Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EPD-612: Evaluations with AutoblocksTracer #124

Merged
merged 27 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions autoblocks/_impl/testing/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
import abc
import dataclasses
import functools
import uuid
from typing import Any
from typing import Optional


@dataclasses.dataclass()
Copy link
Contributor Author

@dmorton2297 dmorton2297 Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is temporarily here - I am going to create a tracer folder under _impl that contains its own models + tracer implementation, but have held off to make the diff easier to review

class TracerEvent:
message: str
trace_id: str
timestamp: str
properties: dict

@classmethod
def to_json(cls, event):
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
return {
"message": event.message,
"traceId": event.trace_id,
"timestamp": event.timestamp,
"properties": event.properties,
}


@dataclasses.dataclass()
class Threshold:
lt: Optional[float] = None
Expand All @@ -13,6 +31,26 @@ class Threshold:
gte: Optional[float] = None


@dataclasses.dataclass()
class EventEvaluation:
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
evaluator_external_id: str
score: float
id: Optional[str] = dataclasses.field(default_factory=lambda: str(uuid.uuid4()))
metadata: Optional[dict] = None
threshold: Optional[Threshold] = None

@classmethod
def to_json(cls, event_evaluation):
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
return dict(
evaluatorExternalId=event_evaluation.evaluator_external_id,
id=str(event_evaluation.id),
score=event_evaluation.score,
metadata=dict(event_evaluation.metadata) if event_evaluation.metadata else None,
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
threshold=dict(event_evaluation.threshold) if event_evaluation.threshold else None,
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
)


# TODO: Rename TestEvaluation?
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
@dataclasses.dataclass()
class Evaluation:
score: float
Expand Down Expand Up @@ -43,3 +81,19 @@ def id(self) -> str:
@abc.abstractmethod
def evaluate_test_case(self, test_case: BaseTestCase, output: Any) -> Evaluation:
pass


class BaseEventEvaluator(abc.ABC):
"""
An abstract base class for implementing an evaluator that runs on events
in an online testing scenario.
"""

@property
@abc.abstractmethod
def id(self) -> str:
pass

@abc.abstractmethod
def evaluate_event(self, event: TracerEvent) -> EventEvaluation:
pass
89 changes: 81 additions & 8 deletions autoblocks/_impl/tracer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import contextvars
import inspect
import logging
import uuid
from contextlib import contextmanager
Expand All @@ -7,20 +9,26 @@
from datetime import timedelta
from datetime import timezone
from typing import Dict
from typing import List
from typing import Optional

from autoblocks._impl import global_state
from autoblocks._impl.config.constants import INGESTION_ENDPOINT
from autoblocks._impl.testing.models import BaseEventEvaluator
from autoblocks._impl.testing.models import EventEvaluation
from autoblocks._impl.testing.models import TracerEvent
from autoblocks._impl.util import AutoblocksEnvVar

log = logging.getLogger(__name__)
from autoblocks._impl.util import gather_with_max_concurrency


@dataclass
class SendEventResponse:
trace_id: Optional[str]


log = logging.getLogger(__name__)


