Source code for langcheck.metrics.ja.source_based_text_quality

from __future__ import annotations

from typing import Dict, List, Optional, cast

from openai import OpenAI
from transformers.pipelines import pipeline
from transformers.pipelines.base import Pipeline

from langcheck.metrics._validation import validate_parameters_source_based
from langcheck.metrics.en.source_based_text_quality import \
    factual_consistency as en_factual_consistency
from langcheck.metrics.metric_value import MetricValue

_factual_consistency_translation_model_path = 'Helsinki-NLP/opus-mt-ja-en'
_factual_consistency_translation_pipeline: Pipeline | None = None


[docs]def factual_consistency( generated_outputs: List[str] | str, sources: List[str] | str, prompts: Optional[List[str] | str] = None, model_type: str = 'local', openai_client: Optional[OpenAI] = None, openai_args: Optional[Dict[str, str]] = None) -> MetricValue[Optional[float]]: '''Calculates the factual consistency between the generated outputs and the sources. This metric takes on float values between [0, 1], where 0 means that the output is not at all consistent with the source text, and 1 means that the output is fully consistent with the source text. (NOTE: when using the OpenAI model, the factuality scores are either 0.0, 0.5, or 1.0. The score may also be `None` if it could not be computed.) We currently support three model types: 1. The 'local' type, where the 'unieval-fact' model is downloaded from HuggingFace and run locally. This is the default model type and there is no setup needed to run this. This function wraps :func:`~langcheck.metrics.en.en_factual_consistency` using the translation model ``Helsinki-NLP/opus-mt-ja-en`` to translate the Japanese texts to English before computing the factual consistency scores. This is because the UniEval-fact model is trained on English text. 2. The 'openai' type, where we use OpenAI's 'gpt-turbo-3.5' model by default. While the model you use is configurable, please make sure to use one that supports function calling (https://platform.openai.com/docs/guides/gpt/function-calling). See `this page <https://langcheck.readthedocs.io/en/latest/metrics.html #computing-metrics-with-openai-models>`__ for examples on setting up the OpenAI API key. 3. The 'azure_openai' type. Essentially the same as the 'openai' type, except that it uses the AzureOpenAI client. Note that you must specify your model deployment to use in ``openai_args``, e.g. ``openai_args={'model': 'YOUR_DEPLOYMENT_NAME'}`` Args: generated_outputs: The model generated output(s) to evaluate sources: The source text(s), one string per generated output prompts: The prompts used to generate the output(s). Prompts are optional metadata and not used to calculate the metric. model_type: The type of model to use ('local', 'openai', or 'azure_openai'), default 'local' openai_client: OpenAI or AzureOpenAI client, default None. If this is None but ``model_type`` is 'openai' or 'azure_openai', we will attempt to create a default client. openai_args: Dict of additional args to pass in to the ``client.chat.completions.create`` function, default None Returns: An MetricValue object ''' generated_outputs, sources, prompts = validate_parameters_source_based( generated_outputs, sources, prompts) assert model_type in [ 'local', 'openai', 'azure_openai' ], ('Unsupported model type. ' 'The supported ones are ["local", "openai", "azure_openai"]') # The English prompt works well enough for Japanese # TODO: Investigate the performance improvement with Japanese prompt if model_type == 'openai' or model_type == 'azure_openai': metric_value = en_factual_consistency(generated_outputs, sources, prompts, model_type, openai_client, openai_args) metric_value.language = 'ja' return metric_value global _factual_consistency_translation_pipeline if _factual_consistency_translation_pipeline is None: _factual_consistency_translation_pipeline = pipeline( 'translation', model=_factual_consistency_translation_model_path) # Translate the sources and generated outputs to English. # Currently, the type checks are not working for the pipeline, since # too diverse types can be returned. en_source = [ cast(str, d['translation_text']) # type: ignore[reportGeneralTypeIssues] for d in _factual_consistency_translation_pipeline( sources) # type: ignore[reportGeneralTypeIssues] ] en_generated_outputs = [ cast(str, d['translation_text']) # type: ignore[reportGeneralTypeIssues] for d in _factual_consistency_translation_pipeline( generated_outputs) # type: ignore[reportGeneralTypeIssues] ] # Compute the factual consistency scores in English. factual_consistency_scores = en_factual_consistency( generated_outputs=en_generated_outputs, sources=en_source).metric_values return MetricValue(metric_name='factual_consistency', prompts=prompts, generated_outputs=generated_outputs, reference_outputs=None, sources=sources, explanations=None, metric_values=factual_consistency_scores, language='ja')