Source code for langcheck.metrics.de._translation
from math import floor
from nltk.tokenize import sent_tokenize
from transformers.pipelines import pipeline
[docs]
class Translate:
"""Translation class based on HuggingFace's translation pipeline."""
def __init__(self, model_name: str) -> None:
"""
Initialize the Translation class with given parameters.
Args:
model_name: The name of the model to use for translation
"""
self._translation_pipeline = pipeline(
"translation",
model=model_name,
tokenizer=model_name,
truncation=True,
)
self._max_length = self._translation_pipeline.model.config.max_length
def _translate(self, texts: str) -> str:
"""Translate the texts using the translation pipeline.
It splits the texts into blocks and translates each block separately,
avoiding problems with long texts.
Args:
texts: The texts to translate
Returns:
The translated texts
"""
tokenization = self._translation_pipeline.tokenizer(
texts, return_tensors="pt"
) # type: ignore
if tokenization.input_ids.shape[1] > (self._max_length / 2):
# Split the text into blocks, if it is too long
# starting from 2 * num_tokens / max_length to be sure
# NB: this comes from a few 100 tests, but it is not a science
blocks = floor(
2 * tokenization.input_ids.shape[1] / self._max_length
)
sentences = sent_tokenize(texts)
# Split sentences into a number of blocks, e.g., 2 blocks = 2 groups
len_block = floor(len(sentences) / blocks) + 1
sentences_list = []
for i in range(blocks):
sentences_list.append(
sentences[i * len_block : (i + 1) * len_block]
)
text_list = [" ".join(sent) for sent in sentences_list]
else:
text_list = [texts]
translated_texts = []
for text in text_list:
text_en = [
str(d["translation_text"]) # type: ignore
for d in self._translation_pipeline(text) # type: ignore
]
translated_texts.append(" ".join(text_en))
text_translated_final = " ".join(translated_texts)
return text_translated_final
def __call__(self, text: str) -> str:
"""Translate the text using the translation pipeline.
Args:
text: The text to translate
Returns:
The translated text
"""
return self._translate(text)