Source code for langcheck.augment.en._synonym

from __future__ import annotations

import nltk
from nlpaug.augmenter.word import SynonymAug


[docs] def synonym( instances: list[str] | str, *, num_perturbations: int = 1, **kwargs, ) -> list[str]: """Applies a text perturbation to each string in instances (usually a list of prompts) where some words are replaced with synonyms. Args: instances: A single string or a list of strings to be augmented. num_perturbations: The number of perturbed instances to generate for each string in instances aug_p: Percentage of words which will be augmented. Defaults to `0.1`. aug_max: Maximum number of words which will be augmented. Defaults to `None`. .. note:: Any argument that can be passed to `nlpaug.augmenter.word.SynonymAug <https://nlpaug.readthedocs.io/en/latest/_modules/nlpaug/augmenter/word/synonym.html>`_ is acceptable. Some of the more useful ones from the `nlpaug` documentation are listed below: - ``aug_p`` (float): Percentage of words which will be augmented. - ``aug_min`` (int): Minimum number of words that will be augmented. - ``aug_max`` (int): Maximum number of words that will be augmented. Note that the default values for these arguments may be different from the ``nlpaug`` defaults. Returns: A list of perturbed instances. """ kwargs["aug_p"] = kwargs.get("aug_p", 0.1) kwargs["aug_max"] = kwargs.get("aug_max") try: nltk.data.find("taggers/averaged_perceptron_tagger_eng") except LookupError: nltk.download("averaged_perceptron_tagger_eng") instances = [instances] if isinstance(instances, str) else instances perturbed_instances = [] aug = SynonymAug(**kwargs) for instance in instances: for _ in range(num_perturbations): perturbed_instances += aug.augment(instance) return perturbed_instances