Source code for langcheck.augment.en._remove_punctuation

from __future__ import annotations

import random
import string


[docs] def remove_punctuation( instances: list[str] | str, *, aug_char_p: float = 1.0, num_perturbations: int = 1, seed: int | None = None, ) -> list[str]: """Applies a text perturbation to each string in instances (usually a list of prompts) where some punctuation is removed. Args: instances: A single string or a list of strings to be augmented. aug_char_p: Percentage of punctuation characters that will be removed. num_perturbations: The number of perturbed instances to generate for each string in instances. seed: The seed for the random number generator. You can fix the seed to deterministically choose which characters to remove. Returns: A list of perturbed instances. """ # Validation on aug_char_p if aug_char_p < 0 or aug_char_p > 1: raise ValueError("aug_char_p must be between 0 and 1") if seed is not None: random.seed(seed) instances = [instances] if isinstance(instances, str) else instances perturbed_instances = [] for instance in instances: for _ in range(num_perturbations): perturbed_instance = "" for char in instance: if char not in string.punctuation: perturbed_instance += char # No augmentation elif random.random() > aug_char_p: perturbed_instance += char # No augmentation else: pass # Remove character perturbed_instances.append(perturbed_instance) return perturbed_instances