Skip to content

Commit

Permalink
feat: Python API update results service (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
aneojgurhem authored Jan 4, 2024
2 parents b853582 + 60dde00 commit 7399481
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 8 deletions.
2 changes: 1 addition & 1 deletion packages/python/src/armonik/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .submitter import ArmoniKSubmitter
from .tasks import ArmoniKTasks, TaskFieldFilter
from .results import ArmoniKResult
from .results import ArmoniKResults, ResultFieldFilter
from .versions import ArmoniKVersions
167 changes: 162 additions & 5 deletions packages/python/src/armonik/client/results.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from __future__ import annotations
from grpc import Channel
from deprecation import deprecated

from typing import List, Dict, cast, Tuple

from ..protogen.client.results_service_pb2_grpc import ResultsStub
from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse
from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse, GetOwnerTaskIdRequest, GetOwnerTaskIdResponse, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, CreateResultsRequest, CreateResultsResponse, ResultsServiceConfigurationResponse, DeleteResultsDataRequest, DeleteResultsDataResponse, UploadResultDataRequest, UploadResultDataResponse, DownloadResultDataRequest, DownloadResultDataResponse, GetResultRequest, GetResultResponse
from ..protogen.common.results_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus
from ..protogen.common.results_fields_pb2 import ResultField
from ..protogen.common.objects_pb2 import Empty
from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter
from ..protogen.common.sort_direction_pb2 import SortDirection
from ..common import Direction , Result
from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS
from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS, RESULT_RAW_ENUM_FIELD_RESULT_ID
from ..common.helpers import batched


class ResultFieldFilter:
STATUS = StatusFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_STATUS)), rawFilters, rawFilterAnd, rawFilterField, rawFilterStatus)
RESULT_ID = StringFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_RESULT_ID)), rawFilters, rawFilterAnd, rawFilterField)

class ArmoniKResult:
class ArmoniKResults:
def __init__(self, grpc_channel: Channel):
""" Result service client
Expand All @@ -24,10 +29,11 @@ def __init__(self, grpc_channel: Channel):
"""
self._client = ResultsStub(grpc_channel)

@deprecated(deprecated_in="3.15.0", details="Use create_result_metadata or create_result insted.")
def get_results_ids(self, session_id: str, names: List[str]) -> Dict[str, str]:
return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=session_id))).results}

