Source code for langcheck.metrics.eval_clients._gemini

from __future__ import annotations

import asyncio
import os
import warnings
from typing import Any, Literal

import torch
from google import genai
from google.genai import types
from pydantic import BaseModel

from langcheck.metrics.eval_clients.eval_response import (
    ResponsesWithMetadata,
)
from langcheck.utils.progress_bar import tqdm_wrapper

from ..prompts._utils import get_template
from ..scorer._base import BaseSimilarityScorer
from ._base import EvalClient
from .extractor import Extractor


[docs] class GeminiEvalClient(EvalClient): """EvalClient defined for the Gemini model.""" def __init__( self, model_name: str = "gemini-1.5-flash", embed_model_name: str | None = None, generate_content_args: dict[str, Any] | None = None, genai_client: genai.Client | None = None, *, use_async: bool = False, vertexai: bool = False, system_prompt: str | None = None, extractor: Extractor | None = None, ): """ Initialize the Gemini evaluation client. You can provide your own genai.Client instance via the `genai_client` argument, or set the necessary environment variables. If you want to use Gemini Developer API, please set `GOOGLE_API_KEY`. If you want to use Vertex AI API, set the `vertexai` argument to True, and set the following environment variables: - GOOGLE_CLOUD_PROJECT=<your-project-id> - GOOGLE_CLOUD_LOCATION=<location> (e.g. europe-west1) - GOOGLE_APPLICATION_CREDENTIALS=<path-to-your-credentials> References: - https://ai.google.dev/api/python/google/generativeai/GenerativeModel - https://cloud.google.com/docs/authentication/application-default-credentials Args: model_name: The Gemini model to use. Defaults to "gemini-1.5-flash". embed_model_name (Optional): The name of the embedding model to use. If not provided, the "models/text-embedding-004" model will be used. generate_content_args (Optional): Dict of args to pass in to the ``generate_content`` function. The keys should be the same as the keys in the ``genai.types.GenerateContentConfig`` type. genai_client (Optional): The genai.Client instance to use. If not provided, the client will be created using the environment variables. use_async: If True, the async client will be used. Defaults to False. vertexai: If True, the Vertex AI client will be used. Ignored when `genai_client` is provided. Defaults to False. system_prompt (Optional): The system prompt for ``generate_content`` in ``get_text_responses`` function. If not provided, no system prompt will be used. extractor (Optional): The extractor to use. If not provided, a default GeminiExtractor will be used. """ warnings.warn( "GeminiEvalClient will be deprecated in the next release." "Please use LiteLLMEvalClient instead." ) self._model_name = model_name self._generate_content_args = generate_content_args or {} _validate_generate_content_config(self._generate_content_args) self._embed_model_name = embed_model_name self._use_async = use_async self._system_instruction = system_prompt if genai_client is None: # Check for required environment variables if vertexai: # Vertex AI requires these environment variables for env_var in [ "GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_LOCATION", "GOOGLE_APPLICATION_CREDENTIALS", ]: if not os.environ.get(env_var): raise ValueError( f"Environment variable '{env_var}' must be set when using Vertex AI." ) # Warn that `GOOGLE_API_KEY` is not used when using Vertex AI if os.environ.get("GOOGLE_API_KEY", None): warnings.warn( "`GOOGLE_API_KEY` is set when using Vertex AI. " "Vertex AI will take precedence over the API key from " "the environment variable." ) elif os.environ.get("GOOGLE_API_KEY", None) is None: # Gemini Developer API requires API key raise ValueError( "`GOOGLE_API_KEY` is not set when using Gemini Developer API. " "Please set the `GOOGLE_API_KEY` environment variable." ) self._client = genai.Client(vertexai=vertexai) self._vertexai = vertexai else: self._client = genai_client self._vertexai = genai_client.vertexai # Client config will take precedence over the argument, and the # argument will be ignored. if self._vertexai and not vertexai: warnings.warn( "The provided `genai_client` is a Vertex AI client, " "so the `vertexai=False` argument will be ignored. The Vertex AI client will be used." ) elif not self._vertexai and vertexai: warnings.warn( "The provided `genai_client` is a Gemini Developer client, " "so the `vertexai=True` argument will be ignored. The Gemini Developer client will be used." ) if extractor is None: self._extractor = GeminiExtractor( genai_client=self._client, use_async=self._use_async, vertexai=self._vertexai, ) else: self._extractor = extractor
[docs] def get_text_responses( self, prompts: list[str], *, tqdm_description: str | None = None, ) -> ResponsesWithMetadata[str]: """The function that gets responses to the given prompt texts. Args: prompts: The prompts you want to get the responses for. Returns: A list of responses to the prompts. The responses can be None if the evaluation fails. """ config: dict[str, Any] = { "temperature": 0.0, "system_instruction": self._system_instruction, } config.update(self._generate_content_args or {}) tqdm_description = tqdm_description or "Intermediate assessments (1/2)" responses = _call_api( model=self._model_name, prompts=prompts, config=config, client=self._client, use_async=self._use_async, tqdm_description=tqdm_description, ) response_texts = [ response.text if response else None for response in responses ] # Token usage is not supported in GeminiEvalClient # If you need token usage, please use LiteLLMEvalClient instead. return ResponsesWithMetadata(response_texts, None)
[docs] def similarity_scorer(self) -> GeminiSimilarityScorer: return GeminiSimilarityScorer( embed_model_name=self._embed_model_name, client=self._client, use_async=self._use_async, )
class GeminiSimilarityScorer(BaseSimilarityScorer): """Similarity scorer that uses the Gemini API to embed the inputs. In the current version of langcheck, the class is only instantiated within EvalClients. """ def __init__( self, embed_model_name: str | None, client: genai.Client, *, use_async: bool = False, ): super().__init__() self._embed_model_name = embed_model_name or "text-embedding-004" self._client = client self._use_async = use_async def _embed(self, inputs: list[str]) -> torch.Tensor: """Embed the inputs using the Gemini API.""" if self._use_async: async def _call_async_api(): embed_response = await self._client.aio.models.embed_content( model=self._embed_model_name, contents=[ types.Part.from_text(text=prompt) for prompt in inputs ], ) return embed_response loop = asyncio.get_event_loop() embed_response = loop.run_until_complete(_call_async_api()) else: embed_response = self._client.models.embed_content( model=self._embed_model_name, contents=[ types.Part.from_text(text=prompt) for prompt in inputs ], ) assert embed_response.embeddings is not None return torch.Tensor( [embed.values for embed in embed_response.embeddings] )
[docs] class GeminiExtractor(Extractor): def __init__( self, model_name: str = "gemini-1.5-flash", genai_client: genai.Client | None = None, generate_content_args: dict[str, Any] | None = None, *, use_async: bool = False, vertexai: bool = False, ): """ Initialize the Gemini score extraction client. You can provide your own genai.Client instance via the `genai_client` argument, or set the necessary environment variables. If you want to use Gemini Developer API, please set `GOOGLE_API_KEY`. If you want to use Vertex AI API, set the `vertexai` argument to True, and set the following environment variables: - GOOGLE_CLOUD_PROJECT=<your-project-id> - GOOGLE_CLOUD_LOCATION=<location> (e.g. europe-west1) - GOOGLE_APPLICATION_CREDENTIALS=<path-to-your-credentials> References: - https://ai.google.dev/api/python/google/generativeai/GenerativeModel - https://cloud.google.com/docs/authentication/application-default-credentials Args: model_name: The Gemini model to use. Defaults to "gemini-1.5-flash". generate_content_args (Optional): Dict of args to pass in to the ``generate_content`` function. The keys should be the same as the keys in the ``genai.types.GenerateContentConfig`` type. genai_client (Optional): The genai.Client instance to use. If not provided, the client will be created using the environment variables. use_async: If True, the async client will be used. Defaults to False. vertexai: If True, the Vertex AI client will be used. Ignored when `genai_client` is provided. Defaults to False. """ warnings.warn( "GeminiExtractor will be deprecated in the next release." "Please use LiteLLMExtractor instead." ) self._model_name = model_name self._generate_content_args = generate_content_args or {} _validate_generate_content_config(self._generate_content_args) self._use_async = use_async if genai_client is None: # Check for required environment variables if vertexai: # Vertex AI requires these environment variables for env_var in [ "GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_LOCATION", "GOOGLE_APPLICATION_CREDENTIALS", ]: if not os.environ.get(env_var): raise ValueError( f"Environment variable '{env_var}' must be set when using Vertex AI." ) elif os.environ.get("GOOGLE_API_KEY", None) is None: # Gemini Developer API requires API key raise ValueError( "`GOOGLE_API_KEY` is not set when using Gemini Developer API. " "Please set the `GOOGLE_API_KEY` environment variable." ) self._client = genai.Client(vertexai=vertexai) else: self._client = genai_client # Client config will take precedence over the argument, and the # argument will be ignored. if genai_client.vertexai and not vertexai: warnings.warn( "The provided `genai_client` is a Vertex AI client, " "so the `vertexai=False` argument will be ignored. The Vertex AI client will be used." ) elif not genai_client.vertexai and vertexai: warnings.warn( "The provided `genai_client` is a Gemini Developer client, " "so the `vertexai=True` argument will be ignored. The Gemini Developer client will be used." )
[docs] def get_float_score( self, metric_name: str, language: str, unstructured_assessment_result: list[str | None], score_map: dict[str, float], *, tqdm_description: str | None = None, ) -> ResponsesWithMetadata[float]: """The function that transforms the unstructured assessments (i.e. long texts that describe the evaluation results) into scores. We leverage the structured output API to extract the short assessment results from the unstructured assessments, so please make sure that the model you use supports structured output (See the References for more details). References: https://ai.google.dev/gemini-api/docs/structured-output Args: metric_name: The name of the metric to be used. (e.g. "toxicity") language: The language of the prompts. (e.g. "en") unstructured_assessment_result: The unstructured assessment results for the given assessment prompts. score_map: The mapping from the short assessment results (e.g. "Good") to the scores. tqdm_description (Optional): The description to be shown in the tqdm bar. Returns: A list of scores for the given prompts. The scores can be None if the evaluation fails. """ if language not in ["en", "ja", "de"]: raise ValueError(f"Unsupported language: {language}") structured_output_template = get_template( f"{language}/get_score/structured_output.j2" ) options = list(score_map.keys()) class Response(BaseModel): score: Literal[tuple(options)] # type: ignore config = { "temperature": 0.0, "response_mime_type": "application/json", "response_schema": Response, } config.update(self._generate_content_args or {}) # Create prompts, filtering out None valid_prompts = [] prompt_indices = [] # Keep track of original indices for i, unstructured_assessment in enumerate( unstructured_assessment_result ): if unstructured_assessment is not None: valid_prompts.append( structured_output_template.render( { "metric": metric_name, "unstructured_assessment": unstructured_assessment, "options": options, } ) ) prompt_indices.append(i) tqdm_description = tqdm_description or "Scores (2/2)" # Call API for valid prompts if valid_prompts: api_responses = _call_api( model=self._model_name, prompts=valid_prompts, config=config, client=self._client, use_async=self._use_async, tqdm_description=tqdm_description, ) else: api_responses = [] # Reconstruct full responses list with None for invalid prompts responses = [None] * len(unstructured_assessment_result) for i, response in enumerate(api_responses): original_idx = prompt_indices[i] responses[original_idx] = response assessments = [ response.parsed.score if response else None for response in responses ] # Token usage is not supported in GeminiExtractor # If you need token usage, please use LiteLLMExtractor instead. return ResponsesWithMetadata( [ score_map[assessment] if assessment and assessment in options else None for assessment in assessments ], None, )
def _call_api( model: str, prompts: list[str], config: dict[str, Any], client: genai.Client, *, use_async: bool = False, tqdm_description: str | None = None, ) -> list[Any]: """A helper function to call `generate_content` of the Gemini API. Args: model: The model name to use. prompts: The prompts for `generate_content`. config: The config for `generate_content`. client: The genai client. use_async: If True, the async client will be used. Defaults to False. tqdm_description (Optional): The description to be shown in the tqdm bar. Returns: A list of responses from the Gemini API. """ if use_async: async def _call_async_api() -> list[Any]: responses = await asyncio.gather( *[ client.aio.models.generate_content( model=model, contents=types.Part.from_text(text=prompt), config=types.GenerateContentConfig(**config), ) for prompt in prompts ], return_exceptions=True, ) return responses responses = asyncio.run(_call_async_api()) else: # A helper function to call the API with exception filter for alignment # of exception handling with the async version. def _call_api_with_exception_filter(prompt: str) -> Any: try: return client.models.generate_content( model=model, contents=types.Part.from_text(text=prompt), config=types.GenerateContentConfig(**config), ) except Exception as e: return e responses = [ _call_api_with_exception_filter(prompt) for prompt in tqdm_wrapper(prompts, desc=tqdm_description) ] # Filter out exceptions and print them out. Also filter out responses # that are blocked by safety settings and print out the safety ratings. for i, response in enumerate(responses): if isinstance(response, Exception): print( "Gemini failed to return an assessment corresponding to " f"{i}th prompt: {response}" ) responses[i] = None elif response.candidates[0].finish_reason == 3: print( f"Gemini's safety settings blocked the {i}th prompt:\n " f"{response.candidates[0].safety_ratings}" ) responses[i] = None return responses def _validate_generate_content_config( generate_content_args: dict[str, Any], ) -> None: """A helper function to validate the generate_content_args. Args: generate_content_args: The generate_content_args to validate. """ try: _ = types.GenerateContentConfig(**generate_content_args) except (TypeError, ValueError) as e: raise ValueError( f"Invalid generate_content_args: {generate_content_args}Error: {e}" )