Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a new backend API for activity feed #857

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 100 additions & 69 deletions skore/src/skore/ui/project_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,28 @@
from __future__ import annotations

import base64
from collections import defaultdict
import operator
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any

from fastapi import APIRouter, HTTPException, Request, status

from skore.item import (
Item,
MediaItem,
NumpyArrayItem,
PandasDataFrameItem,
PandasSeriesItem,
PolarsDataFrameItem,
PolarsSeriesItem,
PrimitiveItem,
SklearnBaseEstimatorItem,
)
from skore.item.cross_validation_item import (
CrossValidationAggregationItem,
CrossValidationItem,
)
from skore.item.media_item import MediaItem
from skore.item.numpy_array_item import NumpyArrayItem
from skore.item.pandas_dataframe_item import PandasDataFrameItem
from skore.item.pandas_series_item import PandasSeriesItem
from skore.item.polars_dataframe_item import PolarsDataFrameItem
from skore.item.polars_series_item import PolarsSeriesItem
from skore.item.primitive_item import PrimitiveItem
from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem
from skore.project import Project
from skore.view.view import Layout, View

Expand All @@ -31,80 +35,84 @@


@dataclass
class SerializedItem:
class SerializableItem:
"""Serialized item."""

name: str
media_type: str
value: Any
updated_at: str
created_at: str


@dataclass
class SerializedProject:
class SerializableProject:
"""Serialized project, to be sent to the skore-ui."""

items: dict[str, list[SerializedItem]]
items: dict[str, list[SerializableItem]]
views: dict[str, Layout]


def __serialize_project(project: Project) -> SerializedProject:
items = defaultdict(list)

def pandas_dataframe_to_serializable(df: pandas.DataFrame):
return df.fillna("NaN").to_dict(orient="tight")

for key in project.list_item_keys():
for item in project.get_item_versions(key):
if isinstance(item, PrimitiveItem):
value = item.primitive
media_type = "text/markdown"
elif isinstance(item, NumpyArrayItem):
value = item.array.tolist()
media_type = "text/markdown"
elif isinstance(item, PandasDataFrameItem):
value = pandas_dataframe_to_serializable(item.dataframe)
media_type = "application/vnd.dataframe+json"
elif isinstance(item, PandasSeriesItem):
value = item.series.fillna("NaN").to_list()
media_type = "text/markdown"
elif isinstance(item, PolarsDataFrameItem):
value = pandas_dataframe_to_serializable(item.dataframe.to_pandas())
media_type = "application/vnd.dataframe+json"
elif isinstance(item, PolarsSeriesItem):
value = item.series.to_list()
media_type = "text/markdown"
elif isinstance(item, SklearnBaseEstimatorItem):
value = item.estimator_html_repr
media_type = "application/vnd.sklearn.estimator+html"
elif isinstance(item, MediaItem):
if "text" in item.media_type:
value = item.media_bytes.decode(encoding=item.media_encoding)
media_type = f"{item.media_type}"
else:
value = base64.b64encode(item.media_bytes).decode()
media_type = f"{item.media_type};base64"
elif isinstance(
item, (CrossValidationItem, CrossValidationAggregationItem)
):
value = base64.b64encode(item.plot_bytes).decode()
media_type = "application/vnd.plotly.v1+json;base64"
else:
raise ValueError(f"Item {item} is not a known item type.")

items[key].append(
SerializedItem(
media_type=media_type,
value=value,
updated_at=item.updated_at,
created_at=item.created_at,
)
)
def __pandas_dataframe_as_serializable(df: pandas.DataFrame):
return df.fillna("NaN").to_dict(orient="tight")


def __item_as_serializable(name: str, item: Item) -> SerializableItem:
if isinstance(item, PrimitiveItem):
value = item.primitive
media_type = "text/markdown"
elif isinstance(item, NumpyArrayItem):
value = item.array.tolist()
media_type = "text/markdown"
elif isinstance(item, PandasDataFrameItem):
value = __pandas_dataframe_as_serializable(item.dataframe)
media_type = "application/vnd.dataframe+json"
elif isinstance(item, PandasSeriesItem):
value = item.series.fillna("NaN").to_list()
media_type = "text/markdown"
elif isinstance(item, PolarsDataFrameItem):
value = __pandas_dataframe_as_serializable(item.dataframe.to_pandas())
media_type = "application/vnd.dataframe+json"
elif isinstance(item, PolarsSeriesItem):
value = item.series.to_list()
media_type = "text/markdown"
elif isinstance(item, SklearnBaseEstimatorItem):
value = item.estimator_html_repr
media_type = "application/vnd.sklearn.estimator+html"
elif isinstance(item, MediaItem):
if "text" in item.media_type:
value = item.media_bytes.decode(encoding=item.media_encoding)
media_type = f"{item.media_type}"
else:
value = base64.b64encode(item.media_bytes).decode()
media_type = f"{item.media_type};base64"
elif isinstance(item, (CrossValidationItem, CrossValidationAggregationItem)):
value = base64.b64encode(item.plot_bytes).decode()
media_type = "application/vnd.plotly.v1+json;base64"
else:
raise ValueError(f"Item {item} is not a known item type.")