def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = ResultFieldFilter.STATUS,sort_direction: SortDirection = Direction.ASC ) -> Tuple[int, List[Result]]:
def list_results(self, result_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = ResultFieldFilter.STATUS,sort_direction: SortDirection = Direction.ASC ) -> Tuple[int, List[Result]]:
"""List results based on a filter.
Args:
Expand All @@ -44,8 +50,159 @@ def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 10
request: ListResultsRequest = ListResultsRequest(
page=page,
page_size=page_size,
filters=cast(rawFilters, result_filter.to_disjunction().to_message()),
filters=cast(rawFilters, result_filter.to_disjunction().to_message()) if result_filter else None,
sort=ListResultsRequest.Sort(field=cast(ResultField, sort_field.field), direction=sort_direction),
)
list_response: ListResultsResponse = self._client.ListResults(request)
return list_response.total, [Result.from_message(r) for r in list_response.results]

def get_result(self, result_id: str) -> Result:
"""Get a result by id.
Args:
result_id: The ID of the result.
Return:
The result summary.
"""
request = GetResultRequest(result_id=result_id)
response: GetResultResponse = self._client.GetResult(request)
return Result.from_message(response.result)

def get_owner_task_id(self, result_ids: List[str], session_id: str, batch_size: int = 500) -> Dict[str, str]:
"""Get the IDs of the tasks that should produce the results.
Args:
result_ids: A list of results.
session_id: The ID of the session to which the results belongs.
batch_size: Batch size for querying.
Return:
A dictionnary mapping results to owner task ID.
"""
results = {}
for result_ids_batch in batched(result_ids, batch_size):
request = GetOwnerTaskIdRequest(session_id=session_id, result_id=result_ids_batch)
response: GetOwnerTaskIdResponse = self._client.GetOwnerTaskId(request)
for result_task in response.result_task:
results[result_task.result_id] = result_task.task_id
return results

def create_results_metadata(self, result_names: List[str], session_id: str, batch_size: int = 100) -> Dict[str, Result]:
"""Create the metadata of multiple results at once.
Data have to be uploaded separately.
Args:
result_names: The list of the names of the results to create.
session_id: The ID of the session to which the results belongs.
batch_size: Batch size for querying.
Return:
A dictionnary mapping each result name to its corresponding result summary.
"""
results = {}
for result_names_batch in batched(result_names, batch_size):
request = CreateResultsMetaDataRequest(
results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names_batch],
session_id=session_id
)
response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request)
for result_message in response.results:
results[result_message.name] = Result.from_message(result_message)
return results

def create_results(self, results_data: Dict[str, bytes], session_id: str, batch_size: int = 1) -> Dict[str, Result]:
"""Create one result with data included in the request.
Args:
results_data: A dictionnary mapping the result names to their actual data.
session_id: The ID of the session to which the results belongs.
batch_size: Batch size for querying.
Return:
A dictionnary mappin each result name to its corresponding result summary.
"""
results = {}
for results_names_batch in batched(results_data.keys(), batch_size):
request = CreateResultsRequest(
results=[CreateResultsRequest.ResultCreate(name=name, data=results_data[name]) for name in results_names_batch],
session_id=session_id
)
response: CreateResultsResponse = self._client.CreateResults(request)
for message in response.results:
results[message.name] = Result.from_message(message)
return results

def upload_result_data(self, result_id: str, session_id: str, result_data: bytes | bytearray) -> None:
"""Upload data for an empty result already created.
Args:
result_id: The ID of the result.
result_data: The result data.
session_id: The ID of the session.
"""
data_chunk_max_size = self.get_service_config()

def upload_result_stream():
request = UploadResultDataRequest(
id=UploadResultDataRequest.ResultIdentifier(
session_id=session_id, result_id=result_id
)
)
yield request

start = 0
data_len = len(result_data)
while start < data_len:
chunk_size = min(data_chunk_max_size, data_len - start)
request = UploadResultDataRequest(
data_chunk=result_data[start : start + chunk_size]
)
yield request
start += chunk_size

self._client.UploadResultData(upload_result_stream())

def download_result_data(self, result_id: str, session_id: str) -> bytes:
"""Retrieve data of a result.
Args:
result_id: The ID of the result.
session_id: The session of the result.
Return:
Result data.
"""
request = DownloadResultDataRequest(
result_id=result_id,
session_id=session_id
)
streaming_call = self._client.DownloadResultData(request)
return b''.join([message.data_chunk for message in streaming_call])

def delete_result_data(self, result_ids: List[str], session_id: str, batch_size: int = 100) -> None:
"""Delete data from multiple results
Args:
result_ids: The IDs of the results which data must be deleted.
session_id: The ID of the session to which the results belongs.
batch_size: Batch size for querying.
"""
for result_ids_batch in batched(result_ids, batch_size):
request = DeleteResultsDataRequest(
result_id=result_ids_batch,
session_id=session_id
)
self._client.DeleteResultsData(request)

def get_service_config(self) -> int:
"""Get the configuration of the service.
Return:
Maximum size supported by a data chunk for the result service.
"""
response: ResultsServiceConfigurationResponse = self._client.GetServiceConfiguration(Empty())
return response.data_chunk_max_size

def watch_results(self):
raise NotImplementedError()
2 changes: 1 addition & 1 deletion packages/python/src/armonik/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter
from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result
from .enumwrapper import HealthCheckStatus, TaskStatus, Direction
from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, ResultStatus
from .filter import StringFilter, StatusFilter
4 changes: 3 additions & 1 deletion packages/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import requests

from armonik.client import ArmoniKTasks, ArmoniKVersions
from armonik.client import ArmoniKResults, ArmoniKTasks, ArmoniKVersions
from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub
from typing import List, Union

Expand Down Expand Up @@ -75,6 +75,8 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK
"""
channel = grpc.insecure_channel(endpoint).__enter__()
match client_name:
case "Results":
return ArmoniKResults(channel)
case "Tasks":
return ArmoniKTasks(channel)
case "Versions":
Expand Down
116 changes: 116 additions & 0 deletions packages/python/tests/test_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import datetime
import pytest
import warnings

