From becde2ce612f7b5af67df2ca55cfda19b8a5fe29 Mon Sep 17 00:00:00 2001 From: Ryan Rishi <51206416+Ryan-Rishi@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:51:33 -0400 Subject: [PATCH] Added parameters to expose LLM prompts from eval (#128) * Added parameters to expose LLM prompts from eval * Fixed wording for consistency metric * Formatting * Switch to not use property --------- Co-authored-by: ethan-tonic --- .../answer_consistency_binary_metric.py | 3 ++- .../metrics/answer_consistency_metric.py | 9 ++++++++ .../metrics/answer_similarity_metric.py | 3 ++- .../metrics/augmentation_accuracy_metric.py | 3 ++- tonic_validate/metrics/duplication_metric.py | 3 ++- .../metrics/hate_speech_content_metric.py | 3 ++- tonic_validate/metrics/metric.py | 6 ++++++ .../metrics/retrieval_precision_metric.py | 3 ++- tonic_validate/utils/llm_calls.py | 21 ++++++++++++------- 9 files changed, 41 insertions(+), 13 deletions(-) diff --git a/tonic_validate/metrics/answer_consistency_binary_metric.py b/tonic_validate/metrics/answer_consistency_binary_metric.py index 95f125e..297d18f 100644 --- a/tonic_validate/metrics/answer_consistency_binary_metric.py +++ b/tonic_validate/metrics/answer_consistency_binary_metric.py @@ -3,13 +3,14 @@ from tonic_validate.metrics.binary_metric import BinaryMetric from tonic_validate.utils.metrics_util import parse_boolean_response from tonic_validate.services.openai_service import OpenAIService -from tonic_validate.utils.llm_calls import answer_consistent_with_context_call +from tonic_validate.utils.llm_calls import answer_consistent_with_context_call, context_consistency_prompt logger = logging.getLogger() class AnswerConsistencyBinaryMetric(BinaryMetric): name: str = "answer_consistency_binary" + prompt: str = context_consistency_prompt() def __init__(self): """ diff --git a/tonic_validate/metrics/answer_consistency_metric.py b/tonic_validate/metrics/answer_consistency_metric.py index eff8d20..794d22b 100644 --- a/tonic_validate/metrics/answer_consistency_metric.py +++ b/tonic_validate/metrics/answer_consistency_metric.py @@ -9,6 +9,8 @@ from tonic_validate.utils.llm_calls import ( main_points_call, statement_derived_from_context_call, + statement_derived_from_context_prompt, + main_points_prompt, ) logger = logging.getLogger() @@ -16,6 +18,13 @@ class AnswerConsistencyMetric(Metric): name: str = "answer_consistency" + prompt: str = ( + "-------------------\n" + f"{main_points_prompt()}\n" + "-------------------\n" + f"{statement_derived_from_context_prompt(statement='EXAMPLE STATEMENT', context_list=[])}\n" + "-------------------\n" + ) def __init__(self): """ diff --git a/tonic_validate/metrics/answer_similarity_metric.py b/tonic_validate/metrics/answer_similarity_metric.py index b3b2741..b5ed699 100644 --- a/tonic_validate/metrics/answer_similarity_metric.py +++ b/tonic_validate/metrics/answer_similarity_metric.py @@ -2,13 +2,14 @@ from tonic_validate.classes.llm_response import LLMResponse from tonic_validate.metrics.metric import Metric from tonic_validate.services.openai_service import OpenAIService -from tonic_validate.utils.llm_calls import similarity_score_call +from tonic_validate.utils.llm_calls import similarity_score_call, similarity_score_prompt logger = logging.getLogger() class AnswerSimilarityMetric(Metric): name: str = "answer_similarity" + prompt: str = similarity_score_prompt() def __init__(self) -> None: """ diff --git a/tonic_validate/metrics/augmentation_accuracy_metric.py b/tonic_validate/metrics/augmentation_accuracy_metric.py index 56af9f1..592d14b 100644 --- a/tonic_validate/metrics/augmentation_accuracy_metric.py +++ b/tonic_validate/metrics/augmentation_accuracy_metric.py @@ -4,13 +4,14 @@ from tonic_validate.metrics.metric import Metric from tonic_validate.utils.metrics_util import parse_boolean_response from tonic_validate.services.openai_service import OpenAIService -from tonic_validate.utils.llm_calls import answer_contains_context_call +from tonic_validate.utils.llm_calls import answer_contains_context_call, answer_contains_context_prompt logger = logging.getLogger() class AugmentationAccuracyMetric(Metric): name: str = "augmentation_accuracy" + prompt: str = answer_contains_context_prompt() def __init__(self): """ diff --git a/tonic_validate/metrics/duplication_metric.py b/tonic_validate/metrics/duplication_metric.py index 91e1a62..9a98822 100644 --- a/tonic_validate/metrics/duplication_metric.py +++ b/tonic_validate/metrics/duplication_metric.py @@ -3,7 +3,7 @@ from tonic_validate.classes.llm_response import LLMResponse from tonic_validate.metrics.binary_metric import BinaryMetric from tonic_validate.services.openai_service import OpenAIService -from tonic_validate.utils.llm_calls import contains_duplicate_information +from tonic_validate.utils.llm_calls import contains_duplicate_information, contains_duplicate_info_prompt from tonic_validate.utils.metrics_util import parse_boolean_response logger = logging.getLogger() @@ -11,6 +11,7 @@ class DuplicationMetric(BinaryMetric): name: str = "duplication_metric" + prompt: str = contains_duplicate_info_prompt() def __init__(self): """ diff --git a/tonic_validate/metrics/hate_speech_content_metric.py b/tonic_validate/metrics/hate_speech_content_metric.py index 0efd4e4..7dd9512 100644 --- a/tonic_validate/metrics/hate_speech_content_metric.py +++ b/tonic_validate/metrics/hate_speech_content_metric.py @@ -3,7 +3,7 @@ from tonic_validate.classes.llm_response import LLMResponse from tonic_validate.metrics.binary_metric import BinaryMetric from tonic_validate.services.openai_service import OpenAIService -from tonic_validate.utils.llm_calls import contains_hate_speech +from tonic_validate.utils.llm_calls import contains_hate_speech, contains_hate_speech_prompt from tonic_validate.utils.metrics_util import parse_boolean_response logger = logging.getLogger() @@ -11,6 +11,7 @@ class HateSpeechContentMetric(BinaryMetric): name: str = "hate_speech_content" + prompt: str = contains_hate_speech_prompt() def __init__(self): """ diff --git a/tonic_validate/metrics/metric.py b/tonic_validate/metrics/metric.py index c9847d6..ba5ea54 100644 --- a/tonic_validate/metrics/metric.py +++ b/tonic_validate/metrics/metric.py @@ -1,10 +1,16 @@ from abc import ABC, abstractmethod +from typing import Optional from tonic_validate.classes.llm_response import LLMResponse from tonic_validate.services.openai_service import OpenAIService class Metric(ABC): + """Abstract class for a metric that can be calculated on an LLM response.""" + + """Prompt for the metric. Can be overridden by subclasses if a specific prompt is needed.""" + prompt: Optional[str] = None + @property @abstractmethod def name(self) -> str: diff --git a/tonic_validate/metrics/retrieval_precision_metric.py b/tonic_validate/metrics/retrieval_precision_metric.py index c67f557..9810179 100644 --- a/tonic_validate/metrics/retrieval_precision_metric.py +++ b/tonic_validate/metrics/retrieval_precision_metric.py @@ -4,13 +4,14 @@ from tonic_validate.metrics.metric import Metric from tonic_validate.utils.metrics_util import parse_boolean_response from tonic_validate.services.openai_service import OpenAIService -from tonic_validate.utils.llm_calls import context_relevancy_call +from tonic_validate.utils.llm_calls import context_relevancy_call, context_relevancy_prompt logger = logging.getLogger() class RetrievalPrecisionMetric(Metric): name: str = "retrieval_precision" + prompt: str = context_relevancy_prompt() def __init__(self): """ diff --git a/tonic_validate/utils/llm_calls.py b/tonic_validate/utils/llm_calls.py index f9f029b..1048561 100644 --- a/tonic_validate/utils/llm_calls.py +++ b/tonic_validate/utils/llm_calls.py @@ -360,12 +360,9 @@ async def statement_derived_from_context_call( logger.debug( f"Asking {openai_service.model} whether statement is derived from context" ) - main_message = "Considering the following statement and list of context(s)" - main_message += f"\n\nSTATEMENT:\n{statement}\nEND OF STATEMENT" - for i, context in enumerate(context_list): - main_message += f"\n\nCONTEXT {i}:\n{context}\nEND OF CONTEXT {i}" - main_message = statement_derived_from_context_prompt(main_message) + main_message = statement_derived_from_context_prompt(statement, context_list) + try: response_message = await openai_service.get_response(main_message) except ContextLengthException as e: @@ -391,18 +388,28 @@ async def statement_derived_from_context_call( return response_message -def statement_derived_from_context_prompt(main_message): +def statement_derived_from_context_prompt(statement: str, context_list: List[str]): """ Parameters ---------- - main_message : The main message to which additional instructions will be added. + statement: str + The statement to be checked. + context_list: List[str] + List of retrieved context. Returns ------- prompt message for determining if a statement can be derived from context. """ + if not context_list: + context_list = ["EXAMPLE CONTEXT"] + + main_message = "Considering the following statement and list of context(s)" + main_message += f"\n\nSTATEMENT:\n{statement}\nEND OF STATEMENT" + for i, context in enumerate(context_list): + main_message += f"\n\nCONTEXT {i}:\n{context}\nEND OF CONTEXT {i}" main_message += ( "\n\nDetermine whether the listed statement above can be derived from the " "context listed above. If the statement can "