Skip to content

Commit

Permalink
deps: Update google-cloud-aiplatform (#510)
Browse files Browse the repository at this point in the history
Upgrading the latest google-cloud-aiplatform library made a breaking
change to the evaluation system. This PR addresses the library version
upgrade along with the updates to the evaluation system.
  • Loading branch information
Yuan325 authored Nov 19, 2024
1 parent 20870ea commit 292fec6
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 28 deletions.
11 changes: 8 additions & 3 deletions llm_demo/evaluation/eval_golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class EvalData(BaseModel):
category: Optional[str] = Field(default=None, description="Evaluation category")
query: Optional[str] = Field(default=None, description="User query")
instruction: Optional[str] = Field(
default=None, description="Instruction to llm system"
default="",
description="Part of the input user prompt. It refers to the inference instruction that is sent to you llm",
)
content: Optional[str] = Field(
default=None,
Expand All @@ -48,16 +49,20 @@ class EvalData(BaseModel):
tool_calls: List[ToolCall] = Field(
default=[], description="Golden tool call for evaluation"
)
prompt: Optional[str] = Field(
default="",
description="User input for the Gen AI model or application. It's optional in some cases.",
)
context: Optional[List[Dict[str, Any] | List[Dict[str, Any]]]] = Field(
default=None, description="Context given to llm in order to answer user query"
)
output: Optional[str] = Field(
default=None, description="Golden output for evaluation"
)
prediction_tool_calls: List[ToolCall] = Field(
llm_tool_calls: List[ToolCall] = Field(
default=[], description="Tool call output from LLM"
)
prediction_output: str = Field(default="", description="Final output from LLM")
llm_output: str = Field(default="", description="Final output from LLM")
reset: bool = Field(
default=True, description="Determine to reset the chat after invoke"
)
Expand Down
44 changes: 21 additions & 23 deletions llm_demo/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@

import pandas as pd
from pydantic import BaseModel, Field
from vertexai.preview.evaluation import EvalTask # type: ignore
from vertexai.preview.evaluation import _base as evaluation_base
from vertexai.evaluation import EvalTask
from vertexai.evaluation import _base as evaluation_base

from orchestrator import BaseOrchestrator

from .eval_golden import EvalData, ToolCall
from .metrics import response_phase_metrics, retrieval_phase_metrics


async def run_llm_for_eval(
eval_list: List[EvalData], orc: BaseOrchestrator, session: Dict, session_id: str
) -> List[EvalData]:
"""
Generate prediction_tool_calls and prediction_output for golden dataset query.
Generate llm_tool_calls and llm_output for golden dataset query.
This function is only compatible with the langchain-tools orchestration.
"""
agent = orc.get_user_session(session_id)
for eval_data in eval_list:
Expand All @@ -39,23 +41,25 @@ async def run_llm_for_eval(
except Exception as e:
print(f"error invoking agent: {e}")
else:
eval_data.prediction_output = query_response.get("output")
eval_data.llm_output = query_response.get("output")

# Retrieve prediction_tool_calls from query response
prediction_tool_calls = []
# Retrieve llm_tool_calls from query response
llm_tool_calls = []
contexts = []
for step in query_response.get("intermediate_steps"):
called_tool = step[0]
tool_call = ToolCall(
name=called_tool.tool,
arguments=called_tool.tool_input,
)
prediction_tool_calls.append(tool_call)
llm_tool_calls.append(tool_call)
context = step[-1]
contexts.append(context)

eval_data.prediction_tool_calls = prediction_tool_calls
eval_data.llm_tool_calls = llm_tool_calls
eval_data.context = contexts
eval_data.prompt = PROMPT
eval_data.instruction = f"Answer user query based on context given. User query is {eval_data.query}."

if eval_data.reset:
orc.user_session_reset(session, session_id)
Expand All @@ -68,7 +72,6 @@ def evaluate_retrieval_phase(
"""
Run evaluation for the ability of a model to select the right tool and arguments (retrieval phase).
"""
metrics = ["tool_call_quality"]
# Prepare evaluation task input
responses = []
references = []
Expand All @@ -85,7 +88,7 @@ def evaluate_retrieval_phase(
json.dumps(
{
"content": e.content,
"tool_calls": [t.model_dump() for t in e.prediction_tool_calls],
"tool_calls": [t.model_dump() for t in e.llm_tool_calls],
}
)
)
Expand All @@ -98,7 +101,7 @@ def evaluate_retrieval_phase(
# Run evaluation
eval_result = EvalTask(
dataset=eval_dataset,
metrics=metrics,
metrics=retrieval_phase_metrics,
experiment=experiment_name,
).evaluate()
return eval_result
Expand All @@ -110,37 +113,32 @@ def evaluate_response_phase(
"""
Run evaluation for the ability of a model to generate a response based on the context given (response phase).
"""
metrics = [
"text_generation_quality",
"text_generation_factuality",
"summarization_pointwise_reference_free",
"qa_pointwise_reference_free",
]
# Prepare evaluation task input
instructions = []
contexts = []
responses = []
prompts = []

for e in eval_datas:
instructions.append(
f"Answer user query based on context given. User query is {e.query}."
)
instructions.append(e.instruction)
context_str = (
[json.dumps(c) for c in e.context] if e.context else ["no data retrieved"]
)
contexts.append(PROMPT + ", " + ", ".join(context_str))
responses.append(e.prediction_output or "")
prompts.append(e.prompt)
contexts.append(", ".join(context_str))
responses.append(e.llm_output or "")
eval_dataset = pd.DataFrame(
{
"instruction": instructions,
"prompt": prompts,
"context": contexts,
"response": responses,
}
)
# Run evaluation
eval_result = EvalTask(
dataset=eval_dataset,
metrics=metrics,
metrics=response_phase_metrics,
experiment=experiment_name,
).evaluate()
return eval_result
Expand Down
49 changes: 49 additions & 0 deletions llm_demo/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from vertexai.evaluation import MetricPromptTemplateExamples, PointwiseMetric

text_quality_metric = PointwiseMetric(
metric="text_quality",
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template(
"text_quality"
),
)

summarization_quality_metric = PointwiseMetric(
metric="summarization_quality",
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template(
"summarization_quality"
),
)

question_answering_quality_metric = PointwiseMetric(
metric="question_answering_quality",
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template(
"question_answering_quality"
),
)

response_phase_metrics = [
text_quality_metric,
summarization_quality_metric,
question_answering_quality_metric,
]

retrieval_phase_metrics = [
"tool_call_valid",
"tool_name_match",
"tool_parameter_key_match",
"tool_parameter_kv_match",
]
6 changes: 5 additions & 1 deletion llm_demo/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ profile = "black"

[tool.mypy]
python_version = 3.11
warn_unused_configs = true
warn_unused_configs = true

[[tool.mypy.overrides]]
module = ["vertexai.evaluation"]
ignore_missing_imports = true
2 changes: 1 addition & 1 deletion llm_demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
fastapi==0.109.2
google-auth==2.33.0
google-cloud-aiplatform[rapid_evaluation]==1.62.0
google-cloud-aiplatform[evaluation]==1.72.0
itsdangerous==2.2.0
jinja2==3.1.4
langchain-community==0.2.9
Expand Down

0 comments on commit 292fec6

Please sign in to comment.