Skip to content

Commit

Permalink
Propagate trace context to webhook and upload requests
Browse files Browse the repository at this point in the history
Based on the implementation in #1698 for sync cog.

If the request to /predict contains headers `traceparent` and
`tracestate` defined by w3c Trace Context[^1] then these headers are
forwarded on to the webhook and upload calls.

This allows observability systems to link requests passing through cog.

[^1]: https://www.w3.org/TR/trace-context/

Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
aron authored and technillogue committed Jul 12, 2024
1 parent 8d834f0 commit 5d38ae7
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 11 deletions.
21 changes: 19 additions & 2 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .eventtypes import PredictionInput
from .response_throttler import ResponseThrottler
from .retry_transport import RetryTransport
from .telemetry import current_trace_context

log = structlog.get_logger(__name__)

Expand Down Expand Up @@ -45,14 +46,25 @@ def _get_version() -> str:
WebhookSenderType = Callable[[Any, WebhookEvent], Awaitable[None]]


def webhook_headers() -> "dict[str, str]":
def common_headers() -> "dict[str, str]":
headers = {"user-agent": _user_agent}
return headers


def webhook_headers() -> "dict[str, str]":
headers = common_headers()
auth_token = os.environ.get("WEBHOOK_AUTH_TOKEN")
if auth_token:
headers["authorization"] = "Bearer " + auth_token

return headers


async def on_request_trace_context_hook(request: httpx.Request) -> None:
ctx = current_trace_context() or {}
request.headers.update(ctx)


def httpx_webhook_client() -> httpx.AsyncClient:
return httpx.AsyncClient(headers=webhook_headers(), follow_redirects=True)

Expand All @@ -68,7 +80,10 @@ def httpx_retry_client() -> httpx.AsyncClient:
retryable_methods=["POST"],
)
return httpx.AsyncClient(
headers=webhook_headers(), transport=transport, follow_redirects=True
event_hooks={"request": [on_request_trace_context_hook]},
headers=webhook_headers(),
transport=transport,
follow_redirects=True,
)


Expand All @@ -87,6 +102,8 @@ def httpx_file_client() -> httpx.AsyncClient:
# httpx default for pool is 5, use that
timeout = httpx.Timeout(connect=10, read=15, write=None, pool=5)
return httpx.AsyncClient(
event_hooks={"request": [on_request_trace_context_hook]},
headers=common_headers(),
transport=transport,
follow_redirects=True,
timeout=timeout,
Expand Down
39 changes: 30 additions & 9 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Dict,
Optional,
TypeVar,
Union,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,6 +51,7 @@
SetupTask,
UnknownPredictionError,
)
from .telemetry import make_trace_context, trace_context

log = structlog.get_logger("cog.server.http")

Expand Down Expand Up @@ -190,9 +190,16 @@ class TrainingRequest(
)
def train(
request: TrainingRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any: # type: ignore
return predict(request, prefer)
return predict(
request,
prefer=prefer,
traceparent=traceparent,
tracestate=tracestate,
)

@app.put(
"/trainings/{training_id}",
Expand All @@ -202,9 +209,17 @@ def train(
def train_idempotent(
training_id: str = Path(..., title="Training ID"),
request: TrainingRequest = Body(..., title="Training Request"),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any:
return predict_idempotent(training_id, request, prefer)
return predict_idempotent(
prediction_id=training_id,
request=request,
prefer=prefer,
traceparent=traceparent,
tracestate=tracestate,
)

@app.post("/trainings/{training_id}/cancel")
def cancel_training(training_id: str = Path(..., title="Training ID")) -> Any:
Expand Down Expand Up @@ -270,7 +285,9 @@ async def ready() -> Any:
)
async def predict(
request: PredictionRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any: # type: ignore
"""
Run a single prediction on the model
Expand All @@ -285,7 +302,8 @@ async def predict(
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return await shared_predict(request=request, respond_async=respond_async)
with trace_context(make_trace_context(traceparent, tracestate)):
return await shared_predict(request=request, respond_async=respond_async)

@limited
@app.put(
Expand All @@ -296,7 +314,9 @@ async def predict(
async def predict_idempotent(
prediction_id: str = Path(..., title="Prediction ID"),
request: PredictionRequest = Body(..., title="Prediction Request"),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any:
"""
Run a single prediction on the model (idempotent creation).
Expand All @@ -314,7 +334,8 @@ async def predict_idempotent(
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return await shared_predict(request=request, respond_async=respond_async)
with trace_context(make_trace_context(traceparent, tracestate)):
return await shared_predict(request=request, respond_async=respond_async)

async def shared_predict(
*, request: Optional[PredictionRequest], respond_async: bool = False
Expand Down
54 changes: 54 additions & 0 deletions python/cog/server/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Generator, Optional

# TypedDict was added in 3.8
from typing_extensions import TypedDict


# See: https://www.w3.org/TR/trace-context/
class TraceContext(TypedDict, total=False):
traceparent: str
tracestate: str


TRACE_CONTEXT: ContextVar[Optional[TraceContext]] = ContextVar(
"trace_context", default=None
)


def make_trace_context(
traceparent: Optional[str] = None, tracestate: Optional[str] = None
) -> TraceContext:
"""
Creates a trace context dictionary from the given traceparent and tracestate
headers. This is used to pass the trace context between services.
"""
ctx: TraceContext = {}
if traceparent:
ctx["traceparent"] = traceparent
if tracestate:
ctx["tracestate"] = tracestate
return ctx


def current_trace_context() -> Optional[TraceContext]:
"""
Returns the current trace context, this needs to be added via HTTP headers
to all outgoing HTTP requests.
"""
return TRACE_CONTEXT.get()


@contextmanager
def trace_context(ctx: TraceContext) -> Generator[None, None, None]:
"""
A helper for managing the current trace context provided by the inbound
HTTP request. This context is used to link requests across the system and
needs to be added to all internal outgoing HTTP requests.
"""
t = TRACE_CONTEXT.set(ctx)
try:
yield
finally:
TRACE_CONTEXT.reset(t)
60 changes: 60 additions & 0 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import httpx
import io
import respx
import time
import unittest.mock as mock

Expand Down Expand Up @@ -560,6 +562,64 @@ def test_asynchronous_prediction_endpoint(client, match):
assert webhook.call_count == 1


# End-to-end test for passing tracing headers on to downstream services.
@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
@uses_predictor_with_client_options(
"output_file", upload_url="https://example.com/upload"
)
async def test_asynchronous_prediction_endpoint_with_trace_context(
respx_mock: respx.MockRouter, client, match
):
webhook = respx_mock.post(
"/webhook",
json__id="12345abcde",
json__status="succeeded",
json__output="https://example.com/upload/file",
headers={
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
).respond(200)
uploader = respx_mock.put(
"/upload/file",
headers={
"content-type": "application/octet-stream",
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
).respond(200)

resp = client.post(
"/predictions",
json={
"id": "12345abcde",
"input": {},
"webhook": "https://example.com/webhook",
"webhook_events_filter": ["completed"],
},
headers={
"Prefer": "respond-async",
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
)
assert resp.status_code == 202

assert resp.json() == match(
{"status": "processing", "output": None, "started_at": mock.ANY}
)
assert resp.json()["started_at"] is not None

n = 0
while webhook.call_count < 1 and n < 10:
time.sleep(0.1)
n += 1

assert webhook.call_count == 1
assert uploader.call_count == 1


@uses_predictor("sleep")
def test_prediction_cancel(client):
resp = client.post("/predictions/123/cancel")
Expand Down

0 comments on commit 5d38ae7

Please sign in to comment.