from __future__ import annotations
import operator
import warnings
from dataclasses import dataclass, fields
from statistics import mean
from typing import Generic, TypeVar, Union
import pandas as pd
from langcheck.metrics.metric_inputs import MetricInputs
# Metrics take on float or integer values
# Some metrics may return `None` values when the score fails to be computed
NumericType = TypeVar(
"NumericType", float, int, Union[float, None], Union[int, None]
)
[docs]
@dataclass
class MetricValue(Generic[NumericType]):
"""A rich object that is the output of any langcheck.metrics function."""
metric_name: str
metric_values: list[NumericType]
# Input of the metrics such as prompts, generated outputs... etc
metric_inputs: MetricInputs
# An explanation can be None if the metric could not be computed
explanations: list[str | None] | None
language: str | None
[docs]
def to_df(self) -> pd.DataFrame:
"""Returns a DataFrame of metric values for each data point."""
input_df = self.metric_inputs.to_df()
output_df = pd.DataFrame(
{
"explanations": self.explanations,
"metric_values": self.metric_values,
}
)
# Return the concatenation of the input and output DataFrames
return pd.concat([input_df, output_df], axis=1)
def __str__(self) -> str:
"""Returns a string representation of an
:class:`~langcheck.metrics.metric_value.MetricValue` object.
"""
return f"Metric: {self.metric_name}\n" f"{self.to_df()}"
def __repr__(self) -> str:
"""Returns a string representation of an
:class:`~langcheck.metrics.metric_value.MetricValue` object.
"""
return str(self)
def _repr_html_(self) -> str:
"""Returns an HTML representation of an
:class:`~langcheck.metrics.metric_value.MetricValue`, which is
automatically called by Jupyter notebooks.
"""
return (
f"Metric: {self.metric_name}<br>" f"{self.to_df()._repr_html_()}" # type: ignore
)
def __lt__(self, threshold: float | int) -> MetricValueWithThreshold:
"""Allows the user to write a `metric_value < 0.5` expression."""
all_fields = {f.name: getattr(self, f.name) for f in fields(self)}
return MetricValueWithThreshold(
**all_fields, threshold=threshold, threshold_op="<"
)
def __le__(self, threshold: float | int) -> MetricValueWithThreshold:
"""Allows the user to write a `metric_value <= 0.5` expression."""
all_fields = {f.name: getattr(self, f.name) for f in fields(self)}
return MetricValueWithThreshold(
**all_fields, threshold=threshold, threshold_op="<="
)
def __gt__(self, threshold: float | int) -> MetricValueWithThreshold:
"""Allows the user to write a `metric_value > 0.5` expression."""
all_fields = {f.name: getattr(self, f.name) for f in fields(self)}
return MetricValueWithThreshold(
**all_fields, threshold=threshold, threshold_op=">"
)
def __ge__(self, threshold: float | int) -> MetricValueWithThreshold:
"""Allows the user to write a `metric_value >= 0.5` expression."""
all_fields = {f.name: getattr(self, f.name) for f in fields(self)}
return MetricValueWithThreshold(
**all_fields, threshold=threshold, threshold_op=">="
)
def __eq__(self, threshold: float | int) -> MetricValueWithThreshold:
"""Allows the user to write a `metric_value == 0.5` expression."""
all_fields = {f.name: getattr(self, f.name) for f in fields(self)}
return MetricValueWithThreshold(
**all_fields, threshold=threshold, threshold_op="=="
)
def __ne__(self, threshold: float | int) -> MetricValueWithThreshold:
"""Allows the user to write a `metric_value != 0.5` expression."""
all_fields = {f.name: getattr(self, f.name) for f in fields(self)}
return MetricValueWithThreshold(
**all_fields, threshold=threshold, threshold_op="!="
)
[docs]
def all(self) -> bool:
"""Equivalent to all(metric_value.metric_values). This is mostly useful
for binary metric functions.
"""
return all(self.metric_values)
[docs]
def any(self) -> bool:
"""Equivalent to any(metric_value.metric_values). This is mostly useful
for binary metric functions.
"""
return any(self.metric_values)
def __bool__(self):
raise ValueError(
"A MetricValue cannot be used as a boolean. "
"Try an expression like `metric_value > 0.5`, "
"`metric_value.all()`, or `metric_value.any()` instead."
)
def __getattr__(self, name: str):
"""If the attribute is not found in the MetricValue object, we try to
proxy the attribute to the MetricInputs object.
"""
try:
return self.metric_inputs.get_input_list(name)
except ValueError:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
@property
def is_scatter_compatible(self) -> bool:
"""Checks if the metric value is compatible with the scatter plot
method. It is only available for metric values with only non-pairwise
metric values used from initial release (generated_outputs, prompts,
reference_outputs and sources)
"""
allowed_inputs = [
"generated_outputs",
"prompts",
"reference_outputs",
"sources",
]
return len(self.metric_inputs.pairwise_inputs) == 0 and all(
input_name in allowed_inputs
for input_name in self.metric_inputs.individual_inputs
)
[docs]
def scatter(self, jupyter_mode: str = "inline") -> None:
"""Shows an interactive scatter plot of all data points in MetricValue.
Intended to be used in a Jupyter notebook.
This is a convenience function that calls
:func:`langcheck.plot.scatter()`.
"""
from langcheck.plot import scatter as plot_scatter
# Type ignore because a Self type is only valid in class contexts
return plot_scatter(
self, # type: ignore[reportGeneralTypeIssues]
jupyter_mode=jupyter_mode,
)
[docs]
def histogram(self, jupyter_mode: str = "inline") -> None:
"""Shows an interactive histogram of all data points in MetricValue.
Intended to be used in a Jupyter notebook.
This is a convenience function that calls
:func:`langcheck.plot.histogram()`.
"""
from langcheck.plot import histogram as plot_histogram
# Type ignore because a Self type is only valid in class contexts
return plot_histogram(
self, # type: ignore[reportGeneralTypeIssues]
jupyter_mode=jupyter_mode,
)
[docs]
@dataclass
class MetricValueWithThreshold(MetricValue):
"""A rich object that is the output of comparing an
:class:`~langcheck.metrics.metric_value.MetricValue` object,
e.g. `metric_value >= 0.5`.
"""
threshold: float | int
threshold_op: str # One of '<', '<=', '>', '>=', '==', '!='
def __post_init__(self) -> None:
"""Computes self.pass_rate and self.threshold_results based on the
constructor arguments.
"""
operators = {
"<": operator.lt,
"<=": operator.le,
">": operator.gt,
">=": operator.ge,
"==": operator.eq,
"!=": operator.ne,
}
if self.threshold_op not in operators:
raise ValueError(f"Invalid threshold operator: {self.threshold_op}")
if self.threshold is None:
raise ValueError("A threshold of `None` is not supported.")
if None in self.metric_values:
warnings.warn(
"The threshold result for `None` values in `metric_values` will"
" always be `False`."
)
# Set the result to `False` if the metric value is `None`
self._threshold_results = [
operators[self.threshold_op](x, self.threshold)
if x is not None
else False
for x in self.metric_values
]
self._pass_rate = mean(self._threshold_results)
@property
def pass_rate(self) -> float:
"""Returns the proportion of data points that pass the threshold."""
return self._pass_rate
@property
def threshold_results(self) -> list[bool]:
"""Returns a list of booleans indicating whether each data point passes
the threshold.
"""
return self._threshold_results
[docs]
def to_df(self) -> pd.DataFrame:
"""Returns a DataFrame of metric values for each data point."""
dataframe = super().to_df()
dataframe["threshold_test"] = [
f"{self.threshold_op} {self.threshold}" for _ in self.metric_values
]
dataframe["threshold_result"] = self.threshold_results
return dataframe
def __str__(self) -> str:
"""Returns a string representation of an
:class:`~langcheck.metrics.metric_value.MetricValue`.
"""
return (
f"Metric: {self.metric_name}\n"
f"Pass Rate: {round(self.pass_rate*100, 2)}%\n"
f"{self.to_df()}"
)
def __repr__(self) -> str:
"""Returns a string representation of an
:class:`~langcheck.metrics.metric_value.MetricValue` object.
"""
return str(self)
def _repr_html_(self) -> str:
"""Returns an HTML representation of an
:class:`~langcheck.metrics.metric_value.MetricValue`, which is
automatically called by Jupyter notebooks.
"""
return (
f"Metric: {self.metric_name}<br>"
f"Pass Rate: {round(self.pass_rate*100, 2)}%<br>"
f"{self.to_df()._repr_html_()}" # type: ignore
)
[docs]
def all(self) -> bool:
"""Returns True if all data points pass the threshold."""
return all(self.threshold_results)
[docs]
def any(self) -> bool:
"""Returns True if any data points pass the threshold."""
return any(self.threshold_results)
def __bool__(self) -> bool:
"""Allows the user to write an `assert metric_value > 0.5` or
`if metric_value > 0.5:` expression. This is shorthand for
`assert (metric_value > 0.5).all()`.
"""
return self.all()