Source code for langcheck.metrics.metric_inputs
from __future__ import annotations
from typing import Union
import pandas as pd
from jinja2 import Environment, meta
# You need "Union" to declare a type in Python < 3.10
IndividualInputType = Union[str, list[str], None]
def _map_pairwise_input_to_list(
input: tuple[IndividualInputType, IndividualInputType],
) -> tuple[list[str] | None, list[str] | None]:
return (
_map_individual_input_to_list(input[0]),
_map_individual_input_to_list(input[1]),
)
def _map_individual_input_to_list(
input: IndividualInputType,
) -> list[str] | None:
if input is None:
return None
elif isinstance(input, str):
return [input]
else:
return input
[docs]
class MetricInputs:
"""A helper class to handle the inputs for the metric in a consistent way."""
def __init__(
self,
individual_inputs: dict[str, IndividualInputType],
pairwise_inputs: dict[
str, tuple[IndividualInputType, IndividualInputType]
]
| None = None,
required_params: list[str] | None = None,
optional_params: list[str] | None = None,
input_name_to_prompt_var_mapping: dict[str, str] | None = None,
):
"""Initialize the MetricInputs object.
Args:
individual_inputs: A dictionary of individual inputs. The keys are
the parameter names and the values are the input lists.
pairwise_inputs: A dictionary of pairwise inputs. The keys are the
parameter names and the values are tuples of two input lists.
required_params: A list of required parameters.
optional_params: A list of optional parameters.
input_name_to_prompt_var_mapping: A dictionary that maps the input
names to the variable names in the prompt template. The values
should therefore correspond with the keys returned from the
`get_inputs_for_prompt_template` method.
"""
# Instantiate the parameter lists if None
self.required_params = required_params or []
self.optional_params = optional_params or []
self.individual_inputs = {
key: _map_individual_input_to_list(value)
for key, value in individual_inputs.items()
}
if pairwise_inputs is None:
self.pairwise_inputs = {}
else:
self.pairwise_inputs = {
key: _map_pairwise_input_to_list(value)
for key, value in pairwise_inputs.items()
}
self.input_name_to_prompt_var_mapping = (
input_name_to_prompt_var_mapping or {}
)
all_input_keys = list(self.individual_inputs.keys()) + list(
self.pairwise_inputs.keys()
)
# Check that all the required parameters are present
missing_required_params = set(self.required_params) - set(
all_input_keys
)
if missing_required_params:
raise ValueError(
f"Missing required parameters: {missing_required_params}"
)
for input_key in all_input_keys:
if input_key not in self.input_name_to_prompt_var_mapping:
# Add mapping for the input key itself
self.input_name_to_prompt_var_mapping[input_key] = input_key
# Do the validation of parameters
# Validate that individual_inputs and pairwise_inputs are disjoint
individual_input_keys = set(self.individual_inputs.keys())
pairwise_input_keys = set(self.pairwise_inputs.keys())
if not individual_input_keys.isdisjoint(pairwise_input_keys):
overlap_keys = individual_input_keys.intersection(
pairwise_input_keys
)
raise ValueError(
"Individual input keys and pairwise input keys should be disjoint."
f" Overlapping keys: {overlap_keys}"
)
# Validate the individual inputs
for individual_input_key in individual_input_keys:
individual_input = self.individual_inputs[individual_input_key]
if individual_input_key in self.required_params:
if individual_input is None:
raise ValueError(
f"Required parameter '{individual_input_key}' is None."
)
elif individual_input_key not in self.optional_params:
raise ValueError(f"Unknown parameter '{individual_input_key}'")
# Validate the pairwise inputs
for pairwise_input_key in pairwise_input_keys:
pairwise_input_a, pairwise_input_b = self.pairwise_inputs[
pairwise_input_key
]
if pairwise_input_key in self.required_params:
if pairwise_input_a is None or pairwise_input_b is None:
raise ValueError(
f"Required parameter '{pairwise_input_key}' is None."
)
elif pairwise_input_key not in self.optional_params:
raise ValueError(f"Unknown parameter '{pairwise_input_key}'")
# If to_df is called, each key is mapped into two columns: key_a and
# key_b. Check that the key is not already used.
df_key_a = pairwise_input_key + "_a"
if df_key_a in all_input_keys:
raise ValueError(
f"Key '{df_key_a} will be added as a dataframe column, but it is already used as a input key."
)
df_key_b = pairwise_input_key + "_b"
if df_key_b in all_input_keys:
raise ValueError(
f"Key '{df_key_b} will be added as a dataframe column, but it is already used as a input key."
)
# Validate the lengths of the inputs
input_lengths: set[int] = set()
for key in self.individual_inputs:
individual_input = self.individual_inputs[key]
if individual_input is not None:
input_lengths.add(len(individual_input))
for key in self.pairwise_inputs:
pairwise_input_a, pairwise_input_b = self.pairwise_inputs[key]
if pairwise_input_a is not None:
input_lengths.add(len(pairwise_input_a))
if pairwise_input_b is not None:
input_lengths.add(len(pairwise_input_b))
if len(input_lengths) > 1:
individual_input_lengths = "\n".join(
f"{key}: {len(value)}"
for key, value in self.individual_inputs.items()
if value is not None
)
pairwise_input_lengths = "\n".join(
f"{key}: ({len(value[0])}, {len(value[1])})"
for key, value in self.pairwise_inputs.items()
if value[0] is not None and value[1] is not None
)
raise ValueError(
f"All inputs should have the same length.\n{individual_input_lengths}\n{pairwise_input_lengths}"
)
if not input_lengths:
raise ValueError("No inputs provided.")
self.input_length = input_lengths.pop()
if self.input_length == 0:
raise ValueError("All inputs should have at least one element.")
# Validate the mapping to prompt variables
self.prompt_var_to_input_name_mapping = {}
for individual_input_key in individual_input_keys:
prompt_var = self.input_name_to_prompt_var_mapping[
individual_input_key
]
if prompt_var in self.prompt_var_to_input_name_mapping:
raise ValueError(
f"Prompt variable '{prompt_var}' is mapped from multiple arguments: "
f"{self.prompt_var_to_input_name_mapping[prompt_var]} and {individual_input_key}"
)
self.prompt_var_to_input_name_mapping[prompt_var] = (
individual_input_key
)
for pairwise_input_key in pairwise_input_keys:
prompt_var_individual = self.input_name_to_prompt_var_mapping[
pairwise_input_key
]
prompt_vars = [
prompt_var_individual + "_a",
prompt_var_individual + "_b",
]
for prompt_var in prompt_vars:
if prompt_var in self.prompt_var_to_input_name_mapping:
raise ValueError(
f"Prompt variable '{prompt_var}' is mapped from multiple arguments: "
f"{self.prompt_var_to_input_name_mapping[prompt_var]} and {pairwise_input_key}"
)
self.prompt_var_to_input_name_mapping[prompt_var] = (
pairwise_input_key
)
[docs]
def get_inputs_for_prompt_template(
self, swap_pairwise: bool = False
) -> list[dict[str, str | None]]:
"""Get the inputs that can be used as arguments for the prompt
template.
Each item is a dictionary where the keys are the prompt variables
specified in the `input_name_to_prompt_var_mapping` and the values are
the input values, which are corresponding elements from the input lists.
For pairwise inputs, the values for the first list and the second list
are stored in the attributes with the suffixes "_a" and "_b".
Args:
swap_pairwise: If True, swap the pairwise inputs.
"""
inputs_for_prompt_template: list[dict[str, str | None]] = []
for i in range(self.input_length):
# Create the inputs for the prompt template for the i-th input
single_instance_inputs = {}
for individual_key in self.individual_inputs:
individual_input = self.individual_inputs[individual_key]
individual_prompt_var = self.input_name_to_prompt_var_mapping[
individual_key
]
if individual_input is None:
single_instance_inputs[individual_prompt_var] = None
else:
single_instance_inputs[individual_prompt_var] = (
individual_input[i]
)
for pairwise_key in self.pairwise_inputs:
pairwise_input_a, pairwise_input_b = self.pairwise_inputs[
pairwise_key
]
if swap_pairwise:
pairwise_input_a, pairwise_input_b = (
pairwise_input_b,
pairwise_input_a,
)
pairwise_prompt_var_a = (
self.input_name_to_prompt_var_mapping[pairwise_key] + "_a"
)
if pairwise_input_a is None:
single_instance_inputs[pairwise_prompt_var_a] = None
else:
single_instance_inputs[pairwise_prompt_var_a] = (
pairwise_input_a[i]
)
pairwise_prompt_var_b = (
self.input_name_to_prompt_var_mapping[pairwise_key] + "_b"
)
if pairwise_input_b is None:
single_instance_inputs[pairwise_prompt_var_b] = None
else:
single_instance_inputs[pairwise_prompt_var_b] = (
pairwise_input_b[i]
)
inputs_for_prompt_template.append(single_instance_inputs)
return inputs_for_prompt_template
[docs]
def to_df(self) -> pd.DataFrame:
"""Convert the inputs to a DataFrame."""
input_lists = {}
for individual_key in self.individual_inputs:
individual_input = self.individual_inputs[individual_key]
if individual_input is None:
input_lists[individual_key] = [None] * self.input_length
else:
input_lists[individual_key] = individual_input
for pairwise_key in self.pairwise_inputs:
pairwise_input_a, pairwise_input_b = self.pairwise_inputs[
pairwise_key
]
if pairwise_input_a is None:
input_lists[pairwise_key + "_a"] = [None] * self.input_length
else:
input_lists[pairwise_key + "_a"] = pairwise_input_a
if pairwise_input_b is None:
input_lists[pairwise_key + "_b"] = [None] * self.input_length
else:
input_lists[pairwise_key + "_b"] = pairwise_input_b
return pd.DataFrame(input_lists)
[docs]
def get_input_list(
self, key: str
) -> tuple[list[str] | None, list[str] | None] | list[str] | None:
"""Get the input list for the key."""
if key in self.individual_inputs:
return self.individual_inputs[key]
elif key in self.pairwise_inputs:
return self.pairwise_inputs[key]
else:
raise ValueError(f"Unknown key: {key}")
[docs]
def validate_template(self, template_src: str):
"""Validate that the given prompt template string is compatible with
the input parameters.
Args:
template_src: The prompt template string.
"""
# Validate the expected parameters in the prompt template
env = Environment()
expected_params = meta.find_undeclared_variables(
env.parse(template_src)
)
allowed_params = self.prompt_var_to_input_name_mapping.keys()
assert all(
param in allowed_params for param in expected_params
), f"The prompt template contains invalid parameters. The allowed parameters are {allowed_params} but the prompt template expects the parameters {expected_params}"
for param in expected_params:
arg_key = self.prompt_var_to_input_name_mapping[param]
if arg_key in self.individual_inputs:
assert (
self.individual_inputs[arg_key] is not None
), f'The prompt template expects the parameter "{param}" but it is not provided.'
else:
pairwise_inputs_a, pairwise_inputs_b = self.pairwise_inputs[
arg_key
]
assert (
pairwise_inputs_a is not None
), f'The prompt template expects the parameter "{param}_a" but it is not provided.'
assert (
pairwise_inputs_b is not None
), f'The prompt template expects the parameter "{param}_b" but it is not provided.'
[docs]
def get_required_individual_input(self, key: str) -> list[str]:
"""Get the list of a required parameter in individual_inputs.
Mainly used for metrics without eval clients.
"""
if (key not in self.individual_inputs) or (
key not in self.required_params
):
raise ValueError(f"Unknown key: {key}")
individual_input = self.individual_inputs[key]
# It is already validated that the input is not None
assert isinstance(individual_input, list)
return individual_input
[docs]
def get_metric_inputs(
*,
generated_outputs: IndividualInputType
| tuple[IndividualInputType, IndividualInputType] = None,
prompts: IndividualInputType
| tuple[IndividualInputType, IndividualInputType] = None,
sources: IndividualInputType
| tuple[IndividualInputType, IndividualInputType] = None,
reference_outputs: IndividualInputType
| tuple[IndividualInputType, IndividualInputType] = None,
additional_inputs: dict[str, IndividualInputType] | None = None,
additional_input_name_to_prompt_var_mapping: dict[str, str] | None = None,
required_params: list[str],
) -> MetricInputs:
"""Create a metric inputs object with the standard parameters
(i.e. generated_outputs, prompts, sources, reference_outputs) and the
specified additional parameters.
Args:
generated_outputs: The generated outputs.
prompts: The prompts.
sources: The sources.
reference_outputs: The reference outputs.
additional_inputs: Additional inputs other than the standard ones.
additional_input_name_to_prompt_var_mapping: A dictionary that maps the
additional input names to the variable names in the prompt template.
required_params: A list of required parameters.
Returns:
A MetricInputs object.
"""
if additional_inputs is None:
additional_inputs = {}
if additional_input_name_to_prompt_var_mapping is None:
additional_input_name_to_prompt_var_mapping = {}
allowed_params = [
"generated_outputs",
"prompts",
"sources",
"reference_outputs",
] + list(additional_inputs.keys())
for param in required_params:
if param not in allowed_params:
raise ValueError(f"Unknown parameter: {param}")
optional_params = list(set(allowed_params) - set(required_params))
all_inputs = {
"generated_outputs": generated_outputs,
"prompts": prompts,
"sources": sources,
"reference_outputs": reference_outputs,
**additional_inputs,
}
# Split individual and pairwise inputs
individual_inputs = {
key: value
for key, value in all_inputs.items()
if not isinstance(value, tuple)
}
pairwise_inputs = {
key: value
for key, value in all_inputs.items()
if isinstance(value, tuple)
}
return MetricInputs(
individual_inputs=individual_inputs,
pairwise_inputs=pairwise_inputs,
required_params=required_params,
optional_params=optional_params,
input_name_to_prompt_var_mapping={
"generated_outputs": "gen_output",
"prompts": "user_query",
"sources": "src",
"reference_outputs": "ref_output",
**additional_input_name_to_prompt_var_mapping,
},
)
[docs]
def get_metric_inputs_with_required_lists(
*,
generated_outputs: IndividualInputType
| tuple[IndividualInputType, IndividualInputType] = None,
prompts: IndividualInputType
| tuple[IndividualInputType, IndividualInputType] = None,
sources: IndividualInputType
| tuple[IndividualInputType, IndividualInputType] = None,
reference_outputs: IndividualInputType
| tuple[IndividualInputType, IndividualInputType] = None,
additional_inputs: dict[str, IndividualInputType] | None = None,
additional_input_name_to_prompt_var_mapping: dict[str, str] | None = None,
required_params: list[str],
) -> tuple[MetricInputs, list[list[str]]]:
"""Create a metric inputs object with the standard parameters
(i.e. generated_outputs, prompts, sources, reference_outputs) and the
specified additional parameters. This function also returns the list of
required parameters as raw lists, which is useful for metrics without eval
clients.
Args:
generated_outputs: The generated outputs.
prompts: The prompts.
sources: The sources.
reference_outputs: The reference outputs.
additional_inputs: Additional inputs other than the standard ones.
additional_input_name_to_prompt_var_mapping: A dictionary that maps the
additional input names to the variable names in the prompt template.
required_params: A list of required parameters.
Returns:
A MetricInputs object and the required lists.
"""
metric_inputs = get_metric_inputs(
generated_outputs=generated_outputs,
prompts=prompts,
sources=sources,
reference_outputs=reference_outputs,
additional_inputs=additional_inputs,
additional_input_name_to_prompt_var_mapping=additional_input_name_to_prompt_var_mapping,
required_params=required_params,
)
required_lists = [
metric_inputs.get_required_individual_input(param)
for param in required_params
]
return metric_inputs, required_lists