Source code for langcheck.plot._scatter

from __future__ import annotations

import math
import textwrap
from copy import deepcopy

import plotly.express as px
from dash import Dash, Input, Output, dcc, html
from pandas.core.indexes.base import Index

from langcheck.metrics.metric_value import MetricValue, MetricValueWithThreshold
from langcheck.plot._css import GLOBAL_CSS, INPUT_CSS, NUM_RESULTS_CSS
from langcheck.plot._utils import Axis, _plot_threshold


[docs] def scatter( metric_value: MetricValue, other_metric_value: MetricValue | None = None, jupyter_mode: str = "inline", ) -> None: """Shows an interactive scatter plot of all data points in an :class:`~langcheck.metrics.metric_value.MetricValue`. When run in a notebook, this usually displays the chart inline in the cell output. Args: metric_value: The :class:`~langcheck.metrics.metric_value.MetricValue` to plot. other_metric_value: If provided, another :class:`~langcheck.metrics.metric_value.MetricValue` to plot on the same chart. jupyter_mode: Defaults to 'inline', which displays the chart in the cell output. For Colab, set this to 'external' instead. See the Dash documentation for more info: https://dash.plotly.com/workspaces/using-dash-in-jupyter-and-workspaces#display-modes """ if not metric_value.is_scatter_compatible or ( other_metric_value is not None and not other_metric_value.is_scatter_compatible ): raise NotImplementedError( "Scatter plots for pairwise MetricValues are not supported yet" ) if other_metric_value is None: _scatter_one_metric_value(metric_value, jupyter_mode) else: _scatter_two_metric_values( metric_value, other_metric_value, jupyter_mode )
def _format_text_for_hover(text: str): """Helper function to format a string so that it displays nicely on hover in the scatter plot. """ # First, split the text by newline characters. This is recommended in # https://docs.python.org/3/library/textwrap.html#textwrap.TextWrapper.replace_whitespace paragraphs = text.split("\n") # Then, split the paragraphs into separate lines with a max width of 70 # chars (default) lines = [line for p in paragraphs for line in textwrap.wrap(p)] # Only show a max of 5 lines. If there are more than 5, add '...' to # indicate that the text has been cut off if len(lines) > 5: lines = lines[:5] + ["..."] return "<br>".join(lines) def _scatter_one_metric_value( metric_value: MetricValue, jupyter_mode: str ) -> None: """Shows an interactive scatter plot of all data points in one :class:`~langcheck.metrics.metric_value.MetricValue`. """ # Rename some MetricValue fields for display df = metric_value.to_df() df.rename( columns={ "metric_values": metric_value.metric_name, # Rename the other columns for display "prompts": "prompt", "reference_outputs": "reference_output", "sources": "source", "explanations": "explanation", "generated_outputs": "generated_output", }, inplace=True, ) df["prompt"] = df["prompt"].fillna("None").apply(_format_text_for_hover) df["reference_output"] = ( df["reference_output"].fillna("None").apply(_format_text_for_hover) ) df["source"] = df["source"].fillna("None").apply(_format_text_for_hover) df["explanation"] = ( df["explanation"].fillna("None").apply(_format_text_for_hover) ) df["generated_output"] = ( df["generated_output"].fillna("None").apply(_format_text_for_hover) ) # Define layout of the Dash app (chart + search boxes) app = Dash(__name__) app.layout = html.Div( [ html.Div( [ html.Label("Filter generated_outputs: "), dcc.Input( id="filter_generated_outputs", type="text", placeholder="Type to search...", style=INPUT_CSS, ), ] ), html.Div( [ html.Label("Filter reference_outputs: "), dcc.Input( id="filter_reference_outputs", type="text", placeholder="Type to search...", style=INPUT_CSS, ), ] ), html.Div( [ html.Label("Filter prompts: "), dcc.Input( id="filter_prompts", type="text", placeholder="Type to search...", style=INPUT_CSS, ), ] ), html.Div( [ html.Label("Filter sources: "), dcc.Input( id="filter_sources", type="text", placeholder="Type to search...", style=INPUT_CSS, ), ] ), html.Div( [html.Span(id="num_results_message", style=NUM_RESULTS_CSS)] ), dcc.Graph( id="scatter_plot", config={ "displaylogo": False, "modeBarButtonsToRemove": [ "select", "lasso2d", "resetScale", ], }, ), ], style=GLOBAL_CSS, ) # This function gets called whenever the user types in the search boxes @app.callback( Output("scatter_plot", "figure"), Output("num_results_message", "children"), Input("filter_generated_outputs", "value"), Input("filter_reference_outputs", "value"), Input("filter_prompts", "value"), Input("filter_sources", "value"), ) def update_figure( filter_generated_outputs, filter_reference_outputs, filter_prompts, filter_sources, ): # Filter data points based on search boxes, case-insensitive filtered_df = df.copy() if filter_generated_outputs: filtered_df = filtered_df[ filtered_df["generated_output"] .str.lower() .str.contains(filter_generated_outputs.lower()) ] if filter_reference_outputs: filtered_df = filtered_df[ filtered_df["reference_output"] .str.lower() .str.contains(filter_reference_outputs.lower()) ] if filter_prompts: filtered_df = filtered_df[ filtered_df["prompt"] .str.lower() .str.contains(filter_prompts.lower()) ] if filter_sources: filtered_df = filtered_df[ filtered_df["source"] .str.lower() .str.contains(filter_sources.lower()) ] # Configure the actual scatter plot fig = px.scatter( filtered_df, x=filtered_df.index, y=metric_value.metric_name, hover_data=filtered_df.columns, ) if isinstance(metric_value, MetricValueWithThreshold): _plot_threshold( fig, metric_value.threshold_op, metric_value.threshold, Axis.horizontal, ) # Explicitly set the default axis ranges (with a little padding) so that # the plot doesn't change when the user types in the search boxes fig.update_xaxes(range=[-0.1, len(df)]) fig.update_yaxes( range=[ min(-0.1, math.floor(df[metric_value.metric_name].min())), max(1.1, math.ceil(df[metric_value.metric_name].max())), ] ) # However, if the user manually zoomed in, keep that zoom level even # when update_figure() re-runs fig.update_layout(uirevision="constant") # Disable drag-to-zoom by default (the user can still enable it # in the modebar) fig.update_layout(dragmode=False) # Display a message about how many data points are hidden num_results_message = ( f"Showing {len(filtered_df)} of {len(df)} data points." ) return fig, num_results_message # Display the Dash app inline in the notebook # TODO: This doesn't seem to display inline if you click "Run All" in VSCode # instead of running the cell directly app.run(jupyter_mode=jupyter_mode) # type: ignore def _scatter_two_metric_values( metric_value: MetricValue, other_metric_value: MetricValue, jupyter_mode: str, ) -> None: """Shows an interactive scatter plot of all data points in two :class:`~langcheck.metrics.metric_value.MetricValue`. """ # Validate that the two MetricValues have the same data points if metric_value.generated_outputs != other_metric_value.generated_outputs: raise ValueError( "Both MetricValues must have the same generated_outputs" ) if metric_value.prompts != other_metric_value.prompts: raise ValueError("Both MetricValues must have the same prompts") if metric_value.reference_outputs != other_metric_value.reference_outputs: raise ValueError( "Both MetricValues must have the same reference_outputs" ) if metric_value.language != other_metric_value.language: raise ValueError("Both MetricValues must have the same language") # Append "(other)" to the metric name of the second MetricValue if # necessary. (It's possible to plot two MetricValues from the same metric, # e.g. if you compute semantic_similarity() with a local model and an OpenAI # model) if metric_value.metric_name == other_metric_value.metric_name: other_metric_value = deepcopy(other_metric_value) other_metric_value.metric_name += " (other)" # Rename some MetricValue fields for display df = metric_value.to_df() df.rename( columns={ "metric_values": metric_value.metric_name, # Rename the other columns for display "prompts": "prompt", "reference_outputs": "reference_output", "sources": "source", "explanations": "explanation", "generated_outputs": "generated_output", }, inplace=True, ) df[other_metric_value.metric_name] = other_metric_value.to_df()[ "metric_value" ] df["prompt"] = df["prompt"].fillna("None").apply(_format_text_for_hover) df["reference_output"] = ( df["reference_output"].fillna("None").apply(_format_text_for_hover) ) df["source"] = df["source"].fillna("None").apply(_format_text_for_hover) df["explanation"] = ( df["explanation"].fillna("None").apply(_format_text_for_hover) ) df["generated_output"] = ( df["generated_output"].fillna("None").apply(_format_text_for_hover) ) # Define layout of the Dash app (chart + search boxes) app = Dash(__name__) app.layout = html.Div( [ html.Div( [ html.Label("Filter generated_outputs: "), dcc.Input( id="filter_generated_outputs", type="text", placeholder="Type to search...", style=INPUT_CSS, ), ] ), html.Div( [ html.Label("Filter reference_outputs: "), dcc.Input( id="filter_reference_outputs", type="text", placeholder="Type to search...", style=INPUT_CSS, ), ] ), html.Div( [ html.Label("Filter prompts: "), dcc.Input( id="filter_prompts", type="text", placeholder="Type to search...", style=INPUT_CSS, ), ] ), html.Div( [ html.Label("Filter sources: "), dcc.Input( id="filter_sources", type="text", placeholder="Type to search...", style=INPUT_CSS, ), ] ), html.Div( [html.Span(id="num_results_message", style=NUM_RESULTS_CSS)] ), dcc.Graph( id="scatter_plot", config={ "displaylogo": False, "modeBarButtonsToRemove": [ "select", "lasso2d", "resetScale", ], }, ), ], style=GLOBAL_CSS, ) # This function gets called whenever the user types in the search boxes @app.callback( Output("scatter_plot", "figure"), Output("num_results_message", "children"), Input("filter_generated_outputs", "value"), Input("filter_reference_outputs", "value"), Input("filter_prompts", "value"), Input("filter_sources", "value"), ) def update_figure( filter_generated_outputs, filter_reference_outputs, filter_prompts, filter_sources, ): # Filter data points based on search boxes, case-insensitive filtered_df = df.copy() if filter_generated_outputs: filtered_df = filtered_df[ filtered_df["generated_output"] .str.lower() .str.contains(filter_generated_outputs.lower()) ] if filter_reference_outputs: filtered_df = filtered_df[ filtered_df["reference_output"] .str.lower() .str.contains(filter_reference_outputs.lower()) ] if filter_prompts: filtered_df = filtered_df[ filtered_df["prompt"] .str.lower() .str.contains(filter_prompts.lower()) ] if filter_sources: filtered_df = filtered_df[ filtered_df["source"] .str.lower() .str.contains(filter_sources.lower()) ] # Configure the actual scatter plot # (We need to explicitly add the index column into hover_data here. # Unfortunately it's not possible to make "index" show up at the top of # the tooltip like _scatter_one_metric_value() since Plotly always # displays the x and y values at the top.) hover_data: dict[str, bool | Index] = { col: True for col in filtered_df.columns } hover_data["index"] = filtered_df.index fig = px.scatter( filtered_df, x=metric_value.metric_name, y=other_metric_value.metric_name, hover_data=hover_data, ) # Draw threshold if any of metric_value is MetricValueWithThreshold if isinstance(metric_value, MetricValueWithThreshold): _plot_threshold( fig, metric_value.threshold_op, metric_value.threshold, Axis.vertical, ) if isinstance(other_metric_value, MetricValueWithThreshold): _plot_threshold( fig, other_metric_value.threshold_op, other_metric_value.threshold, Axis.horizontal, ) # Explicitly set the default axis ranges (with a little padding) so that # the plot doesn't change when the user types in the search boxes fig.update_xaxes( range=[ min(-0.1, math.floor(df[metric_value.metric_name].min())), max(1.1, math.ceil(df[metric_value.metric_name].max())), ] ) fig.update_yaxes( range=[ min(-0.1, math.floor(df[other_metric_value.metric_name].min())), max(1.1, math.ceil(df[other_metric_value.metric_name].max())), ] ) # However, if the user manually zoomed in, keep that zoom level even # when update_figure() re-runs fig.update_layout(uirevision="constant") # Disable drag-to-zoom by default (the user can still enable it in the # modebar) fig.update_layout(dragmode=False) # Display a message about how many data points are hidden num_results_message = ( f"Showing {len(filtered_df)} of {len(df)} data points." ) return fig, num_results_message # Display the Dash app inline in the notebook # TODO: This doesn't seem to display inline if you click "Run All" in VSCode # instead of running the cell directly app.run(jupyter_mode=jupyter_mode) # type: ignore