diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py index 82340c88..9c07bd16 100644 --- a/skore/src/skore/ui/project_routes.py +++ b/skore/src/skore/ui/project_routes.py @@ -3,24 +3,28 @@ from __future__ import annotations import base64 -from collections import defaultdict +import operator from dataclasses import dataclass +from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from fastapi import APIRouter, HTTPException, Request, status +from skore.item import ( + Item, + MediaItem, + NumpyArrayItem, + PandasDataFrameItem, + PandasSeriesItem, + PolarsDataFrameItem, + PolarsSeriesItem, + PrimitiveItem, + SklearnBaseEstimatorItem, +) from skore.item.cross_validation_item import ( CrossValidationAggregationItem, CrossValidationItem, ) -from skore.item.media_item import MediaItem -from skore.item.numpy_array_item import NumpyArrayItem -from skore.item.pandas_dataframe_item import PandasDataFrameItem -from skore.item.pandas_series_item import PandasSeriesItem -from skore.item.polars_dataframe_item import PolarsDataFrameItem -from skore.item.polars_series_item import PolarsSeriesItem -from skore.item.primitive_item import PrimitiveItem -from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem from skore.project import Project from skore.view.view import Layout, View @@ -31,9 +35,10 @@ @dataclass -class SerializedItem: +class SerializableItem: """Serialized item.""" + name: str media_type: str value: Any updated_at: str @@ -41,70 +46,73 @@ class SerializedItem: @dataclass -class SerializedProject: +class SerializableProject: """Serialized project, to be sent to the skore-ui.""" - items: dict[str, list[SerializedItem]] + items: dict[str, list[SerializableItem]] views: dict[str, Layout] -def __serialize_project(project: Project) -> SerializedProject: - items = defaultdict(list) - - def pandas_dataframe_to_serializable(df: pandas.DataFrame): - return df.fillna("NaN").to_dict(orient="tight") - - for key in project.list_item_keys(): - for item in project.get_item_versions(key): - if isinstance(item, PrimitiveItem): - value = item.primitive - media_type = "text/markdown" - elif isinstance(item, NumpyArrayItem): - value = item.array.tolist() - media_type = "text/markdown" - elif isinstance(item, PandasDataFrameItem): - value = pandas_dataframe_to_serializable(item.dataframe) - media_type = "application/vnd.dataframe+json" - elif isinstance(item, PandasSeriesItem): - value = item.series.fillna("NaN").to_list() - media_type = "text/markdown" - elif isinstance(item, PolarsDataFrameItem): - value = pandas_dataframe_to_serializable(item.dataframe.to_pandas()) - media_type = "application/vnd.dataframe+json" - elif isinstance(item, PolarsSeriesItem): - value = item.series.to_list() - media_type = "text/markdown" - elif isinstance(item, SklearnBaseEstimatorItem): - value = item.estimator_html_repr - media_type = "application/vnd.sklearn.estimator+html" - elif isinstance(item, MediaItem): - if "text" in item.media_type: - value = item.media_bytes.decode(encoding=item.media_encoding) - media_type = f"{item.media_type}" - else: - value = base64.b64encode(item.media_bytes).decode() - media_type = f"{item.media_type};base64" - elif isinstance( - item, (CrossValidationItem, CrossValidationAggregationItem) - ): - value = base64.b64encode(item.plot_bytes).decode() - media_type = "application/vnd.plotly.v1+json;base64" - else: - raise ValueError(f"Item {item} is not a known item type.") - - items[key].append( - SerializedItem( - media_type=media_type, - value=value, - updated_at=item.updated_at, - created_at=item.created_at, - ) - ) +def __pandas_dataframe_as_serializable(df: pandas.DataFrame): + return df.fillna("NaN").to_dict(orient="tight") + + +def __item_as_serializable(name: str, item: Item) -> SerializableItem: + if isinstance(item, PrimitiveItem): + value = item.primitive + media_type = "text/markdown" + elif isinstance(item, NumpyArrayItem): + value = item.array.tolist() + media_type = "text/markdown" + elif isinstance(item, PandasDataFrameItem): + value = __pandas_dataframe_as_serializable(item.dataframe) + media_type = "application/vnd.dataframe+json" + elif isinstance(item, PandasSeriesItem): + value = item.series.fillna("NaN").to_list() + media_type = "text/markdown" + elif isinstance(item, PolarsDataFrameItem): + value = __pandas_dataframe_as_serializable(item.dataframe.to_pandas()) + media_type = "application/vnd.dataframe+json" + elif isinstance(item, PolarsSeriesItem): + value = item.series.to_list() + media_type = "text/markdown" + elif isinstance(item, SklearnBaseEstimatorItem): + value = item.estimator_html_repr + media_type = "application/vnd.sklearn.estimator+html" + elif isinstance(item, MediaItem): + if "text" in item.media_type: + value = item.media_bytes.decode(encoding=item.media_encoding) + media_type = f"{item.media_type}" + else: + value = base64.b64encode(item.media_bytes).decode() + media_type = f"{item.media_type};base64" + elif isinstance(item, (CrossValidationItem, CrossValidationAggregationItem)): + value = base64.b64encode(item.plot_bytes).decode() + media_type = "application/vnd.plotly.v1+json;base64" + else: + raise ValueError(f"Item {item} is not a known item type.") + + return SerializableItem( + name=name, + media_type=media_type, + value=value, + updated_at=item.updated_at, + created_at=item.created_at, + ) + + +def __project_as_serializable(project: Project) -> SerializableProject: + items = { + key: [ + __item_as_serializable(key, item) for item in project.get_item_versions(key) + ] + for key in project.list_item_keys() + } views = {key: project.get_view(key).layout for key in project.list_view_keys()} - return SerializedProject( - items=dict(items), + return SerializableProject( + items=items, views=views, ) @@ -113,7 +121,7 @@ def pandas_dataframe_to_serializable(df: pandas.DataFrame): async def get_items(request: Request): """Serialize a project and send it.""" project = request.app.state.project - return __serialize_project(project) + return __project_as_serializable(project) @router.put("/views", status_code=status.HTTP_201_CREATED) @@ -127,7 +135,7 @@ async def put_view(request: Request, key: str, layout: Layout): view = View(layout=layout) project.put_view(key, view) - return __serialize_project(project) + return __project_as_serializable(project) @router.delete("/views", status_code=status.HTTP_202_ACCEPTED) @@ -142,4 +150,27 @@ async def delete_view(request: Request, key: str): status_code=status.HTTP_404_NOT_FOUND, detail="View not found" ) from None - return __serialize_project(project) + return __project_as_serializable(project) + + +@router.get("/activity") +async def get_activity( + request: Request, + after: datetime = datetime(1, 1, 1, 0, 0, 0, 0, timezone.utc), +): + """Send all recent activity as a JSON array. + + The activity is composed of all the items and their versions created after the + datetime `after`, sorted from newest to oldest. + """ + project = request.app.state.project + return sorted( + ( + __item_as_serializable(key, version) + for key in project.list_item_keys() + for version in project.get_item_versions(key) + if datetime.fromisoformat(version.updated_at) > after + ), + key=operator.attrgetter("updated_at"), + reverse=True, + ) diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index f538e595..076825ab 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -1,3 +1,5 @@ +import datetime + import numpy import pandas import polars @@ -37,6 +39,7 @@ def test_get_items(client, in_memory_project): "items": { "test": [ { + "name": "test", "media_type": "text/markdown", "value": item.primitive, "created_at": item.created_at, @@ -141,3 +144,43 @@ def test_serialize_media_item(client, in_memory_project): assert "image" in project["items"]["img"][0]["media_type"] assert project["items"]["html"][0]["value"] == html assert project["items"]["media html"][0]["value"] == html + + +def test_activity_feed(monkeypatch, client, in_memory_project): + class MockDatetime: + NOW = datetime.datetime.now(tz=datetime.timezone.utc) + TIMEDELTA = datetime.timedelta(days=1) + + def __init__(self, *args, **kwargs): ... + + @staticmethod + def now(*args, **kwargs): + MockDatetime.NOW += MockDatetime.TIMEDELTA + return MockDatetime.NOW + + monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + + for i in range(5): + in_memory_project.put(str(i), i) + + response = client.get("/api/project/activity") + assert response.status_code == 200 + assert [(item["name"], item["value"]) for item in response.json()] == [ + ("4", 4), + ("3", 3), + ("2", 2), + ("1", 1), + ("0", 0), + ] + + now = MockDatetime.NOW # increments now + + in_memory_project.put("4", 5) + in_memory_project.put("5", 5) + + response = client.get("/api/project/activity", params={"after": now}) + assert response.status_code == 200 + assert [(item["name"], item["value"]) for item in response.json()] == [ + ("5", 5), + ("4", 5), + ]