class AutoblocksTracer:
def __init__(
self,
Expand Down Expand Up @@ -101,6 +109,60 @@ def start_span(self):
props["span_id"] = prev_span_id
self.set_properties(props)

async def evaluate_event(self, event: TracerEvent, evaluator: BaseEventEvaluator) -> None:
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
"""
Evaluates an event using a provided evaluator.
"""
if inspect.iscoroutinefunction(evaluator.evaluate_event):
try:
evaluation = await evaluator.evaluate_event(event=event)
except Exception as err:
log.error("Event evaluation through an exception", err)
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
else:
try:
ctx = contextvars.copy_context()
evaluation = await global_state.event_loop().run_in_executor(
None,
ctx.run,
evaluator.evaluate_event,
event,
)
except Exception as err:
log.error("Event evaluation through an exception", err)
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved

if evaluation is None:
return
return evaluation
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved

async def _run_and_build_evals_properties(
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
self, evaluators: BaseEventEvaluator, event: TracerEvent, max_evaluator_concurrency: int
) -> List[EventEvaluation]:
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
event_dict = TracerEvent.to_json(event)
if len(evaluators) == 0:
return event_dict
try:
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
evaluations: List[EventEvaluation] = await gather_with_max_concurrency(
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
max_evaluator_concurrency,
[
self.evaluate_event(
event=event,
evaluator=evaluator,
)
for evaluator in evaluators
],
)
if evaluations and len(evaluations) > 0:
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
evaluations_json = [
EventEvaluation.to_json(evaluation)
for evaluation in filter(lambda x: isinstance(x, EventEvaluation), evaluations)
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
]
if len(evaluations_json) > 0:
event_dict["properties"]["evaluations"] = evaluations_json
return event_dict
except Exception as err:
log.error("Unable to complete evaluating event", err)
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
return event_dict

async def _send_event_unsafe(
self,
# Require all arguments to be specified via key=value
Expand All @@ -112,6 +174,8 @@ async def _send_event_unsafe(
timestamp: Optional[str] = None,
properties: Optional[Dict] = None,
prompt_tracking: Optional[Dict] = None,
evaluators: List[BaseEventEvaluator] = [],
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
max_evaluator_concurrency: int,
nicolewhite marked this conversation as resolved.
Show resolved Hide resolved
) -> SendEventResponse:
merged_properties = dict(self._properties)
merged_properties.update(properties or {})
Expand All @@ -125,14 +189,19 @@ async def _send_event_unsafe(
trace_id = trace_id or self._trace_id
timestamp = timestamp or datetime.now(timezone.utc).isoformat()

event = TracerEvent(
message=message,
trace_id=trace_id,
timestamp=timestamp,
properties=merged_properties,
)
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved

transformed_event_json = await self._run_and_build_evals_properties(
evaluators, event, max_evaluator_concurrency
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
)
req = await global_state.http_client().post(
url=INGESTION_ENDPOINT,
json={
"message": message,
"traceId": trace_id,
"timestamp": timestamp,
"properties": merged_properties,
},
json=transformed_event_json,
headers=self._client_headers,
timeout=self._timeout_seconds,
)
Expand All @@ -150,6 +219,8 @@ def send_event(
parent_span_id: Optional[str] = None,
timestamp: Optional[str] = None,
properties: Optional[Dict] = None,
evaluators: List[BaseEventEvaluator] = [],
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
max_evaluator_concurrency: int = 5,
prompt_tracking: Optional[Dict] = None,
) -> SendEventResponse:
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand All @@ -167,6 +238,8 @@ def send_event(
parent_span_id=parent_span_id,
timestamp=timestamp,
properties=properties,
evaluators=evaluators,
max_evaluator_concurrency=max_evaluator_concurrency,
prompt_tracking=prompt_tracking,
),
global_state.event_loop(),
Expand Down
2 changes: 1 addition & 1 deletion autoblocks/_impl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ async def sem_coro(coro: Coroutine):

# return_exceptions=True causes exceptions to be returned as values instead
# of propagating them to the caller. this is similar in behavior to Promise.allSettled
await asyncio.gather(*(sem_coro(c) for c in coroutines), return_exceptions=True)
return await asyncio.gather(*(sem_coro(c) for c in coroutines), return_exceptions=True)
nicolewhite marked this conversation as resolved.
Show resolved Hide resolved
nicolewhite marked this conversation as resolved.
Show resolved Hide resolved
90 changes: 90 additions & 0 deletions tests/autoblocks/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import pytest

from autoblocks._impl.config.constants import INGESTION_ENDPOINT
from autoblocks._impl.testing.models import BaseEventEvaluator
from autoblocks._impl.testing.models import EventEvaluation
from autoblocks._impl.testing.models import TracerEvent
from autoblocks.tracer import AutoblocksTracer
from tests.autoblocks.util import make_expected_body

Expand Down Expand Up @@ -448,3 +451,90 @@ def test_tracer_start_span(*args, **kwargs):

assert tracer._properties.get("span_id") is None
assert tracer._properties.get("parent_span_id") is None


def test_tracer_prod_evaluations(httpx_mock):
test_evaluation_id = uuid.uuid4()

class MyEvaluator(BaseEventEvaluator):
id = "my-evaluator"

def evaluate_event(self, event: TracerEvent) -> EventEvaluation:
nicolewhite marked this conversation as resolved.
Show resolved Hide resolved
return EventEvaluation(
evaluator_external_id=self.id,
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
id=test_evaluation_id,
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
score=0.9,
threshold={"gte": 0.5},
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
)

mock_input = {
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
"trace_id": "my-trace-id",
"timestamp": timestamp,
"properties": {},
"evaluators": [
MyEvaluator(),
],
}
httpx_mock.add_response(
url=INGESTION_ENDPOINT,
method="POST",
status_code=200,
json={"traceId": "my-trace-id"},
match_headers={"Authorization": "Bearer mock-ingestion-key"},
match_content=make_expected_body(
dict(
message="my-message",
traceId="my-trace-id",
timestamp=timestamp,
properties={
"evaluations": [
{
"evaluatorExternalId": "my-evaluator",
"id": str(test_evaluation_id),
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
"score": 0.9,
"metadata": None,
"threshold": {"gte": 0.5},
}
]
},
)
),
)
tracer = AutoblocksTracer("mock-ingestion-key")
resp = tracer.send_event("my-message", **mock_input)
assert resp.trace_id == "my-trace-id"


def test_tracer_failing_evaluation(httpx_mock):
class MyEvaluator(BaseEventEvaluator):
id = "my-evaluator"

def evaluate_event(self, event: TracerEvent) -> EventEvaluation:
raise Exception("Something terrible went wrong")

mock_input = {
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved
"trace_id": "my-trace-id",
"timestamp": timestamp,
"properties": {},
"evaluators": [
MyEvaluator(),
],
}
httpx_mock.add_response(
url=INGESTION_ENDPOINT,
method="POST",
status_code=200,
json={"traceId": "my-trace-id"},
match_headers={"Authorization": "Bearer mock-ingestion-key"},
match_content=make_expected_body(
dict(
message="my-message",
traceId="my-trace-id",
timestamp=timestamp,
properties={},
)
),
)
tracer = AutoblocksTracer("mock-ingestion-key")
resp = tracer.send_event("my-message", **mock_input)
assert resp.trace_id == "my-trace-id"
dmorton2297 marked this conversation as resolved.
Show resolved Hide resolved