Source code for langcheck.metrics.zh.reference_free_text_quality

from __future__ import annotations

import hanlp
from transformers.pipelines import pipeline

from langcheck.metrics.en.reference_free_text_quality import (
    sentiment as en_sentiment,
)
from langcheck.metrics.en.reference_free_text_quality import (
    toxicity as en_toxicity,
)
from langcheck.metrics.eval_clients import EvalClient
from langcheck.metrics.metric_inputs import (
    get_metric_inputs_with_required_lists,
)
from langcheck.metrics.metric_value import MetricValue

LANG = "zh"


[docs] def sentiment( generated_outputs: list[str] | str, prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", ) -> MetricValue[float | None]: """Calculates the sentiment scores of generated outputs. This metric takes on float values between [0, 1], where 0 is negative sentiment and 1 is positive sentiment. (NOTE: when using an EvalClient, the sentiment scores are either 0.0 (negative), 0.5 (neutral), or 1.0 (positive). The score may also be `None` if it could not be computed.) We currently support two evaluation model types: 1. The 'local' type, where the IDEA-CCNL/Erlangshen-Roberta-110M-Sentiment model is downloaded from HuggingFace and run locally. This is the default model type and there is no setup needed to run this. 2. The EvalClient type, where you can use an EvalClient typically implemented with an LLM. The implementation details are explained in each of the concrete EvalClient classes. Ref: https://huggingface.co/IDEA-CCNL/Erlangshen-Roberta-110M-Sentiment Args: generated_outputs: The model generated output(s) to evaluate prompts: The prompts used to generate the output(s). Prompts are optional metadata and not used to calculate the metric. eval_model: The type of model to use ('local' or the EvalClient instance used for the evaluation). default 'local' Returns: An :class:`~langcheck.metrics.metric_value.MetricValue` object """ metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists( generated_outputs=generated_outputs, prompts=prompts, required_params=["generated_outputs"], ) if eval_model != "local": # EvalClient assert isinstance(eval_model, EvalClient), ( "An EvalClient must be provided for non-local model types." ) # This reuses the English prompt. # TODO: Update this to use a Chinese prompt. metric_value = en_sentiment(generated_outputs, prompts, eval_model) metric_value.language = "zh" return metric_value # {0:"Negative", 1:'Positive'} from langcheck.metrics.model_manager import manager tokenizer, model = manager.fetch_model(language="zh", metric="sentiment") _sentiment_pipeline = pipeline( "sentiment-analysis", model=model, # type: ignore[reportGeneralTypeIssues] tokenizer=tokenizer, # type: ignore[reportGeneralTypeIssues] ) _model_id2label = _sentiment_pipeline.model.config.id2label _predict_result = _sentiment_pipeline(generated_outputs) # type: ignore[reportGeneralTypeIssues] # if predicted result is 'Positive', use the score directly # else, use 1 - score as the sentiment score scores = [ 1 - x["score"] if x["label"] == _model_id2label[0] else x["score"] # type: ignore[reportGeneralTypeIssues] for x in _predict_result # type: ignore[reportGeneralTypeIssues] ] return MetricValue( metric_name="sentiment", metric_inputs=metric_inputs, explanations=None, metric_values=scores, # type: ignore[reportGeneralTypeIssues] language="zh", )
[docs] def toxicity( generated_outputs: list[str] | str, prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", eval_prompt_version: str = "v2", ) -> MetricValue[float | None]: """Calculates the toxicity scores of generated outputs. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high toxicity. (NOTE: when using an EvalClient, the toxicity scores are in steps of 0.25. The score may also be `None` if it could not be computed.) We currently support two evaluation model types: 1. The 'local' type, where a model file is downloaded from HuggingFace and run locally. This is the default model type and there is no setup needed to run this. The model (alibaba-pai/pai-bert-base-zh-llm-risk-detection) is a risky detection model for LLM generated content released by Alibaba group. 2. The EvalClient type, where you can use an EvalClient typically implemented with an LLM. The implementation details are explained in each of the concrete EvalClient classes. Ref: https://huggingface.co/alibaba-pai/pai-bert-base-zh-llm-risk-detection Args: generated_outputs: The model generated output(s) to evaluate prompts: The prompts used to generate the output(s). Prompts are optional metadata and not used to calculate the metric. eval_model: The type of model to use ('local' or the EvalClient instance used for the evaluation). default 'local' eval_prompt_version: The version of the eval prompt to use when the EvalClient is used. The default version is 'v2' (latest). Returns: An :class:`~langcheck.metrics.metric_value.MetricValue` object """ metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists( generated_outputs=generated_outputs, prompts=prompts, required_params=["generated_outputs"], ) if eval_model != "local": # EvalClient assert isinstance(eval_model, EvalClient), ( "An EvalClient must be provided for non-local model types." ) # This reuses the English prompt. # TODO: Update this to use a Chinese prompt. metric_value = en_toxicity(generated_outputs, prompts, eval_model) metric_value.language = "zh" return metric_value else: scores = _toxicity_local(generated_outputs) explanations = None return MetricValue( metric_name="toxicity", metric_inputs=metric_inputs, explanations=explanations, # type: ignore (pyright doesn't understand that a list of floats is a list of optional floats) metric_values=scores, language="zh", )
def _toxicity_local(generated_outputs: list[str]) -> list[float]: """Calculates the toxicity scores of generated outputs using a fine-tuned model from `alibaba-pai/pai-bert-base-zh-llm-risk-detection`. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high toxicity. Ref: https://huggingface.co/alibaba-pai/pai-bert-base-zh-llm-risk-detection Args: generated_outputs: A list of model generated outputs to evaluate Returns: A list of scores """ # this pipeline output predict probability for each text on each label. # the output format is list[list[dict(str)]] from langcheck.metrics.model_manager import manager tokenizer, model = manager.fetch_model(language="zh", metric="toxicity") _toxicity_pipeline = pipeline( "text-classification", model=model, # type: ignore[reportOptionalIterable] tokenizer=tokenizer, # type: ignore[reportOptionalIterable] top_k=5, ) # {'Normal': 0, 'Pulp': 1, 'Sex': 2, 'Other Risk': 3, 'Adult': 4} _model_id2label = _toxicity_pipeline.model.config.id2label _predict_results = _toxicity_pipeline( generated_outputs # type: ignore[reportGeneralTypeIssues] ) # labels except Normal are all risky, toxicity_score = 1-score['Normal'] toxicity_scores = [] for item_predict_proba in _predict_results: # type: ignore[reportOptionalIterable] for label_proba in item_predict_proba: # type: ignore[reportGeneralTypeIssues] if label_proba["label"] == _model_id2label[0]: # type: ignore[reportGeneralTypeIssues] toxicity_scores.append(1 - label_proba["score"]) # type: ignore[reportGeneralTypeIssues] return toxicity_scores # type: ignore[reportGeneralTypeIssues]
[docs] def xuyaochen_report_readability( generated_outputs: list[str] | str, prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the readability scores of generated outputs introduced in "中文年报可读性"(Chinese annual report readability). This metric calculates average words per sentence as r1, average of the sum of the numbers of adverbs and coordinating conjunction words in a sentence in given generated outputs as r2, then, refer to the Fog Index that combine r1 with r2 by arithmetic mean as the final outputs. This function uses HanLP Tokenizer and POS at the same time, POS in CTB style https://hanlp.hankcs.com/docs/annotations/pos/ctb.html. The lower the score is, the better the readability. The score is mainly influenced by r1, the average number of words in sentences. Ref: Refer Chinese annual report readability: measurement and test Link: https://www.tandfonline.com/doi/full/10.1080/21697213.2019.1701259 Args: generated_outputs: A list of model generated outputs to evaluate prompts: The prompts used to generate the output(s). Prompts are optional metadata and not used to calculate the metric. Returns: A list of scores """ # split generated_outputs into sentence metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists( generated_outputs=generated_outputs, prompts=prompts, required_params=["generated_outputs"], ) tokenizer = hanlp.load( hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH # type: ignore[reportGeneralTypeIssues] ) postagger = hanlp.load( hanlp.pretrained.pos.CTB9_POS_RADICAL_ELECTRA_SMALL # type: ignore[reportGeneralTypeIssues] ) pos_pipeline = hanlp.pipeline().append(hanlp.utils.rules.split_sentence) # type: ignore[reportGeneralTypeIssues] pos_pipeline = pos_pipeline.append(tokenizer).append(postagger) tokenize_pipeline = hanlp.pipeline().append( hanlp.utils.rules.split_sentence # type: ignore[reportGeneralTypeIssues] ) tokenize_pipeline = tokenize_pipeline.append(tokenizer) # OUTPUT: List[List[List[TOKEN]]] output_tokens = list(map(tokenize_pipeline, generated_outputs)) # List[List[List[POS]]] output_pos = list(map(pos_pipeline, generated_outputs)) def count_tokens(sent_tokens: list[str]) -> int: count = sum( [ not hanlp.utils.string_util.ispunct(token) # type: ignore[reportGeneralTypeIssues] for token in sent_tokens ] ) return count def count_postags(sent_poses: list[str]) -> int: # AD: adverb, CC: coordinating conjunction, # CS: subordinating conjunction count = sum([pos in ["AD", "CC", "CS"] for pos in sent_poses]) return count def calc_r1(content: list[list[str]]) -> float: token_count_by_sentence = list(map(count_tokens, content)) if len(token_count_by_sentence) == 0: return 0 else: return sum(token_count_by_sentence) / len(token_count_by_sentence) def calc_r2(content: list[list[str]]) -> float: pos_count_by_sentence = list(map(count_postags, content)) if len(pos_count_by_sentence) == 0: return 0 else: return sum(pos_count_by_sentence) / len(pos_count_by_sentence) r1 = list(map(calc_r1, output_tokens)) # type: ignore[reportGeneralTypeIssues] r2 = list(map(calc_r2, output_pos)) # type: ignore[reportGeneralTypeIssues] r3 = [(r1_score + r2_score) * 0.5 for r1_score, r2_score in zip(r1, r2)] return MetricValue( metric_name="readability", metric_inputs=metric_inputs, explanations=None, metric_values=r3, language="zh", )