from __future__ import annotations
from collections.abc import Iterable
from jinja2 import Template
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from ..prompts._utils import get_template
from ._base import EvalClient
[docs]
class PrometheusEvalClient(EvalClient):
"""EvalClient defined for the Prometheus 2 model.
This eval client currently supports only English.
Presented in `"Prometheus 2: An Open Source Language Model Specialized
in Evaluating Other Language Models" <https://arxiv.org/abs/2405.01535>`.
We adapted the prompts in <https://github.com/prometheus-eval/prometheus-
eval/blob/main/libs/prometheus-eval/prometheus_eval/prompts.py>.
"""
def __init__(
self,
model_name: str = "prometheus-eval/prometheus-7b-v2.0",
torch_dtype: str = "bfloat16",
tensor_parallel_size: int = 1,
device: str = "cuda",
):
"""
Initilize the Prometheus evaluation client.
Args:
model_name: The name of the model to use.
torch_dtype: The torch dtype to use. torch.bfloat16 is recommended.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
device: The device to load the model on.
"""
self._model = LLM(
model=model_name,
max_model_len=8192,
dtype=torch_dtype,
tensor_parallel_size=tensor_parallel_size,
device=device,
)
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._sampling_params = SamplingParams(
temperature=0.6,
top_p=0.9,
max_tokens=1000,
skip_special_tokens=True,
)
[docs]
def load_prompt_template(
self,
language: str,
metric_name: str,
eval_prompt_version: str | None = None,
) -> Template:
"""
Gets a Jinja template from the specified language, eval client, metric
name, and (optionally) eval prompt version.
Args:
language (str): The language of the template.
metric_name (str): The name of the metric.
eval_prompt_version (str | None): The version of the eval prompt.
If None, the default version is used.
Returns:
Template: The Jinja template.
"""
if eval_prompt_version is None:
try:
return get_template(
f"{language}/metrics/prometheus/{metric_name}.j2"
)
except FileNotFoundError:
raise ValueError(
f"The {metric_name} metric (language = {language}) is not yet supported by the Prometheus eval client."
)
else:
try:
return get_template(
f"{language}/metrics/prometheus/{metric_name}_{eval_prompt_version}.j2"
)
except FileNotFoundError:
raise ValueError(
f"The {metric_name} metric (language = {language}, version = {eval_prompt_version}) is not yet supported by the Prometheus eval client."
)
[docs]
def get_text_responses(self, prompts: Iterable[str]) -> list[str | None]:
"""The function that generates responses to the given prompt texts.
Args:
prompts: The prompts you want to get the responses for.
Returns:
A list of responses to the prompts. The responses can be None if the
evaluation fails.
"""
messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
processed_prompts = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
if isinstance(processed_prompts, str):
processed_prompts = [processed_prompts]
else:
processed_prompts = [str(p) for p in processed_prompts]
responses = self._model.generate(
processed_prompts, self._sampling_params
)
response_texts = [
response.outputs[0].text
if response and response.outputs[0].text != ""
else None
for response in responses
]
return response_texts
[docs]
def get_float_score(
self,
metric_name: str,
language: str,
unstructured_assessment_result: list[str | None],
score_map: dict[str, float],
) -> list[float | None]:
"""The function that transforms the unstructured assessments (i.e. long
texts that describe the evaluation results) into scores. We simple find
the assessment result which appeared latest in the unstructured text.
Args:
metric_name: The name of the metric to be used. (e.g. "toxicity")
language: The language of the prompts. (e.g. "en")
unstructured_assessment_result: The unstructured assessment results
for the given assessment prompts.
score_map: The mapping from the short assessment results
(e.g. "Good") to the scores.
Returns:
A list of scores for the given prompts. The scores can be None if
the evaluation fails.
"""
if language != "en":
raise ValueError(f"Unsupported language: {language}")
options = list(score_map.keys())
assessments = []
for unstructured_assessment in unstructured_assessment_result:
if unstructured_assessment is None:
assessments.append(None)
continue
# Find the option that appears latest in the assessment
assessment = max(options, key=unstructured_assessment.rfind)
if unstructured_assessment.find(assessment) == -1:
print("No options found in the assessment.")
assessments.append(None)
else:
assessments.append(assessment)
return [
score_map[assessment] if assessment else None
for assessment in assessments
]
[docs]
def get_score(
self,
metric_name: str,
language: str,
prompts: str | Iterable[str],
score_map: dict[str, float],
) -> tuple[list[float | None], list[str | None]]:
"""Give scores to texts embedded in the given prompts. The function
itself calls get_text_responses and get_float_score to get the scores.
The function returns the scores and the unstructured explanation
strings.
Args:
metric_name: The name of the metric to be used. (e.g. "toxicity")
language: The language of the prompts. (e.g. "en")
prompts: The prompts that contain the original text to be scored,
the evaluation criteria... etc. Typically it is based on the
Jinja prompt templates and instantiated withing each metric
function.
score_map: The mapping from the short assessment results
(e.g. "Good") to the scores.
Returns:
A tuple of two lists. The first list contains the scores for each
prompt and the second list contains the unstructured assessment
results for each prompt. Both can be None if the evaluation fails.
"""
if isinstance(prompts, str):
prompts = [prompts]
unstructured_assessment_result = self.get_text_responses(prompts)
scores = self.get_float_score(
metric_name,
language,
unstructured_assessment_result,
score_map,
)
return scores, unstructured_assessment_result
[docs]
def similarity_scorer(self):
raise NotImplementedError(
"Embedding-based metrics are not supported in PrometheusEvalClient."
"Use other EvalClients to get these metrics."
)