return SerializableItem(
name=name,
media_type=media_type,
value=value,
updated_at=item.updated_at,
created_at=item.created_at,
)


def __project_as_serializable(project: Project) -> SerializableProject:
items = {
key: [
__item_as_serializable(key, item) for item in project.get_item_versions(key)
]
for key in project.list_item_keys()
}

views = {key: project.get_view(key).layout for key in project.list_view_keys()}

return SerializedProject(
items=dict(items),
return SerializableProject(
items=items,
views=views,
)

Expand All @@ -113,7 +121,7 @@ def pandas_dataframe_to_serializable(df: pandas.DataFrame):
async def get_items(request: Request):
"""Serialize a project and send it."""
project = request.app.state.project
return __serialize_project(project)
return __project_as_serializable(project)


@router.put("/views", status_code=status.HTTP_201_CREATED)
Expand All @@ -127,7 +135,7 @@ async def put_view(request: Request, key: str, layout: Layout):
view = View(layout=layout)
project.put_view(key, view)

return __serialize_project(project)
return __project_as_serializable(project)


@router.delete("/views", status_code=status.HTTP_202_ACCEPTED)
Expand All @@ -142,4 +150,27 @@ async def delete_view(request: Request, key: str):
status_code=status.HTTP_404_NOT_FOUND, detail="View not found"
) from None

return __serialize_project(project)
return __project_as_serializable(project)


@router.get("/activity")
async def get_activity(
request: Request,
after: datetime = datetime(1, 1, 1, 0, 0, 0, 0, timezone.utc),
):
"""Send all recent activity as a JSON array.

The activity is composed of all the items and their versions created after the
datetime `after`, sorted from newest to oldest.
"""
project = request.app.state.project
return sorted(
(
__item_as_serializable(key, version)
for key in project.list_item_keys()
for version in project.get_item_versions(key)
if datetime.fromisoformat(version.updated_at) > after
),
key=operator.attrgetter("updated_at"),
augustebaum marked this conversation as resolved.
Show resolved Hide resolved
reverse=True,
)
43 changes: 43 additions & 0 deletions skore/tests/integration/ui/test_ui.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime

import numpy
import pandas
import polars
Expand Down Expand Up @@ -37,6 +39,7 @@ def test_get_items(client, in_memory_project):
"items": {
"test": [
{
"name": "test",
"media_type": "text/markdown",
"value": item.primitive,
"created_at": item.created_at,
Expand Down Expand Up @@ -141,3 +144,43 @@ def test_serialize_media_item(client, in_memory_project):
assert "image" in project["items"]["img"][0]["media_type"]
assert project["items"]["html"][0]["value"] == html
assert project["items"]["media html"][0]["value"] == html


def test_activity_feed(monkeypatch, client, in_memory_project):
class MockDatetime:
NOW = datetime.datetime.now(tz=datetime.timezone.utc)
TIMEDELTA = datetime.timedelta(days=1)

def __init__(self, *args, **kwargs): ...

@staticmethod
def now(*args, **kwargs):
MockDatetime.NOW += MockDatetime.TIMEDELTA
return MockDatetime.NOW

monkeypatch.setattr("skore.item.item.datetime", MockDatetime)

for i in range(5):
in_memory_project.put(str(i), i)

response = client.get("/api/project/activity")
assert response.status_code == 200
assert [(item["name"], item["value"]) for item in response.json()] == [
("4", 4),
("3", 3),
("2", 2),
("1", 1),
("0", 0),
]

now = MockDatetime.NOW # increments now

in_memory_project.put("4", 5)
in_memory_project.put("5", 5)

response = client.get("/api/project/activity", params={"after": now})
augustebaum marked this conversation as resolved.
Show resolved Hide resolved
assert response.status_code == 200
assert [(item["name"], item["value"]) for item in response.json()] == [
("5", 5),
("4", 5),
]
Loading