Source code for langcheck.augment.en._ocr_typo

from __future__ import annotations

from nlpaug.augmenter.char.ocr import OcrAug


[docs] def ocr_typo( instances: list[str] | str, *, num_perturbations: int = 1, **kwargs, ) -> list[str]: """Applies an OCR typo text perturbation to each string in instances (usually a list of prompts). Args: instances: A single string or a list of strings to be augmented. num_perturbations: The number of perturbed instances to generate for each string in instances aug_char_p: Percentage of characters (per token) that will be augmented. Defaults to `0.1`. aug_char_max: Maximum number of characters which will be augmented. Defaults to `None`. aug_word_max: Maximum number of words which will be augmented. Defaults to `None`. .. note:: Any argument that can be passed to `nlpaug.augmenter.char.ocr.OcrAug <https://nlpaug.readthedocs.io/en/latest/augmenter/char/ocr.html#nlpaug.augmenter.char.ocr.OcrAug>`_ is acceptable. Some of the more useful ones from the `nlpaug` documentation are listed below: - ``aug_char_p`` (float): Percentage of characters (per token) that will be augmented. - ``aug_char_min`` (int): Minimum number of characters that will be augmented. - ``aug_char_max`` (int): Maximum number of characters that will be augmented. - ``aug_word_p`` (float): Percentage of words that will be augmented. - ``aug_word_min`` (int): Minimum number of words that will be augmented. - ``aug_word_max`` (int): Maximum number of words that will be augmented. Note that the default values for these arguments may be different from the ``nlpaug`` defaults. Returns: A list of perturbed instances. """ kwargs["aug_char_p"] = kwargs.get("aug_char_p", 0.1) kwargs["aug_char_max"] = kwargs.get("aug_char_max") kwargs["aug_word_max"] = kwargs.get("aug_word_max") # Validation on aug_char_p if kwargs["aug_char_p"] < 0 or kwargs["aug_char_p"] > 1: raise ValueError("aug_char_p must be between 0 and 1") instances = [instances] if isinstance(instances, str) else instances perturbed_instances = [] aug = OcrAug(**kwargs) for instance in instances: for _ in range(num_perturbations): perturbed_instances += aug.augment(instance) return perturbed_instances