Source code for langcheck.metrics.ja.source_based_text_quality

from __future__ import annotations

from typing import Dict, List, Optional, cast

from openai import OpenAI
from transformers.pipelines import pipeline
from transformers.pipelines.base import Pipeline

from langcheck.metrics._validation import (
    validate_parameters_context_relevance, validate_parameters_source_based)
from langcheck.metrics.en._openai import OpenAIBasedEvaluator
from langcheck.metrics.en.source_based_text_quality import \
    factual_consistency as en_factual_consistency
from langcheck.metrics.metric_value import MetricValue
from langcheck.utils.progess_bar import tqdm_wrapper

_factual_consistency_translation_model_path = 'Helsinki-NLP/opus-mt-ja-en'
_factual_consistency_translation_pipeline: Pipeline | None = None


[docs]def factual_consistency( generated_outputs: List[str] | str, sources: List[str] | str, prompts: Optional[List[str] | str] = None, model_type: str = 'local', openai_client: Optional[OpenAI] = None, openai_args: Optional[Dict[str, str]] = None, *, use_async: bool = False) -> MetricValue[Optional[float]]: '''Calculates the factual consistency between the generated outputs and the sources. This metric takes on float values between [0, 1], where 0 means that the output is not at all consistent with the source text, and 1 means that the output is fully consistent with the source text. (NOTE: when using the OpenAI model, the factuality scores are either 0.0, 0.5, or 1.0. The score may also be `None` if it could not be computed.) We currently support three model types: 1. The 'local' type, where the 'unieval-fact' model is downloaded from HuggingFace and run locally. This is the default model type and there is no setup needed to run this. This function wraps :func:`~langcheck.metrics.en.en_factual_consistency` using the translation model ``Helsinki-NLP/opus-mt-ja-en`` to translate the Japanese texts to English before computing the factual consistency scores. This is because the UniEval-fact model is trained on English text. 2. The 'openai' type, where we use OpenAI's 'gpt-turbo-3.5' model by default. While the model you use is configurable, please make sure to use one that supports function calling (https://platform.openai.com/docs/guides/gpt/function-calling). See `this page <https://langcheck.readthedocs.io/en/latest/metrics.html #computing-metrics-with-openai-models>`__ for examples on setting up the OpenAI API key. 3. The 'azure_openai' type. Essentially the same as the 'openai' type, except that it uses the AzureOpenAI client. Note that you must specify your model deployment to use in ``openai_args``, e.g. ``openai_args={'model': 'YOUR_DEPLOYMENT_NAME'}`` Args: generated_outputs: The model generated output(s) to evaluate sources: The source text(s), one string per generated output prompts: The prompts used to generate the output(s). Prompts are optional metadata and not used to calculate the metric. model_type: The type of model to use ('local', 'openai', or 'azure_openai'), default 'local' openai_client: OpenAI or AzureOpenAI client, default None. If this is None but ``model_type`` is 'openai' or 'azure_openai', we will attempt to create a default client. openai_args: Dict of additional args to pass in to the ``client.chat.completions.create`` function, default None use_async: Whether to use the asynchronous API of OpenAI, default False Returns: An MetricValue object ''' generated_outputs, sources, prompts = validate_parameters_source_based( generated_outputs, sources, prompts) assert model_type in [ 'local', 'openai', 'azure_openai' ], ('Unsupported model type. ' 'The supported ones are ["local", "openai", "azure_openai"]') # The English prompt works well enough for Japanese # TODO: Investigate the performance improvement with Japanese prompt if model_type == 'openai' or model_type == 'azure_openai': metric_value = en_factual_consistency(generated_outputs, sources, prompts, model_type, openai_client, openai_args, use_async=use_async) metric_value.language = 'ja' return metric_value global _factual_consistency_translation_pipeline if _factual_consistency_translation_pipeline is None: _factual_consistency_translation_pipeline = pipeline( 'translation', model=_factual_consistency_translation_model_path, truncation=True) # Translate the sources and generated outputs to English. # Currently, the type checks are not working for the pipeline, since # too diverse types can be returned. batch_size = 8 en_source = [] for i in tqdm_wrapper(range(0, len(sources), batch_size), desc='Translating sources', total=(len(sources) + batch_size - 1) // batch_size): batch_sources = sources[i:i + batch_size] en_source.extend([ cast(str, d['translation_text']) # type: ignore[reportGeneralTypeIssues] for d in _factual_consistency_translation_pipeline( batch_sources) # type: ignore[reportGeneralTypeIssues] ]) en_generated_outputs = [] for i in tqdm_wrapper(range(0, len(generated_outputs), batch_size), desc='Translating generated outputs', total=(len(generated_outputs) + batch_size - 1) // batch_size): batch_generated_outputs = generated_outputs[i:i + batch_size] en_generated_outputs.extend([ cast(str, d['translation_text']) # type: ignore[reportGeneralTypeIssues] for d in _factual_consistency_translation_pipeline( batch_generated_outputs ) # type: ignore[reportGeneralTypeIssues] ]) # Compute the factual consistency scores in English. factual_consistency_scores = en_factual_consistency( generated_outputs=en_generated_outputs, sources=en_source).metric_values return MetricValue(metric_name='factual_consistency', prompts=prompts, generated_outputs=generated_outputs, reference_outputs=None, sources=sources, explanations=None, metric_values=factual_consistency_scores, language='ja')
[docs]def context_relevance(sources: List[str] | str, prompts: List[str] | str, model_type: str = 'openai', openai_client: Optional[OpenAI] = None, openai_args: Optional[Dict[str, str]] = None, *, use_async: bool = False) -> MetricValue[Optional[float]]: '''Calculates the relevance of the sources to the prompts. This metric takes on float values between [0, 1], where 0 means that the source text is not at all relevant to the prompt, and 1 means that the source text is fully relevant to the prompt. We currently support two model types: 1. The 'openai' type, where we use OpenAI's 'gpt-turbo-3.5' model by default. While the model you use is configurable, please make sure to use one that supports function calling (https://platform.openai.com/docs/guides/gpt/function-calling). See `this page <https://langcheck.readthedocs.io/en/latest/metrics.html #computing-metrics-with-openai-models>`__ for examples on setting up the OpenAI API key. 2. The 'azure_openai' type. Essentially the same as the 'openai' type, except that it uses the AzureOpenAI client. Note that you must specify your model deployment to use in ``openai_args``, e.g. ``openai_args={'model': 'YOUR_DEPLOYMENT_NAME'}`` Args: sources: The source text(s), one string per prompt prompts: The prompt(s) model_type: The type of model to use ('openai' or 'azure_openai'), default 'openai' openai_client: OpenAI or AzureOpenAI client, default None. If this is None, we will attempt to create a default client. openai_args: Dict of additional args to pass in to the ``client.chat.completions.create`` function, default None use_async: Whether to use the asynchronous API, default False ''' prompts, sources = validate_parameters_context_relevance(prompts, sources) def _prompt(src: str, user_query: str) -> str: return f''' ユーザーの質問に対してソースの関連性を評価してください。データは以下の通りです: [BEGIN DATA] ************ [ソース]: {src} ************ [ユーザーの質問]: {user_query} ************ [END DATA] ユーザーの質問に対応するために必要な、関連性のある情報がソースに含まれているかを判断 してください。利用可能な評価は以下の通りです: `Fully Relevant` - ソーステキストには、ユーザーの質問に対応するために必要な情報が 含まれています。 `Partially Relevant` - ソーステキストはユーザーの質問に部分的に関連していますが、質問に 対応するために必要なすべての情報を含んでいません。 `Not Relevant` - ソーステキストはユーザーの質問に関連していません。 深呼吸をして、この問題をステップバイステップで取り組んでください。 ''' def _function_call_prompt(long_assessment: str) -> str: return f''' 以下はソースの関連性に関する評価です: ************ [評価]: {long_assessment} ************ 結果として出た評価を保存してください。利用可能な評価は以下の通りです: `Fully Relevant` `Partially Relevant` `Not Relevant` ''' context_relevance_assessment_to_score = { 'Fully Relevant': 1.0, 'Partially Relevant': 0.5, 'Not Relevant': 0.0 } oai_evaluator = OpenAIBasedEvaluator( assessment_to_score_mapping=context_relevance_assessment_to_score, function_name='save_context_relevance_assessment', function_description=("Saves a context relevance assessment."), argument_name='context_relevance', argument_description='The context relevance assessment', client_type=model_type, client=openai_client, openai_args=openai_args, use_async=use_async) scores, explanations = oai_evaluator.get_score( map(_prompt, sources, prompts), _function_call_prompt) return MetricValue(metric_name='context_relevance', prompts=prompts, generated_outputs=None, reference_outputs=None, sources=sources, explanations=explanations, metric_values=scores, language='ja')