From 60dde00c260a5591f8bb40eabfe8a5b807ad8d2f Mon Sep 17 00:00:00 2001 From: qdelamea Date: Thu, 4 Jan 2024 12:02:15 +0100 Subject: [PATCH] feat: Python API update results service --- .../python/src/armonik/client/__init__.py | 2 +- packages/python/src/armonik/client/results.py | 167 +++++++++++++++++- .../python/src/armonik/common/__init__.py | 2 +- packages/python/tests/conftest.py | 4 +- packages/python/tests/test_results.py | 116 ++++++++++++ 5 files changed, 283 insertions(+), 8 deletions(-) create mode 100644 packages/python/tests/test_results.py diff --git a/packages/python/src/armonik/client/__init__.py b/packages/python/src/armonik/client/__init__.py index 8b74161dc..a4a7cd74e 100644 --- a/packages/python/src/armonik/client/__init__.py +++ b/packages/python/src/armonik/client/__init__.py @@ -1,4 +1,4 @@ from .submitter import ArmoniKSubmitter from .tasks import ArmoniKTasks, TaskFieldFilter -from .results import ArmoniKResult +from .results import ArmoniKResults, ResultFieldFilter from .versions import ArmoniKVersions diff --git a/packages/python/src/armonik/client/results.py b/packages/python/src/armonik/client/results.py index 942add9c1..1937bba7b 100644 --- a/packages/python/src/armonik/client/results.py +++ b/packages/python/src/armonik/client/results.py @@ -1,21 +1,26 @@ from __future__ import annotations from grpc import Channel +from deprecation import deprecated from typing import List, Dict, cast, Tuple from ..protogen.client.results_service_pb2_grpc import ResultsStub -from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse +from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse, GetOwnerTaskIdRequest, GetOwnerTaskIdResponse, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, CreateResultsRequest, CreateResultsResponse, ResultsServiceConfigurationResponse, DeleteResultsDataRequest, DeleteResultsDataResponse, UploadResultDataRequest, UploadResultDataResponse, DownloadResultDataRequest, DownloadResultDataResponse, GetResultRequest, GetResultResponse from ..protogen.common.results_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus from ..protogen.common.results_fields_pb2 import ResultField +from ..protogen.common.objects_pb2 import Empty from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter from ..protogen.common.sort_direction_pb2 import SortDirection from ..common import Direction , Result -from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS +from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS, RESULT_RAW_ENUM_FIELD_RESULT_ID +from ..common.helpers import batched + class ResultFieldFilter: STATUS = StatusFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_STATUS)), rawFilters, rawFilterAnd, rawFilterField, rawFilterStatus) + RESULT_ID = StringFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_RESULT_ID)), rawFilters, rawFilterAnd, rawFilterField) -class ArmoniKResult: +class ArmoniKResults: def __init__(self, grpc_channel: Channel): """ Result service client @@ -24,10 +29,11 @@ def __init__(self, grpc_channel: Channel): """ self._client = ResultsStub(grpc_channel) + @deprecated(deprecated_in="3.15.0", details="Use create_result_metadata or create_result insted.") def get_results_ids(self, session_id: str, names: List[str]) -> Dict[str, str]: return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=session_id))).results} - def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = ResultFieldFilter.STATUS,sort_direction: SortDirection = Direction.ASC ) -> Tuple[int, List[Result]]: + def list_results(self, result_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = ResultFieldFilter.STATUS,sort_direction: SortDirection = Direction.ASC ) -> Tuple[int, List[Result]]: """List results based on a filter. Args: @@ -44,8 +50,159 @@ def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 10 request: ListResultsRequest = ListResultsRequest( page=page, page_size=page_size, - filters=cast(rawFilters, result_filter.to_disjunction().to_message()), + filters=cast(rawFilters, result_filter.to_disjunction().to_message()) if result_filter else None, sort=ListResultsRequest.Sort(field=cast(ResultField, sort_field.field), direction=sort_direction), ) list_response: ListResultsResponse = self._client.ListResults(request) return list_response.total, [Result.from_message(r) for r in list_response.results] + + def get_result(self, result_id: str) -> Result: + """Get a result by id. + + Args: + result_id: The ID of the result. + + Return: + The result summary. + """ + request = GetResultRequest(result_id=result_id) + response: GetResultResponse = self._client.GetResult(request) + return Result.from_message(response.result) + + def get_owner_task_id(self, result_ids: List[str], session_id: str, batch_size: int = 500) -> Dict[str, str]: + """Get the IDs of the tasks that should produce the results. + + Args: + result_ids: A list of results. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + A dictionnary mapping results to owner task ID. + """ + results = {} + for result_ids_batch in batched(result_ids, batch_size): + request = GetOwnerTaskIdRequest(session_id=session_id, result_id=result_ids_batch) + response: GetOwnerTaskIdResponse = self._client.GetOwnerTaskId(request) + for result_task in response.result_task: + results[result_task.result_id] = result_task.task_id + return results + + def create_results_metadata(self, result_names: List[str], session_id: str, batch_size: int = 100) -> Dict[str, Result]: + """Create the metadata of multiple results at once. + Data have to be uploaded separately. + + Args: + result_names: The list of the names of the results to create. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + A dictionnary mapping each result name to its corresponding result summary. + """ + results = {} + for result_names_batch in batched(result_names, batch_size): + request = CreateResultsMetaDataRequest( + results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names_batch], + session_id=session_id + ) + response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request) + for result_message in response.results: + results[result_message.name] = Result.from_message(result_message) + return results + + def create_results(self, results_data: Dict[str, bytes], session_id: str, batch_size: int = 1) -> Dict[str, Result]: + """Create one result with data included in the request. + + Args: + results_data: A dictionnary mapping the result names to their actual data. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + A dictionnary mappin each result name to its corresponding result summary. + """ + results = {} + for results_names_batch in batched(results_data.keys(), batch_size): + request = CreateResultsRequest( + results=[CreateResultsRequest.ResultCreate(name=name, data=results_data[name]) for name in results_names_batch], + session_id=session_id + ) + response: CreateResultsResponse = self._client.CreateResults(request) + for message in response.results: + results[message.name] = Result.from_message(message) + return results + + def upload_result_data(self, result_id: str, session_id: str, result_data: bytes | bytearray) -> None: + """Upload data for an empty result already created. + + Args: + result_id: The ID of the result. + result_data: The result data. + session_id: The ID of the session. + """ + data_chunk_max_size = self.get_service_config() + + def upload_result_stream(): + request = UploadResultDataRequest( + id=UploadResultDataRequest.ResultIdentifier( + session_id=session_id, result_id=result_id + ) + ) + yield request + + start = 0 + data_len = len(result_data) + while start < data_len: + chunk_size = min(data_chunk_max_size, data_len - start) + request = UploadResultDataRequest( + data_chunk=result_data[start : start + chunk_size] + ) + yield request + start += chunk_size + + self._client.UploadResultData(upload_result_stream()) + + def download_result_data(self, result_id: str, session_id: str) -> bytes: + """Retrieve data of a result. + + Args: + result_id: The ID of the result. + session_id: The session of the result. + + Return: + Result data. + """ + request = DownloadResultDataRequest( + result_id=result_id, + session_id=session_id + ) + streaming_call = self._client.DownloadResultData(request) + return b''.join([message.data_chunk for message in streaming_call]) + + def delete_result_data(self, result_ids: List[str], session_id: str, batch_size: int = 100) -> None: + """Delete data from multiple results + + Args: + result_ids: The IDs of the results which data must be deleted. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + """ + for result_ids_batch in batched(result_ids, batch_size): + request = DeleteResultsDataRequest( + result_id=result_ids_batch, + session_id=session_id + ) + self._client.DeleteResultsData(request) + + def get_service_config(self) -> int: + """Get the configuration of the service. + + Return: + Maximum size supported by a data chunk for the result service. + """ + response: ResultsServiceConfigurationResponse = self._client.GetServiceConfiguration(Empty()) + return response.data_chunk_max_size + + def watch_results(self): + raise NotImplementedError() diff --git a/packages/python/src/armonik/common/__init__.py b/packages/python/src/armonik/common/__init__.py index 001721868..5d44f4c9f 100644 --- a/packages/python/src/armonik/common/__init__.py +++ b/packages/python/src/armonik/common/__init__.py @@ -1,4 +1,4 @@ from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result -from .enumwrapper import HealthCheckStatus, TaskStatus, Direction +from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, ResultStatus from .filter import StringFilter, StatusFilter diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py index b4bb0af6f..b60a02483 100644 --- a/packages/python/tests/conftest.py +++ b/packages/python/tests/conftest.py @@ -3,7 +3,7 @@ import pytest import requests -from armonik.client import ArmoniKTasks, ArmoniKVersions +from armonik.client import ArmoniKResults, ArmoniKTasks, ArmoniKVersions from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub from typing import List, Union @@ -75,6 +75,8 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK """ channel = grpc.insecure_channel(endpoint).__enter__() match client_name: + case "Results": + return ArmoniKResults(channel) case "Tasks": return ArmoniKTasks(channel) case "Versions": diff --git a/packages/python/tests/test_results.py b/packages/python/tests/test_results.py new file mode 100644 index 000000000..cb3dd7d0f --- /dev/null +++ b/packages/python/tests/test_results.py @@ -0,0 +1,116 @@ +import datetime +import pytest +import warnings + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKResults, ResultFieldFilter +from armonik.common import Result, ResultStatus + + +class TestArmoniKResults: + + def test_get_result(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.get_result("result-name") + + assert rpc_called("Results", "GetResult") + assert isinstance(result, Result) + assert result.session_id == 'session-id' + assert result.name == 'result-name' + assert result.owner_task_id == 'owner-task-id' + assert result.status == 2 + assert result.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert result.completed_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert result.result_id == 'result-id' + assert result.size == 0 + + def test_get_owner_task_id(self): + results_client: ArmoniKResults = get_client("Results") + results_tasks = results_client.get_owner_task_id(["result-id"], "session-id") + + assert rpc_called("Results", "GetOwnerTaskId") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results_tasks == {} + + def test_list_results_no_filter(self): + results_client: ArmoniKResults = get_client("Results") + num, results = results_client.list_results() + + assert rpc_called("Results", "ListResults") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert results == [] + + def test_list_results_with_filter(self): + results_client: ArmoniKResults = get_client("Results") + num, results = results_client.list_results(ResultFieldFilter.STATUS == ResultStatus.COMPLETED) + + assert rpc_called("Results", "ListResults", 2) + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert results == [] + + def test_create_results_metadata(self): + results_client: ArmoniKResults = get_client("Results") + results = results_client.create_results_metadata(["result-name"], "session-id") + + assert rpc_called("Results", "CreateResultsMetaData") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results == {} + + def test_create_results(self): + results_client: ArmoniKResults = get_client("Results") + results = results_client.create_results({"result-name": b"test data"}, "session-id") + + assert rpc_called("Results", "CreateResults") + assert results == {} + + def test_get_service_config(self): + results_client: ArmoniKResults = get_client("Results") + chunk_size = results_client.get_service_config() + + assert rpc_called("Results", "GetServiceConfiguration") + assert isinstance(chunk_size, int) + assert chunk_size == 81920 + + def test_upload_result_data(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.upload_result_data("result-name", "session-id", b"test data") + + assert rpc_called("Results", "UploadResultData") + assert result is None + + def test_download_result_data(self): + results_client: ArmoniKResults = get_client("Results") + data = results_client.download_result_data("result-name", "session-id") + + assert rpc_called("Results", "DownloadResultData") + assert data == b"" + + def test_delete_result_data(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.delete_result_data(["result-name"], "session-id") + + assert rpc_called("Results", "DeleteResultsData") + assert result is None + + def test_watch_results(self): + results_client: ArmoniKResults = get_client("Results") + with pytest.raises(NotImplementedError, match=""): + results_client.watch_results() + assert rpc_called("Results", "WatchResults", 0) + + def test_get_results_ids(self): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + results_client: ArmoniKResults = get_client("Results") + results = results_client.get_results_ids("session-id", ["result_1"]) + + assert issubclass(w[-1].category, DeprecationWarning) + assert rpc_called("Results", "CreateResultsMetaData", 2) + assert results == {} + + def test_service_fully_implemented(self): + assert all_rpc_called("Results", missings=["WatchResults"])