Source code for langcheck.augment.en._remove_punctuation

from __future__ import annotations

import random
import string


[docs]def remove_punctuation(instances: list[str] | str, *, aug_char_p: float = 1.0, num_perturbations: int = 1) -> list[str]: '''Applies a text perturbation to each string in instances (usually a list of prompts) where some punctuation is removed. Args: instances: A single string or a list of strings to be augmented. aug_char_p: Percentage of puncutation characters that will be removed. num_perturbations: The number of perturbed instances to generate for each string in instances. Returns: A list of perturbed instances. ''' instances = [instances] if isinstance(instances, str) else instances perturbed_instances = [] for instance in instances: for _ in range(num_perturbations): perturbed_instance = '' for char in instance: if char not in string.punctuation: perturbed_instance += char # No augmentation elif random.random() > aug_char_p: perturbed_instance += char # No augmentation else: pass # Remove character perturbed_instances.append(perturbed_instance) return perturbed_instances