Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function to emit metrics #1649

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/metrics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Metrics

Prediction objects have a `metrics` field. This normally includes `predict_time` and `total_time`. Official language models have metrics like `input_token_count`, `output_token_count`, `tokens_per_second`, and `time_to_first_token`. Currently, custom metrics from Cog are ignored when running on Replicate. Official Replicate-published models are the only exception to this. When running outside of Replicate, you can emit custom metrics like this:


```python
import cog
from cog import BasePredictor, Path

class Predictor(BasePredictor):
def predict(self, width: int, height: int) -> Path:
"""Run a single prediction on the model"""
cog.emit_metric(name="pixel_count", value=width * height)
```
2 changes: 2 additions & 0 deletions python/cog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel

from .predictor import BasePredictor
from .server.worker import emit_metric
from .types import AsyncConcatenateIterator, ConcatenateIterator, File, Input, Path

try:
Expand All @@ -18,4 +19,5 @@
"File",
"Input",
"Path",
"emit_metric",
]
6 changes: 6 additions & 0 deletions python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class Log:
source: str = field(validator=validators.in_(["stdout", "stderr"]))


@define
class PredictionMetric:
name: str
value: "float | int"


@define
class PredictionOutput:
payload: Any
Expand Down
11 changes: 8 additions & 3 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Heartbeat,
Log,
PredictionInput,
PredictionMetric,
PredictionOutput,
PredictionOutputType,
PublicEventType,
Expand Down Expand Up @@ -438,6 +439,7 @@ def __init__(
self.logger.info("starting prediction")
# maybe this should be a deep copy to not share File state with child worker
self.p = schema.PredictionResponse(**request.dict())
self.p.metrics = {}
self.p.status = schema.Status.PROCESSING
self.p.output = None
self.p.logs = ""
Expand Down Expand Up @@ -489,9 +491,9 @@ async def succeeded(self) -> None:
# that...
assert self.p.completed_at is not None
assert self.p.started_at is not None
self.p.metrics = {
"predict_time": (self.p.completed_at - self.p.started_at).total_seconds()
}
self.p.metrics["predict_time"] = (
self.p.completed_at - self.p.started_at
).total_seconds()
await self._send_webhook(schema.WebhookEvent.COMPLETED)

async def failed(self, error: str) -> None:
Expand Down Expand Up @@ -552,6 +554,9 @@ def event_to_handle_future(self, event: PublicEventType) -> Awaitable[None]:
if self._output_type.multi:
return self.set_output([])
return self.noop()
if isinstance(event, PredictionMetric):
self.p.metrics[event.name] = event.value
return self.noop()
if isinstance(event, PredictionOutput):
if self._output_type is None:
return self.failed(error="Predictor returned unexpected output")
Expand Down
19 changes: 19 additions & 0 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Heartbeat,
Log,
PredictionInput,
PredictionMetric,
PredictionOutput,
PredictionOutputType,
PublicEventType,
Expand Down Expand Up @@ -85,6 +86,17 @@ async def read(
raise self.fatal


# janky mutable container for a single eventual ChildWorker
worker_reference: "dict[None, _ChildWorker]" = {}
yorickvP marked this conversation as resolved.
Show resolved Hide resolved

def emit_metric(metric_name: str, metric_value: "float | int") -> None:
worker = worker_reference.get(None, None)
if worker is None:
raise Exception("Attempted to emit metric but worker is not running")
worker._emit_metric(metric_name, metric_value)



class _ChildWorker(_spawn.Process): # type: ignore
def __init__(
self,
Expand All @@ -109,6 +121,7 @@ def run(self) -> None:
# We use SIGUSR1 to signal an interrupt for cancelation.
signal.signal(signal.SIGUSR1, self._signal_handler)

worker_reference[None] = self
self.prediction_id_context: ContextVar[str] = ContextVar("prediction_context")

# <could be moved into StreamRedirector>
Expand Down Expand Up @@ -239,6 +252,12 @@ def _handle_predict_error(self, id: str) -> Iterator[None]:
self._stream_redirector.drain()
self._events.send((id, done))

def _emit_metric(self, name: str, value: "int | float") -> None:
prediction_id = self.prediction_id_context.get(None)
if prediction_id is None:
raise Exception("Tried to emit a metric outside a prediction context")
self._events.send((prediction_id, PredictionMetric(name, value)))

def _mk_send(self, id: str) -> Callable[[PublicEventType], None]:
def send(event: PublicEventType) -> None:
self._events.send((id, event))
Expand Down
3 changes: 2 additions & 1 deletion python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import time
from datetime import datetime
from unittest import mock

import pytest
import pytest_asyncio

from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent
from cog.server.clients import ClientManager
from cog.server.eventtypes import (
Expand All @@ -22,7 +24,6 @@
UnknownPredictionError,
)


# TODO
# - setup logs
# - file inputs being converted
Expand Down