from __future__ import annotations
import json
import re
from collections.abc import Callable, Container, Iterable
from langcheck.metrics.metric_inputs import (
get_metric_inputs_with_required_lists,
)
from langcheck.metrics.metric_value import MetricValue
from langcheck.utils.progress_bar import tqdm_wrapper
[docs]
def is_int(
generated_outputs: list[str] | str,
domain: Iterable[int] | Container[int] | None = None,
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs can be parsed as integers, optionally within
a domain of integers like `range(1, 11)` or `{1, 3, 5}`. This metric takes
on binary 0 or 1 values.
Args:
generated_outputs: The model generated output(s) to evaluate
domain: The optional domain of valid integers
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(generated_outputs):
try:
output_int = int(output)
if domain is None or output_int in domain:
metric_values.append(1)
else:
metric_values.append(0)
except ValueError:
metric_values.append(0)
return MetricValue(
metric_name="is_int",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)
[docs]
def is_float(
generated_outputs: list[str] | str,
min: float | None = None,
max: float | None = None,
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs can be parsed as floating point numbers,
optionally within a min/max range. This metric takes on binary 0 or 1
values.
Args:
generated_outputs: The model generated output(s) to evaluate
min: The optional minimum valid float
max: The optional maximum valid float
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(generated_outputs):
try:
output_float = float(output)
if min is None and max is None:
metric_values.append(1)
elif min is not None and output_float < min:
metric_values.append(0)
elif max is not None and output_float > max:
metric_values.append(0)
else:
metric_values.append(1)
except ValueError:
metric_values.append(0)
return MetricValue(
metric_name="is_float",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)
[docs]
def is_json_object(
generated_outputs: list[str] | str,
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs can be parsed as JSON objects. This metric
takes on binary 0 or 1 values.
Args:
generated_outputs: The model generated output(s) to evaluate
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(generated_outputs):
try:
json_output = json.loads(output)
if isinstance(json_output, dict):
metric_values.append(1)
else:
metric_values.append(0)
except json.JSONDecodeError:
metric_values.append(0)
return MetricValue(
metric_name="is_json_object",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)
[docs]
def is_json_array(
generated_outputs: list[str] | str,
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs can be parsed as JSON arrays. This metric
takes on binary 0 or 1 values.
Args:
generated_outputs: The model generated output(s) to evaluate
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(generated_outputs):
try:
json_output = json.loads(output)
if isinstance(json_output, list):
metric_values.append(1)
else:
metric_values.append(0)
except json.JSONDecodeError:
metric_values.append(0)
return MetricValue(
metric_name="is_json_array",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)
[docs]
def matches_regex(
generated_outputs: list[str] | str,
regex: str,
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs fully match a given regular expression. This
metric takes on binary 0 or 1 values.
Args:
generated_outputs: The model generated output(s) to evaluate
regex: The regular expression to match
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(generated_outputs):
if re.fullmatch(regex, output) is not None:
metric_values.append(1)
else:
metric_values.append(0)
return MetricValue(
metric_name="matches_regex",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)
[docs]
def contains_regex(
generated_outputs: list[str] | str,
regex: str,
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs partially contain a given regular expression.
This metric takes on binary 0 or 1 values.
Args:
generated_outputs: The model generated output(s) to evaluate
regex: The regular expression to match
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(generated_outputs):
if re.search(regex, output) is not None:
metric_values.append(1)
else:
metric_values.append(0)
return MetricValue(
metric_name="contains_regex",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)
[docs]
def contains_all_strings(
generated_outputs: list[str] | str,
strings: list[str],
case_sensitive: bool = False,
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs contain all strings in of a given list. This
metric takes on binary 0 or 1 values.
Args:
generated_outputs: The model generated output(s) to evaluate
strings: A list of strings to match
case_sensitive: Whether to match case sensitively or not, default False
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# Convert everything to lowercase if case insensitive
if not case_sensitive:
_strings = [string.lower() for string in strings]
_generated_outputs = [output.lower() for output in generated_outputs]
else:
_strings = strings
_generated_outputs = generated_outputs
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(_generated_outputs):
if all(string in output for string in _strings):
metric_values.append(1)
else:
metric_values.append(0)
return MetricValue(
metric_name="contains_all_strings",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)
[docs]
def contains_any_strings(
generated_outputs: list[str] | str,
strings: list[str],
case_sensitive: bool = False,
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs contain any strings in a given list. This
metric takes on binary 0 or 1 values.
Args:
generated_outputs: The model generated output(s) to evaluate
strings: A list of strings to match
case_sensitive: Whether to match case sensitively or not, default to
:obj:`False`.
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# Convert everything to lowercase if case insensitive
if not case_sensitive:
_strings = [string.lower() for string in strings]
_generated_outputs = [output.lower() for output in generated_outputs]
else:
_strings = strings
_generated_outputs = generated_outputs
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(_generated_outputs):
if any(string in output for string in _strings):
metric_values.append(1)
else:
metric_values.append(0)
return MetricValue(
metric_name="contains_any_strings",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)
[docs]
def validation_fn(
generated_outputs: list[str] | str,
valid_fn: Callable[[str], bool],
prompts: list[str] | str | None = None,
) -> MetricValue[int]:
"""Checks if generated outputs are valid according to an arbitrary function.
This metric takes on binary 0 or 1 values.
Args:
generated_outputs: The model generated output(s) to evaluate
valid_fn: A function that takes a single string and returns a
bool determining whether the string is valid or not.
The function can also raise an exception on failure.
prompts: The prompts used to generate the output(s). Prompts are
optional metadata and not used to calculate the metric.
Returns:
An :class:`~langcheck.metrics.metric_value.MetricValue` object
"""
metric_inputs, [generated_outputs] = get_metric_inputs_with_required_lists(
generated_outputs=generated_outputs,
prompts=prompts,
required_params=["generated_outputs"],
)
# The values are binary: 1 for success and 0 for failure
metric_values = []
for output in tqdm_wrapper(generated_outputs):
try:
if valid_fn(output):
metric_values.append(1)
else:
metric_values.append(0)
except Exception:
metric_values.append(0)
return MetricValue(
metric_name="validation_fn",
metric_inputs=metric_inputs,
explanations=None,
metric_values=metric_values,
language=None,
)