Skip to content

Commit

Permalink
feat: Add a new backend API for activity feed (#857)
Browse files Browse the repository at this point in the history
Author: Thomas S <thomas@probabl.ai>
Author: Matt J <matthieu@probabl.ai>

---

Part of #767 .

---------

Co-authored-by: Matt J <matthieu@probabl.ai>
Co-authored-by: Auguste Baum <52001167+augustebaum@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 60ac13c commit 2e3c616
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 69 deletions.
169 changes: 100 additions & 69 deletions skore/src/skore/ui/project_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,28 @@
from __future__ import annotations

import base64
from collections import defaultdict
import operator
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any

from fastapi import APIRouter, HTTPException, Request, status

from skore.item import (
Item,
MediaItem,
NumpyArrayItem,
PandasDataFrameItem,
PandasSeriesItem,
PolarsDataFrameItem,
PolarsSeriesItem,
PrimitiveItem,
SklearnBaseEstimatorItem,
)
from skore.item.cross_validation_item import (
CrossValidationAggregationItem,
CrossValidationItem,
)
from skore.item.media_item import MediaItem
from skore.item.numpy_array_item import NumpyArrayItem
from skore.item.pandas_dataframe_item import PandasDataFrameItem
from skore.item.pandas_series_item import PandasSeriesItem
from skore.item.polars_dataframe_item import PolarsDataFrameItem
from skore.item.polars_series_item import PolarsSeriesItem
from skore.item.primitive_item import PrimitiveItem
from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem
from skore.project import Project
from skore.view.view import Layout, View

Expand All @@ -31,80 +35,84 @@


@dataclass
class SerializedItem:
class SerializableItem:
"""Serialized item."""

name: str
media_type: str
value: Any
updated_at: str
created_at: str


@dataclass
class SerializedProject:
class SerializableProject:
"""Serialized project, to be sent to the skore-ui."""

items: dict[str, list[SerializedItem]]
items: dict[str, list[SerializableItem]]
views: dict[str, Layout]


def __serialize_project(project: Project) -> SerializedProject:
items = defaultdict(list)

def pandas_dataframe_to_serializable(df: pandas.DataFrame):
return df.fillna("NaN").to_dict(orient="tight")

for key in project.list_item_keys():
for item in project.get_item_versions(key):
if isinstance(item, PrimitiveItem):
value = item.primitive
media_type = "text/markdown"
elif isinstance(item, NumpyArrayItem):
value = item.array.tolist()
media_type = "text/markdown"
elif isinstance(item, PandasDataFrameItem):
value = pandas_dataframe_to_serializable(item.dataframe)
media_type = "application/vnd.dataframe+json"
elif isinstance(item, PandasSeriesItem):
value = item.series.fillna("NaN").to_list()
media_type = "text/markdown"
elif isinstance(item, PolarsDataFrameItem):
value = pandas_dataframe_to_serializable(item.dataframe.to_pandas())
media_type = "application/vnd.dataframe+json"
elif isinstance(item, PolarsSeriesItem):
value = item.series.to_list()
media_type = "text/markdown"
elif isinstance(item, SklearnBaseEstimatorItem):
value = item.estimator_html_repr
media_type = "application/vnd.sklearn.estimator+html"
elif isinstance(item, MediaItem):
if "text" in item.media_type:
value = item.media_bytes.decode(encoding=item.media_encoding)
media_type = f"{item.media_type}"
else:
value = base64.b64encode(item.media_bytes).decode()
media_type = f"{item.media_type};base64"
elif isinstance(
item, (CrossValidationItem, CrossValidationAggregationItem)
):
value = base64.b64encode(item.plot_bytes).decode()
media_type = "application/vnd.plotly.v1+json;base64"
else:
raise ValueError(f"Item {item} is not a known item type.")

items[key].append(
SerializedItem(
media_type=media_type,
value=value,
updated_at=item.updated_at,
created_at=item.created_at,
)
)
def __pandas_dataframe_as_serializable(df: pandas.DataFrame):
return df.fillna("NaN").to_dict(orient="tight")


def __item_as_serializable(name: str, item: Item) -> SerializableItem:
if isinstance(item, PrimitiveItem):
value = item.primitive
media_type = "text/markdown"
elif isinstance(item, NumpyArrayItem):
value = item.array.tolist()
media_type = "text/markdown"
elif isinstance(item, PandasDataFrameItem):
value = __pandas_dataframe_as_serializable(item.dataframe)
media_type = "application/vnd.dataframe+json"
elif isinstance(item, PandasSeriesItem):
value = item.series.fillna("NaN").to_list()
media_type = "text/markdown"
elif isinstance(item, PolarsDataFrameItem):
value = __pandas_dataframe_as_serializable(item.dataframe.to_pandas())
media_type = "application/vnd.dataframe+json"
elif isinstance(item, PolarsSeriesItem):
value = item.series.to_list()
media_type = "text/markdown"
elif isinstance(item, SklearnBaseEstimatorItem):
value = item.estimator_html_repr
media_type = "application/vnd.sklearn.estimator+html"
elif isinstance(item, MediaItem):
if "text" in item.media_type:
value = item.media_bytes.decode(encoding=item.media_encoding)
media_type = f"{item.media_type}"
else:
value = base64.b64encode(item.media_bytes).decode()
media_type = f"{item.media_type};base64"
elif isinstance(item, (CrossValidationItem, CrossValidationAggregationItem)):
value = base64.b64encode(item.plot_bytes).decode()
media_type = "application/vnd.plotly.v1+json;base64"
else:
raise ValueError(f"Item {item} is not a known item type.")

return SerializableItem(
name=name,
media_type=media_type,
value=value,
updated_at=item.updated_at,
created_at=item.created_at,
)


def __project_as_serializable(project: Project) -> SerializableProject:
items = {
key: [
__item_as_serializable(key, item) for item in project.get_item_versions(key)
]
for key in project.list_item_keys()
}

views = {key: project.get_view(key).layout for key in project.list_view_keys()}

return SerializedProject(
items=dict(items),
return SerializableProject(
items=items,
views=views,
)

Expand All @@ -113,7 +121,7 @@ def pandas_dataframe_to_serializable(df: pandas.DataFrame):
async def get_items(request: Request):
"""Serialize a project and send it."""
project = request.app.state.project
return __serialize_project(project)
return __project_as_serializable(project)


@router.put("/views", status_code=status.HTTP_201_CREATED)
Expand All @@ -127,7 +135,7 @@ async def put_view(request: Request, key: str, layout: Layout):
view = View(layout=layout)
project.put_view(key, view)

return __serialize_project(project)
return __project_as_serializable(project)


@router.delete("/views", status_code=status.HTTP_202_ACCEPTED)
Expand All @@ -142,4 +150,27 @@ async def delete_view(request: Request, key: str):
status_code=status.HTTP_404_NOT_FOUND, detail="View not found"
) from None

return __serialize_project(project)
return __project_as_serializable(project)


@router.get("/activity")
async def get_activity(
request: Request,
after: datetime = datetime(1, 1, 1, 0, 0, 0, 0, timezone.utc),
):
"""Send all recent activity as a JSON array.
The activity is composed of all the items and their versions created after the
datetime `after`, sorted from newest to oldest.
"""
project = request.app.state.project
return sorted(
(
__item_as_serializable(key, version)
for key in project.list_item_keys()
for version in project.get_item_versions(key)
if datetime.fromisoformat(version.updated_at) > after
),
key=operator.attrgetter("updated_at"),
reverse=True,
)
43 changes: 43 additions & 0 deletions skore/tests/integration/ui/test_ui.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime

import numpy
import pandas
import polars
Expand Down Expand Up @@ -37,6 +39,7 @@ def test_get_items(client, in_memory_project):
"items": {
"test": [
{
"name": "test",
"media_type": "text/markdown",
"value": item.primitive,
"created_at": item.created_at,
Expand Down Expand Up @@ -141,3 +144,43 @@ def test_serialize_media_item(client, in_memory_project):
assert "image" in project["items"]["img"][0]["media_type"]
assert project["items"]["html"][0]["value"] == html
assert project["items"]["media html"][0]["value"] == html


def test_activity_feed(monkeypatch, client, in_memory_project):
class MockDatetime:
NOW = datetime.datetime.now(tz=datetime.timezone.utc)
TIMEDELTA = datetime.timedelta(days=1)

def __init__(self, *args, **kwargs): ...

@staticmethod
def now(*args, **kwargs):
MockDatetime.NOW += MockDatetime.TIMEDELTA
return MockDatetime.NOW

monkeypatch.setattr("skore.item.item.datetime", MockDatetime)

for i in range(5):
in_memory_project.put(str(i), i)

response = client.get("/api/project/activity")
assert response.status_code == 200
assert [(item["name"], item["value"]) for item in response.json()] == [
("4", 4),
("3", 3),
("2", 2),
("1", 1),
("0", 0),
]

now = MockDatetime.NOW # increments now

in_memory_project.put("4", 5)
in_memory_project.put("5", 5)

response = client.get("/api/project/activity", params={"after": now})
assert response.status_code == 200
assert [(item["name"], item["value"]) for item in response.json()] == [
("5", 5),
("4", 5),
]

0 comments on commit 2e3c616

Please sign in to comment.