Source code for langcheck.metrics.de._translation

from math import floor

from nltk.tokenize import sent_tokenize
from transformers.pipelines import pipeline


[docs] class Translate: """Translation class based on HuggingFace's translation pipeline.""" def __init__(self, model_name: str) -> None: """ Initialize the Translation class with given parameters. Args: model_name: The name of the model to use for translation """ self._translation_pipeline = pipeline("translation", model=model_name, tokenizer=model_name, truncation=True) self._max_length = self._translation_pipeline.model.config.max_length def _translate(self, texts: str) -> str: """Translate the texts using the translation pipeline. It splits the texts into blocks and translates each block separately, avoiding problems with long texts. Args: texts: The texts to translate Returns: The translated texts """ tokenization = self._translation_pipeline.tokenizer( texts, return_tensors="pt") # type: ignore if tokenization.input_ids.shape[1] > (self._max_length / 2): # Split the text into blocks, if it is too long # starting from 2 * num_tokens / max_length to be sure # NB: this comes from a few 100 tests, but it is not a science blocks = floor(2 * tokenization.input_ids.shape[1] / self._max_length) sentences = sent_tokenize(texts) # Split sentences into a number of blocks, e.g., 2 blocks = 2 groups len_block = floor(len(sentences) / blocks) + 1 sentences_list = [] for i in range(blocks): sentences_list.append(sentences[i * len_block:(i + 1) * len_block]) text_list = [" ".join(sent) for sent in sentences_list] else: text_list = [texts] translated_texts = [] for text in text_list: text_en = [ str(d["translation_text"]) # type: ignore for d in self._translation_pipeline(text) # type: ignore ] translated_texts.append(" ".join(text_en)) text_translated_final = " ".join(translated_texts) return text_translated_final def __call__(self, text: str) -> str: """Translate the text using the translation pipeline. Args: text: The text to translate Returns: The translated text """ return self._translate(text)