Source code for langcheck.plot._scatter

import math
import textwrap
from copy import deepcopy
from typing import Optional

import plotly.express as px
from dash import Dash, Input, Output, dcc, html

from langcheck.metrics.metric_value import MetricValue, MetricValueWithThreshold
from langcheck.plot._css import GLOBAL_CSS, INPUT_CSS, NUM_RESULTS_CSS
from langcheck.plot._utils import Axis, _plot_threshold


[docs]def scatter(metric_value: MetricValue, other_metric_value: Optional[MetricValue] = None, jupyter_mode: str = 'inline') -> None: '''Shows an interactive scatter plot of all data points in an :class:`~langcheck.metrics.metric_value.MetricValue`. When run in a notebook, this usually displays the chart inline in the cell output. Args: metric_value: The :class:`~langcheck.metrics.metric_value.MetricValue` to plot. other_metric_value: If provided, another :class:`~langcheck.metrics.metric_value.MetricValue` to plot on the same chart. jupyter_mode: Defaults to 'inline', which displays the chart in the cell output. For Colab, set this to 'external' instead. See the Dash documentation for more info: https://dash.plotly.com/workspaces/using-dash-in-jupyter-and-workspaces#display-modes ''' if other_metric_value is None: _scatter_one_metric_value(metric_value, jupyter_mode) else: _scatter_two_metric_values(metric_value, other_metric_value, jupyter_mode)
def _format_text_for_hover(text: str): '''Helper function to format a string so that it displays nicely on hover in the scatter plot. ''' # First, split the text by newline characters. This is recommended in # https://docs.python.org/3/library/textwrap.html#textwrap.TextWrapper.replace_whitespace paragraphs = text.split('\n') # Then, split the paragraphs into separate lines with a max width of 70 # chars (default) lines = [line for p in paragraphs for line in textwrap.wrap(p)] # Only show a max of 5 lines. If there are more than 5, add '...' to # indicate that the text has been cut off if len(lines) > 5: lines = lines[:5] + ['...'] return '<br>'.join(lines) def _scatter_one_metric_value(metric_value: MetricValue, jupyter_mode: str) -> None: '''Shows an interactive scatter plot of all data points in one :class:`~langcheck.metrics.metric_value.MetricValue`. ''' # Rename some MetricValue fields for display df = metric_value.to_df() df.rename(columns={'metric_value': metric_value.metric_name}, inplace=True) df['prompt'] = df['prompt'].fillna('None').apply(_format_text_for_hover) df['reference_output'] = df['reference_output'].fillna('None').apply( _format_text_for_hover) df['source'] = df['source'].fillna('None').apply(_format_text_for_hover) df['explanation'] = df['explanation'].fillna('None').apply( _format_text_for_hover) df['generated_output'] = df['generated_output'].apply( _format_text_for_hover) # Define layout of the Dash app (chart + search boxes) app = Dash(__name__) app.layout = html.Div([ html.Div([ html.Label('Filter generated_outputs: '), dcc.Input(id='filter_generated_outputs', type='text', placeholder='Type to search...', style=INPUT_CSS), ]), html.Div([ html.Label('Filter reference_outputs: '), dcc.Input(id='filter_reference_outputs', type='text', placeholder='Type to search...', style=INPUT_CSS), ]), html.Div([ html.Label('Filter prompts: '), dcc.Input(id='filter_prompts', type='text', placeholder='Type to search...', style=INPUT_CSS), ]), html.Div([ html.Label('Filter sources: '), dcc.Input(id='filter_sources', type='text', placeholder='Type to search...', style=INPUT_CSS), ]), html.Div([html.Span(id='num_results_message', style=NUM_RESULTS_CSS)]), dcc.Graph( id='scatter_plot', config={ 'displaylogo': False, 'modeBarButtonsToRemove': ['select', 'lasso2d', 'resetScale'] }) ], style=GLOBAL_CSS) # This function gets called whenever the user types in the search boxes @app.callback(Output('scatter_plot', 'figure'), Output('num_results_message', 'children'), Input('filter_generated_outputs', 'value'), Input('filter_reference_outputs', 'value'), Input('filter_prompts', 'value'), Input('filter_sources', 'value')) def update_figure(filter_generated_outputs, filter_reference_outputs, filter_prompts, filter_sources): # Filter data points based on search boxes, case-insensitive filtered_df = df.copy() if filter_generated_outputs: filtered_df = filtered_df[ filtered_df['generated_output'].str.lower().str.contains( filter_generated_outputs.lower())] if filter_reference_outputs: filtered_df = filtered_df[ filtered_df['reference_output'].str.lower().str.contains( filter_reference_outputs.lower())] if filter_prompts: filtered_df = filtered_df[ filtered_df['prompt'].str.lower().str.contains( filter_prompts.lower())] if filter_sources: filtered_df = filtered_df[ filtered_df['source'].str.lower().str.contains( filter_sources.lower())] # Configure the actual scatter plot fig = px.scatter(filtered_df, x=filtered_df.index, y=metric_value.metric_name, hover_data=filtered_df.columns) if isinstance(metric_value, MetricValueWithThreshold): _plot_threshold(fig, metric_value.threshold_op, metric_value.threshold, Axis.horizontal) # Explicitly set the default axis ranges (with a little padding) so that # the plot doesn't change when the user types in the search boxes fig.update_xaxes(range=[-0.1, len(df)]) fig.update_yaxes(range=[ min(-0.1, math.floor(df[metric_value.metric_name].min())), max(1.1, math.ceil(df[metric_value.metric_name].max())) ]) # However, if the user manually zoomed in, keep that zoom level even # when update_figure() re-runs fig.update_layout(uirevision='constant') # Disable drag-to-zoom by default (the user can still enable it # in the modebar) fig.update_layout(dragmode=False) # Display a message about how many data points are hidden num_results_message = ( f'Showing {len(filtered_df)} of {len(df)} data points.') return fig, num_results_message # Display the Dash app inline in the notebook # TODO: This doesn't seem to display inline if you click "Run All" in VSCode # instead of running the cell directly app.run(jupyter_mode=jupyter_mode) # type: ignore def _scatter_two_metric_values(metric_value: MetricValue, other_metric_value: MetricValue, jupyter_mode: str) -> None: '''Shows an interactive scatter plot of all data points in two :class:`~langcheck.metrics.metric_value.MetricValue`. ''' # Validate that the two MetricValues have the same data points if metric_value.generated_outputs != other_metric_value.generated_outputs: raise ValueError( 'Both MetricValues must have the same generated_outputs') if metric_value.prompts != other_metric_value.prompts: raise ValueError('Both MetricValues must have the same prompts') if metric_value.reference_outputs != other_metric_value.reference_outputs: raise ValueError( 'Both MetricValues must have the same reference_outputs') if metric_value.language != other_metric_value.language: raise ValueError('Both MetricValues must have the same language') # Append "(other)" to the metric name of the second MetricValue if # necessary. (It's possible to plot two MetricValues from the same metric, # e.g. if you compute semantic_similarity() with a local model and an OpenAI # model) if metric_value.metric_name == other_metric_value.metric_name: other_metric_value = deepcopy(other_metric_value) other_metric_value.metric_name += ' (other)' # Rename some MetricValue fields for display df = metric_value.to_df() df.rename(columns={'metric_value': metric_value.metric_name}, inplace=True) df[other_metric_value.metric_name] = other_metric_value.to_df( )['metric_value'] df['prompt'] = df['prompt'].fillna('None').apply(_format_text_for_hover) df['reference_output'] = df['reference_output'].fillna('None').apply( _format_text_for_hover) df['source'] = df['source'].fillna('None').apply(_format_text_for_hover) df['explanation'] = df['explanation'].fillna('None').apply( _format_text_for_hover) df['generated_output'] = df['generated_output'].apply( _format_text_for_hover) # Define layout of the Dash app (chart + search boxes) app = Dash(__name__) app.layout = html.Div([ html.Div([ html.Label('Filter generated_outputs: '), dcc.Input(id='filter_generated_outputs', type='text', placeholder='Type to search...', style=INPUT_CSS), ]), html.Div([ html.Label('Filter reference_outputs: '), dcc.Input(id='filter_reference_outputs', type='text', placeholder='Type to search...', style=INPUT_CSS), ]), html.Div([ html.Label('Filter prompts: '), dcc.Input(id='filter_prompts', type='text', placeholder='Type to search...', style=INPUT_CSS), ]), html.Div([ html.Label('Filter sources: '), dcc.Input(id='filter_sources', type='text', placeholder='Type to search...', style=INPUT_CSS), ]), html.Div([html.Span(id='num_results_message', style=NUM_RESULTS_CSS)]), dcc.Graph( id='scatter_plot', config={ 'displaylogo': False, 'modeBarButtonsToRemove': ['select', 'lasso2d', 'resetScale'] }) ], style=GLOBAL_CSS) # This function gets called whenever the user types in the search boxes @app.callback(Output('scatter_plot', 'figure'), Output('num_results_message', 'children'), Input('filter_generated_outputs', 'value'), Input('filter_reference_outputs', 'value'), Input('filter_prompts', 'value'), Input('filter_sources', 'value')) def update_figure(filter_generated_outputs, filter_reference_outputs, filter_prompts, filter_sources): # Filter data points based on search boxes, case-insensitive filtered_df = df.copy() if filter_generated_outputs: filtered_df = filtered_df[ filtered_df['generated_output'].str.lower().str.contains( filter_generated_outputs.lower())] if filter_reference_outputs: filtered_df = filtered_df[ filtered_df['reference_output'].str.lower().str.contains( filter_reference_outputs.lower())] if filter_prompts: filtered_df = filtered_df[ filtered_df['prompt'].str.lower().str.contains( filter_prompts.lower())] if filter_sources: filtered_df = filtered_df[ filtered_df['source'].str.lower().str.contains( filter_sources.lower())] # Configure the actual scatter plot # (We need to explicitly add the index column into hover_data here. # Unfortunately it's not possible to make "index" show up at the top of # the tooltip like _scatter_one_metric_value() since Plotly always # displays the x and y values at the top.) hover_data = {col: True for col in filtered_df.columns} hover_data['index'] = filtered_df.index fig = px.scatter(filtered_df, x=metric_value.metric_name, y=other_metric_value.metric_name, hover_data=hover_data) # Draw threshold if any of metric_value is MetricValueWithThreshold if isinstance(metric_value, MetricValueWithThreshold): _plot_threshold(fig, metric_value.threshold_op, metric_value.threshold, Axis.vertical) if isinstance(other_metric_value, MetricValueWithThreshold): _plot_threshold(fig, other_metric_value.threshold_op, other_metric_value.threshold, Axis.horizontal) # Explicitly set the default axis ranges (with a little padding) so that # the plot doesn't change when the user types in the search boxes fig.update_xaxes(range=[ min(-0.1, math.floor(df[metric_value.metric_name].min())), max(1.1, math.ceil(df[metric_value.metric_name].max())) ]) fig.update_yaxes(range=[ min(-0.1, math.floor(df[other_metric_value.metric_name].min())), max(1.1, math.ceil(df[other_metric_value.metric_name].max())) ]) # However, if the user manually zoomed in, keep that zoom level even # when update_figure() re-runs fig.update_layout(uirevision='constant') # Disable drag-to-zoom by default (the user can still enable it in the # modebar) fig.update_layout(dragmode=False) # Display a message about how many data points are hidden num_results_message = ( f'Showing {len(filtered_df)} of {len(df)} data points.') return fig, num_results_message # Display the Dash app inline in the notebook # TODO: This doesn't seem to display inline if you click "Run All" in VSCode # instead of running the cell directly app.run(jupyter_mode=jupyter_mode) # type: ignore