from .conftest import all_rpc_called, rpc_called, get_client
from armonik.client import ArmoniKResults, ResultFieldFilter
from armonik.common import Result, ResultStatus


class TestArmoniKResults:

def test_get_result(self):
results_client: ArmoniKResults = get_client("Results")
result = results_client.get_result("result-name")

assert rpc_called("Results", "GetResult")
assert isinstance(result, Result)
assert result.session_id == 'session-id'
assert result.name == 'result-name'
assert result.owner_task_id == 'owner-task-id'
assert result.status == 2
assert result.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert result.completed_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert result.result_id == 'result-id'
assert result.size == 0

def test_get_owner_task_id(self):
results_client: ArmoniKResults = get_client("Results")
results_tasks = results_client.get_owner_task_id(["result-id"], "session-id")

assert rpc_called("Results", "GetOwnerTaskId")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert results_tasks == {}

def test_list_results_no_filter(self):
results_client: ArmoniKResults = get_client("Results")
num, results = results_client.list_results()

assert rpc_called("Results", "ListResults")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert results == []

def test_list_results_with_filter(self):
results_client: ArmoniKResults = get_client("Results")
num, results = results_client.list_results(ResultFieldFilter.STATUS == ResultStatus.COMPLETED)

assert rpc_called("Results", "ListResults", 2)
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert results == []

def test_create_results_metadata(self):
results_client: ArmoniKResults = get_client("Results")
results = results_client.create_results_metadata(["result-name"], "session-id")

assert rpc_called("Results", "CreateResultsMetaData")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert results == {}

def test_create_results(self):
results_client: ArmoniKResults = get_client("Results")
results = results_client.create_results({"result-name": b"test data"}, "session-id")

assert rpc_called("Results", "CreateResults")
assert results == {}

def test_get_service_config(self):
results_client: ArmoniKResults = get_client("Results")
chunk_size = results_client.get_service_config()

assert rpc_called("Results", "GetServiceConfiguration")
assert isinstance(chunk_size, int)
assert chunk_size == 81920

def test_upload_result_data(self):
results_client: ArmoniKResults = get_client("Results")
result = results_client.upload_result_data("result-name", "session-id", b"test data")

assert rpc_called("Results", "UploadResultData")
assert result is None

def test_download_result_data(self):
results_client: ArmoniKResults = get_client("Results")
data = results_client.download_result_data("result-name", "session-id")

assert rpc_called("Results", "DownloadResultData")
assert data == b""

def test_delete_result_data(self):
results_client: ArmoniKResults = get_client("Results")
result = results_client.delete_result_data(["result-name"], "session-id")

assert rpc_called("Results", "DeleteResultsData")
assert result is None

def test_watch_results(self):
results_client: ArmoniKResults = get_client("Results")
with pytest.raises(NotImplementedError, match=""):
results_client.watch_results()
assert rpc_called("Results", "WatchResults", 0)

def test_get_results_ids(self):
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")

results_client: ArmoniKResults = get_client("Results")
results = results_client.get_results_ids("session-id", ["result_1"])

assert issubclass(w[-1].category, DeprecationWarning)
assert rpc_called("Results", "CreateResultsMetaData", 2)
assert results == {}

def test_service_fully_implemented(self):
assert all_rpc_called("Results", missings=["WatchResults"])

0 comments on commit 7399481

Please sign in to comment.