diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b91f05761..19579998f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -231,6 +231,17 @@ jobs: - name: Install dependencies run: pip install "$(echo pkg/armonik*.whl)[tests]" + - name: Install .NET Core + uses: actions/setup-dotnet@3447fd6a9f9e57506b15f895c5b76d3b197dc7c2 # v3 + with: + dotnet-version: 6.x + + - name: Start Mock server + run: | + cd ../csharp/ArmoniK.Api.Mock + nohup dotnet run > /dev/null 2>&1 & + sleep 60 + - name: Run tests run: python -m pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-report xml:coverage.xml --cov-report html:coverage_report diff --git a/packages/python/src/armonik/common/helpers.py b/packages/python/src/armonik/common/helpers.py index e174e2f42..3a9cb8324 100644 --- a/packages/python/src/armonik/common/helpers.py +++ b/packages/python/src/armonik/common/helpers.py @@ -1,6 +1,6 @@ from __future__ import annotations from datetime import timedelta, datetime, timezone -from typing import List, Optional +from typing import List, Optional, Iterable, TypeVar import google.protobuf.duration_pb2 as duration import google.protobuf.timestamp_pb2 as timestamp @@ -9,6 +9,9 @@ from .enumwrapper import TaskStatus +T = TypeVar('T') + + def get_task_filter(session_ids: Optional[List[str]] = None, task_ids: Optional[List[str]] = None, included_statuses: Optional[List[TaskStatus]] = None, excluded_statuses: Optional[List[TaskStatus]] = None) -> TaskFilter: @@ -96,3 +99,29 @@ def timedelta_to_duration(delta: timedelta) -> duration.Duration: d = duration.Duration() d.FromTimedelta(delta) return d + + +def batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]: + """ + Batches elements from an iterable into lists of size at most 'n'. + + Args: + iterable : The input iterable. + n : The batch size. + + Yields: + A generator yielding batches of elements from the input iterable. + """ + it = iter(iterable) + + sentinel = object() + batch = [] + c = next(it, sentinel) + while c is not sentinel: + batch.append(c) + if len(batch) == n: + yield batch + batch.clear() + c = next(it, sentinel) + if len(batch) > 0: + yield batch diff --git a/packages/python/tests/common.py b/packages/python/tests/common.py deleted file mode 100644 index db2feed4a..000000000 --- a/packages/python/tests/common.py +++ /dev/null @@ -1,40 +0,0 @@ -from grpc import Channel - - -class DummyChannel(Channel): - def __init__(self): - self.method_dict = {} - - def stream_unary(self, *args, **kwargs): - return self.get_method(args[0]) - - def unary_stream(self, *args, **kwargs): - return self.get_method(args[0]) - - def unary_unary(self, *args, **kwargs): - return self.get_method(args[0]) - - def stream_stream(self, *args, **kwargs): - return self.get_method(args[0]) - - def set_instance(self, instance): - self.method_dict = {func: getattr(instance, func) for func in dir(type(instance)) if callable(getattr(type(instance), func)) and not func.startswith("__")} - - def get_method(self, name: str): - return self.method_dict.get(name.split("/")[-1], None) - - def subscribe(self, callback, try_to_connect=False): - pass - - def unsubscribe(self, callback): - pass - - def close(self): - pass - - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py new file mode 100644 index 000000000..c7eb15a79 --- /dev/null +++ b/packages/python/tests/conftest.py @@ -0,0 +1,123 @@ +import grpc +import os +import pytest +import requests + +from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub +from typing import List + + +# Mock server endpoints used for the tests. +grpc_endpoint = "localhost:5001" +calls_endpoint = "http://localhost:5000/calls.json" +reset_endpoint = "http://localhost:5000/reset" +data_folder = os.getcwd() + + +@pytest.fixture(scope="session", autouse=True) +def clean_up(request): + """ + This fixture runs at the session scope and is automatically used before and after + running all the tests. It set up and teardown the testing environments by: + - creating dummy files before testing begins; + - clear files after testing; + - resets the mocking gRPC server counters to maintain a clean testing environment. + + Yields: + None: This fixture is used as a context manager, and the test code runs between + the 'yield' statement and the cleanup code. + + Raises: + requests.exceptions.HTTPError: If an error occurs when attempting to reset + the mocking gRPC server counters. + """ + # Write dumm payload and data dependency to files for testing purposes + with open(os.path.join(data_folder, "payload-id"), "wb") as f: + f.write("payload".encode()) + with open(os.path.join(data_folder, "dd-id"), "wb") as f: + f.write("dd".encode()) + + # Run all the tests + yield + + # Remove the temporary files created for testing + os.remove(os.path.join(data_folder, "payload-id")) + os.remove(os.path.join(data_folder, "dd-id")) + + # Reset the mock server counters + try: + response = requests.post(reset_endpoint) + response.raise_for_status() + print("\nMock server resetted.") + except requests.exceptions.HTTPError as e: + print("An error occurred when resetting the server: " + str(e)) + + +def rpc_called(service_name: str, rpc_name: str, n_calls: int = 1, endpoint: str = calls_endpoint) -> bool: + """Check if a remote procedure call (RPC) has been made a specified number of times. + This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint. + + Args: + service_name (str): The name of the service providing the RPC. + rpc_name (str): The name of the specific RPC to check for the number of calls. + n_calls (int, optional): The expected number of times the RPC should have been called. Default is 1. + endpoint (str, optional): The URL of the remote service providing RPC information. Default to + calls_endpoint. + + Returns: + bool: True if the specified RPC has been called the expected number of times, False otherwise. + + Raises: + requests.exceptions.RequestException: If an error occurs when requesting ArmoniK.Api.Mock. + + Example: + >>> rpc_called('http://localhost:5000/calls.json', 'Versions', 'ListVersionss', 0) + True + """ + response = requests.get(endpoint) + response.raise_for_status() + data = response.json() + + # Check if the RPC has been called n_calls times + if data[service_name][rpc_name] == n_calls: + return True + return False + + +def all_rpc_called(service_name: str, missings: List[str] = [], endpoint: str = calls_endpoint) -> bool: + """ + Check if all remote procedure calls (RPCs) in a service have been made at least once. + This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint. + + Args: + service_name (str): The name of the service containing the RPC information in the response. + endpoint (str, optional): The URL of the remote service providing RPC information. Default is + the value of calls_endpoint. + missings (List[str], optional): A list of RPCs known to be not implemented. Default is an empty list. + + Returns: + bool: True if all RPCs in the specified service have been called at least once, False otherwise. + + Raises: + requests.exceptions.RequestException: If an error occurs when requesting ArmoniK.Api.Mock. + + Example: + >>> all_rpc_called('http://localhost:5000/calls.json', 'Versions') + False + """ + response = requests.get(endpoint) + response.raise_for_status() + data = response.json() + + missing_rpcs = [] + + # Check if all RPCs in the service have been called at least once + for rpc_name, rpc_num_calls in data[service_name].items(): + if rpc_num_calls == 0: + missing_rpcs.append(rpc_name) + if missing_rpcs: + if missings == missing_rpcs: + return True + print(f"RPCs not implemented in {service_name} service: {missing_rpcs}.") + return False + return True diff --git a/packages/python/tests/submitter_test.py b/packages/python/tests/submitter_test.py deleted file mode 100644 index c849efdc0..000000000 --- a/packages/python/tests/submitter_test.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python3 -import datetime -import logging -import pytest -from armonik.client import ArmoniKSubmitter -from typing import Iterator, Optional, List -from .common import DummyChannel -from armonik.common import TaskOptions, TaskDefinition, TaskStatus, timedelta_to_duration -from armonik.protogen.client.submitter_service_pb2_grpc import SubmitterStub -from armonik.protogen.common.objects_pb2 import Empty, Configuration, Session, TaskIdList, ResultRequest, TaskError, Error, \ - Count, StatusCount, DataChunk -from armonik.protogen.common.submitter_common_pb2 import CreateSessionRequest, CreateSessionReply, CreateLargeTaskRequest, \ - CreateTaskReply, TaskFilter, ResultReply, AvailabilityReply, WaitRequest, GetTaskStatusRequest, GetTaskStatusReply - -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) - - -class DummySubmitter(SubmitterStub): - def __init__(self, channel: DummyChannel, max_chunk_size=300): - channel.set_instance(self) - super().__init__(channel) - self.max_chunk_size = max_chunk_size - self.large_tasks_requests: List[CreateLargeTaskRequest] = [] - self.task_filter: Optional[TaskFilter] = None - self.create_session: Optional[CreateSessionRequest] = None - self.session: Optional[Session] = None - self.result_stream: List[ResultReply] = [] - self.result_request: Optional[ResultRequest] = None - self.is_available = True - self.wait_request: Optional[WaitRequest] = None - self.get_status_request: Optional[GetTaskStatusRequest] = None - - def GetServiceConfiguration(self, _: Empty) -> Configuration: - return Configuration(data_chunk_max_size=self.max_chunk_size) - - def CreateSession(self, request: CreateSessionRequest) -> CreateSessionReply: - self.create_session = request - return CreateSessionReply(session_id="SessionId") - - def CancelSession(self, request: Session) -> Empty: - self.session = request - return Empty() - - def CreateLargeTasks(self, request: Iterator[CreateLargeTaskRequest]) -> CreateTaskReply: - self.large_tasks_requests = [r for r in request] - return CreateTaskReply(creation_status_list=CreateTaskReply.CreationStatusList(creation_statuses=[ - CreateTaskReply.CreationStatus( - task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"], - data_dependencies=["DD"])), - CreateTaskReply.CreationStatus(error="TestError")])) - - def ListTasks(self, request: TaskFilter) -> TaskIdList: - self.task_filter = request - return TaskIdList(task_ids=["TaskId"]) - - def TryGetResultStream(self, request: ResultRequest) -> Iterator[ResultReply]: - self.result_request = request - for r in self.result_stream: - yield r - - def WaitForAvailability(self, request: ResultRequest) -> AvailabilityReply: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_ERROR - self.result_request = request - return AvailabilityReply(ok=Empty()) if self.is_available else AvailabilityReply( - error=TaskError(task_id="TaskId", errors=[Error(task_status=TASK_STATUS_ERROR, detail="TestError")])) - - def WaitForCompletion(self, request: WaitRequest) -> Count: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_COMPLETED - self.wait_request = request - return Count(values=[StatusCount(status=TASK_STATUS_COMPLETED, count=1)]) - - def GetTaskStatus(self, request: GetTaskStatusRequest) -> GetTaskStatusReply: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_COMPLETED - self.get_status_request = request - return GetTaskStatusReply( - id_statuses=[GetTaskStatusReply.IdStatus(task_id="TaskId", status=TASK_STATUS_COMPLETED)]) - - -default_task_option = TaskOptions(datetime.timedelta(seconds=300), priority=1, max_retries=5) - - -@pytest.mark.parametrize("task_options,partitions", [(default_task_option, None), (default_task_option, ["default"])]) -def test_armonik_submitter_should_create_session(task_options, partitions): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - session_id = submitter.create_session(default_task_options=task_options, partition_ids=partitions) - assert session_id == "SessionId" - assert inner.create_session - assert inner.create_session.default_task_option.priority == task_options.priority - assert len(inner.create_session.partition_ids) == 0 if partitions is None else list(inner.create_session.partition_ids) == partitions - assert len(inner.create_session.default_task_option.options) == len(task_options.options) - assert inner.create_session.default_task_option.max_duration == timedelta_to_duration(task_options.max_duration) - assert inner.create_session.default_task_option.max_retries == task_options.max_retries - - -def test_armonik_submitter_should_cancel_session(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - submitter.cancel_session("SessionId") - assert inner.session is not None - assert inner.session.id == "SessionId" - - -def test_armonik_submitter_should_get_config(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - config = submitter.get_service_configuration() - assert config is not None - assert config.data_chunk_max_size == 300 - - -should_submit = [ - [TaskDefinition("Payload1".encode('utf-8'), expected_output_ids=["EOK"], data_dependencies=["DD"]), - TaskDefinition("Payload2".encode('utf-8'), expected_output_ids=["EOK"], data_dependencies=["DD"])], - [TaskDefinition("Payload1".encode('utf-8'), expected_output_ids=["EOK"]), - TaskDefinition("Payload2".encode('utf-8'), expected_output_ids=["EOK"])], - [TaskDefinition("".encode('utf-8'), expected_output_ids=["EOK"]), - TaskDefinition("".encode('utf-8'), expected_output_ids=["EOK"])] -] - - -@pytest.mark.parametrize("task_list,task_options", - [(t, default_task_option if i else None) for t in should_submit for i in [True, False]]) -def test_armonik_submitter_should_submit(task_list, task_options): - channel = DummyChannel() - inner = DummySubmitter(channel, max_chunk_size=5) - submitter = ArmoniKSubmitter(channel) - successes, errors = submitter.submit("SessionId", tasks=task_list, task_options=task_options) - # The dummy submitter has been set to submit one successful task and one submission error - assert len(successes) == 1 - assert len(errors) == 1 - assert successes[0].id == "TaskId" - assert successes[0].session_id == "SessionId" - assert errors[0] == "TestError" - - reqs = inner.large_tasks_requests - assert len(reqs) > 0 - offset = 0 - assert reqs[0 + offset].WhichOneof("type") == "init_request" - assert reqs[0 + offset].init_request.session_id == "SessionId" - assert reqs[1 + offset].WhichOneof("type") == "init_task" - assert reqs[1 + offset].init_task.header.expected_output_keys[0] == "EOK" - assert reqs[1 + offset].init_task.header.data_dependencies[0] == "DD" if len( - task_list[0].data_dependencies) > 0 else len(reqs[1 + offset].init_task.header.data_dependencies) == 0 - assert reqs[2 + offset].WhichOneof("type") == "task_payload" - assert reqs[2 + offset].task_payload.data == "".encode("utf-8") if len(task_list[0].payload) == 0 \ - else reqs[2 + offset].task_payload.data == task_list[0].payload[:5] - if len(task_list[0].payload) > 0: - offset += 1 - assert reqs[2 + offset].WhichOneof("type") == "task_payload" - assert reqs[2 + offset].task_payload.data == task_list[0].payload[5:] - assert reqs[3 + offset].WhichOneof("type") == "task_payload" - assert reqs[3 + offset].task_payload.data_complete - assert reqs[4 + offset].WhichOneof("type") == "init_task" - assert reqs[4 + offset].init_task.header.expected_output_keys[0] == "EOK" - assert reqs[4 + offset].init_task.header.data_dependencies[0] == "DD" if len( - task_list[1].data_dependencies) > 0 else len(reqs[4 + offset].init_task.header.data_dependencies) == 0 - assert reqs[5 + offset].WhichOneof("type") == "task_payload" - assert reqs[5 + offset].task_payload.data == "".encode("utf-8") if len(task_list[1].payload) == 0 \ - else reqs[5 + offset].task_payload.data == task_list[1].payload[:5] - if len(task_list[1].payload) > 0: - offset += 1 - assert reqs[5 + offset].WhichOneof("type") == "task_payload" - assert reqs[5 + offset].task_payload.data == task_list[1].payload[5:] - assert reqs[6 + offset].WhichOneof("type") == "task_payload" - assert reqs[6 + offset].task_payload.data_complete - assert reqs[7 + offset].WhichOneof("type") == "init_task" - assert reqs[7 + offset].init_task.last_task - - -filters_params = [(session_ids, task_ids, included_statuses, excluded_statuses, - (session_ids is None or task_ids is None) and ( - included_statuses is None or excluded_statuses is None)) - for session_ids in [["SessionId"], None] - for task_ids in [["TaskId"], None] - for included_statuses in [[TaskStatus.COMPLETED], None] - for excluded_statuses in [[TaskStatus.COMPLETED], None]] - - -@pytest.mark.parametrize("session_ids,task_ids,included_statuses,excluded_statuses,should_succeed", filters_params) -def test_armonik_submitter_should_list_tasks(session_ids, task_ids, included_statuses, excluded_statuses, - should_succeed): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - if should_succeed: - tasks = submitter.list_tasks(session_ids=session_ids, task_ids=task_ids, included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - assert len(tasks) > 0 - assert tasks[0] == "TaskId" - assert inner.task_filter is not None - assert all(map(lambda x: x[1] == session_ids[x[0]], enumerate(inner.task_filter.session.ids))) - assert all(map(lambda x: x[1] == task_ids[x[0]], enumerate(inner.task_filter.task.ids))) - assert all(map(lambda x: x[1] == included_statuses[x[0]], enumerate(inner.task_filter.included.statuses))) - assert all(map(lambda x: x[1] == excluded_statuses[x[0]], enumerate(inner.task_filter.excluded.statuses))) - else: - with pytest.raises(ValueError): - _ = submitter.list_tasks(session_ids=session_ids, task_ids=task_ids, included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - - -def test_armonik_submitter_should_get_status(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - - statuses = submitter.get_task_status(["TaskId"]) - assert len(statuses) > 0 - assert "TaskId" in statuses - assert statuses["TaskId"] == TaskStatus.COMPLETED - assert inner.get_status_request is not None - assert len(inner.get_status_request.task_ids) == 1 - assert inner.get_status_request.task_ids[0] == "TaskId" - - -get_result_should_throw = [ - [], - [ResultReply(result=DataChunk(data="payload".encode("utf-8")))], - [ResultReply(result=DataChunk(data="payload".encode("utf-8"))), ResultReply(result=DataChunk(data_complete=True)), - ResultReply(result=DataChunk(data="payload".encode("utf-8")))], - [ResultReply( - error=TaskError(task_id="TaskId", errors=[Error(task_status=TaskStatus.ERROR, detail="TestError")]))], -] - -get_result_should_succeed = [ - [ResultReply(result=DataChunk(data="payload".encode("utf-8"))), ResultReply(result=DataChunk(data_complete=True))] -] - -get_result_should_none = [ - [ResultReply(not_completed_task="NotCompleted")] -] - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_succeed]) -def test_armonik_submitter_should_get_result(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - result = submitter.get_result("SessionId", "ResultId") - assert result is not None - assert len(result) > 0 - assert inner.result_request - assert inner.result_request.result_id == "ResultId" - assert inner.result_request.session == "SessionId" - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_throw]) -def test_armonik_submitter_get_result_should_throw(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - with pytest.raises(Exception): - _ = submitter.get_result("SessionId", "ResultId") - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_none]) -def test_armonik_submitter_get_result_should_none(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - result = submitter.get_result("SessionId", "ResultId") - assert result is None - assert inner.result_request - assert inner.result_request.result_id == "ResultId" - assert inner.result_request.session == "SessionId" - - -@pytest.mark.parametrize("available", [True, False]) -def test_armonik_submitter_wait_availability(available): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.is_available = available - submitter = ArmoniKSubmitter(channel) - reply = submitter.wait_for_availability("SessionId", "ResultId") - assert reply is not None - assert reply.is_available() == available - assert len(reply.errors) == 0 if available else reply.errors[0] == "TestError" - - -@pytest.mark.parametrize("session_ids,task_ids,included_statuses,excluded_statuses,should_succeed", filters_params) -def test_armonik_submitter_wait_completion(session_ids, task_ids, included_statuses, excluded_statuses, should_succeed): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - - if should_succeed: - counts = submitter.wait_for_completion(session_ids=session_ids, task_ids=task_ids, - included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - assert len(counts) > 0 - assert TaskStatus.COMPLETED in counts - assert counts[TaskStatus.COMPLETED] == 1 - assert inner.wait_request is not None - assert all(map(lambda x: x[1] == session_ids[x[0]], enumerate(inner.wait_request.filter.session.ids))) - assert all(map(lambda x: x[1] == task_ids[x[0]], enumerate(inner.wait_request.filter.task.ids))) - assert all(map(lambda x: x[1] == included_statuses[x[0]], - enumerate(inner.wait_request.filter.included.statuses))) - assert all(map(lambda x: x[1] == excluded_statuses[x[0]], - enumerate(inner.wait_request.filter.excluded.statuses))) - assert not inner.wait_request.stop_on_first_task_error - assert not inner.wait_request.stop_on_first_task_cancellation - else: - with pytest.raises(ValueError): - _ = submitter.wait_for_completion(session_ids=session_ids, task_ids=task_ids, - included_statuses=included_statuses, - excluded_statuses=excluded_statuses) diff --git a/packages/python/tests/taskhandler_test.py b/packages/python/tests/taskhandler_test.py deleted file mode 100644 index e4f3c181c..000000000 --- a/packages/python/tests/taskhandler_test.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 -import os - -import pytest -from typing import Iterator - -from armonik.common import TaskDefinition - -from .common import DummyChannel -from armonik.worker import TaskHandler -from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub -from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, NotifyResultDataRequest, NotifyResultDataResponse -from armonik.protogen.common.worker_common_pb2 import ProcessRequest -from armonik.protogen.common.objects_pb2 import Configuration -import logging - -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) - -data_folder = os.getcwd() - - -@pytest.fixture(autouse=True, scope="session") -def setup_teardown(): - with open(os.path.join(data_folder, "payloadid"), "wb") as f: - f.write("payload".encode()) - with open(os.path.join(data_folder, "ddid"), "wb") as f: - f.write("dd".encode()) - yield - os.remove(os.path.join(data_folder, "payloadid")) - os.remove(os.path.join(data_folder, "ddid")) - - -class DummyAgent(AgentStub): - - def __init__(self, channel: DummyChannel) -> None: - channel.set_instance(self) - super(DummyAgent, self).__init__(channel) - self.create_task_messages = [] - self.send_result_task_message = [] - - def CreateTask(self, request_iterator: Iterator[CreateTaskRequest]) -> CreateTaskReply: - self.create_task_messages = [r for r in request_iterator] - return CreateTaskReply(creation_status_list=CreateTaskReply.CreationStatusList(creation_statuses=[ - CreateTaskReply.CreationStatus( - task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"], - data_dependencies=["DD"]))])) - - def NotifyResultData(self, request: NotifyResultDataRequest) -> NotifyResultDataResponse: - self.send_result_task_message.append(request) - return NotifyResultDataResponse(result_ids=[i.result_id for i in request.ids]) - - -should_succeed_case = ProcessRequest(communication_token="token", session_id="sessionid", task_id="taskid", expected_output_keys=["resultid"], payload_id="payloadid", data_dependencies=["ddid"], data_folder=data_folder, configuration=Configuration(data_chunk_max_size=8000)) - - -@pytest.mark.parametrize("requests", [should_succeed_case]) -def test_taskhandler_create_should_succeed(requests: ProcessRequest): - agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler(requests, agent) - assert task_handler.token is not None and len(task_handler.token) > 0 - assert len(task_handler.payload) > 0 - assert task_handler.session_id is not None and len(task_handler.session_id) > 0 - assert task_handler.task_id is not None and len(task_handler.task_id) > 0 - - -def test_taskhandler_data_are_correct(): - agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler(should_succeed_case, agent) - assert len(task_handler.payload) > 0 - - task_handler.create_tasks([TaskDefinition("Payload".encode("utf-8"), ["EOK"], ["DD"])]) - - tasks = agent.create_task_messages - assert len(tasks) == 5 - assert tasks[0].WhichOneof("type") == "init_request" - assert tasks[1].WhichOneof("type") == "init_task" - assert len(tasks[1].init_task.header.data_dependencies) == 1 \ - and tasks[1].init_task.header.data_dependencies[0] == "DD" - assert len(tasks[1].init_task.header.expected_output_keys) == 1 \ - and tasks[1].init_task.header.expected_output_keys[0] == "EOK" - assert tasks[2].WhichOneof("type") == "task_payload" - assert tasks[2].task_payload.data == "Payload".encode("utf-8") - assert tasks[3].WhichOneof("type") == "task_payload" - assert tasks[3].task_payload.data_complete - assert tasks[4].WhichOneof("type") == "init_task" - assert tasks[4].init_task.last_task - - diff --git a/packages/python/tests/tasks_test.py b/packages/python/tests/tasks_test.py deleted file mode 100644 index 752b2ac63..000000000 --- a/packages/python/tests/tasks_test.py +++ /dev/null @@ -1,281 +0,0 @@ -#!/usr/bin/env python3 -import dataclasses -from typing import Optional, List, Any, Union, Dict, Collection -from google.protobuf.timestamp_pb2 import Timestamp - -from datetime import datetime - -import pytest - -from .common import DummyChannel -from armonik.client import ArmoniKTasks -from armonik.client.tasks import TaskFieldFilter -from armonik.common import TaskStatus, datetime_to_timestamp, Task -from armonik.common.filter import StringFilter, Filter -from armonik.protogen.client.tasks_service_pb2_grpc import TasksStub -from armonik.protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, TaskDetailed -from armonik.protogen.common.tasks_filters_pb2 import Filters, FilterField -from armonik.protogen.common.filters_common_pb2 import * -from armonik.protogen.common.tasks_fields_pb2 import * -from .submitter_test import default_task_option - - -class DummyTasksService(TasksStub): - def __init__(self, channel: DummyChannel): - channel.set_instance(self) - super().__init__(channel) - self.task_request: Optional[GetTaskRequest] = None - - def GetTask(self, request: GetTaskRequest) -> GetTaskResponse: - self.task_request = request - raw = TaskDetailed(id="TaskId", session_id="SessionId", owner_pod_id="PodId", parent_task_ids=["ParentTaskId"], - data_dependencies=["DD"], expected_output_ids=["EOK"], retry_of_ids=["RetryId"], - status=TaskStatus.COMPLETED, status_message="Message", - options=default_task_option.to_message(), - created_at=datetime_to_timestamp(datetime.now()), - started_at=datetime_to_timestamp(datetime.now()), - submitted_at=datetime_to_timestamp(datetime.now()), - ended_at=datetime_to_timestamp(datetime.now()), pod_ttl=datetime_to_timestamp(datetime.now()), - output=TaskDetailed.Output(success=True), pod_hostname="Hostname", received_at=datetime_to_timestamp(datetime.now()), - acquired_at=datetime_to_timestamp(datetime.now()) - ) - return GetTaskResponse(task=raw) - - -def test_tasks_get_task_should_succeed(): - channel = DummyChannel() - inner = DummyTasksService(channel) - tasks = ArmoniKTasks(channel) - task = tasks.get_task("TaskId") - assert task is not None - assert inner.task_request is not None - assert inner.task_request.task_id == "TaskId" - assert task.id == "TaskId" - assert task.session_id == "SessionId" - assert task.parent_task_ids == ["ParentTaskId"] - assert task.output - assert task.output.success - - -def test_task_refresh(): - channel = DummyChannel() - inner = DummyTasksService(channel) - tasks = ArmoniKTasks(channel) - current = Task(id="TaskId") - current.refresh(tasks) - assert current is not None - assert inner.task_request is not None - assert inner.task_request.task_id == "TaskId" - assert current.id == "TaskId" - assert current.session_id == "SessionId" - assert current.parent_task_ids == ["ParentTaskId"] - assert current.output - assert current.output.success - - -def test_task_filters(): - filt: StringFilter = TaskFieldFilter.TASK_ID == "TaskId" - message = filt.to_message() - assert isinstance(message, FilterField) - assert message.field.WhichOneof("field") == "task_summary_field" - assert message.field.task_summary_field.field == TASK_SUMMARY_ENUM_FIELD_TASK_ID - assert message.filter_string.value == "TaskId" - assert message.filter_string.operator == FILTER_STRING_OPERATOR_EQUAL - - filt: StringFilter = TaskFieldFilter.TASK_ID != "TaskId" - message = filt.to_message() - assert isinstance(message, FilterField) - assert message.field.WhichOneof("field") == "task_summary_field" - assert message.field.task_summary_field.field == TASK_SUMMARY_ENUM_FIELD_TASK_ID - assert message.filter_string.value == "TaskId" - assert message.filter_string.operator == FILTER_STRING_OPERATOR_NOT_EQUAL - - -@dataclasses.dataclass -class SimpleFieldFilter: - field: Any - value: Any - operator: Any - - -@pytest.mark.parametrize("filt,n_or,n_and,filters", [ - ( - (TaskFieldFilter.INITIAL_TASK_ID == "TestId"), - 1, [1], - [ - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID, "TestId", FILTER_STRING_OPERATOR_EQUAL) - ] - ), - ( - (TaskFieldFilter.APPLICATION_NAME.contains("TestName") & (TaskFieldFilter.CREATED_AT > Timestamp(seconds=1000, nanos=500))), - 1, [2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_NAME, "TestName", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_CREATED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_AFTER) - ] - ), - ( - (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 2, [1, 2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ] - ), - ( - (((TaskFieldFilter.PRIORITY > 3) & ~(TaskFieldFilter.STATUS == TaskStatus.COMPLETED) & TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) | (TaskFieldFilter.ENGINE_TYPE.endswith("Test") & (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_NOT_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ] - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ] - ) -]) -def test_filter_combination(filt: Filter, n_or: int, n_and: List[int], filters: List[SimpleFieldFilter]): - filt = filt.to_disjunction() - assert len(filt._filters) == n_or - sorted_n_and = sorted(n_and) - sorted_actual = sorted([len(f) for f in filt._filters]) - assert len(sorted_n_and) == len(sorted_actual) - assert all((sorted_n_and[i] == sorted_actual[i] for i in range(len(sorted_actual)))) - for f in filt._filters: - for ff in f: - field_value = getattr(ff.field, ff.field.WhichOneof("field")).field - for i, expected in enumerate(filters): - if expected.field == field_value and expected.value == ff.value and expected.operator == ff.operator: - filters.pop(i) - break - else: - print(f"Could not find {str(ff)}") - assert False - assert len(filters) == 0 - - -def test_name_from_value(): - assert TaskStatus.name_from_value(TaskStatus.COMPLETED) == "TASK_STATUS_COMPLETED" - - -class BasicFilterAnd: - - def __setattr__(self, key, value): - self.__dict__[key] = value - - def __getattr__(self, item): - return self.__dict__[item] - - -@pytest.mark.parametrize("filt,n_or,n_and,filters,expected_type", [ - ( - (TaskFieldFilter.INITIAL_TASK_ID == "TestId"), - 1, [1], - [ - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID, "TestId", FILTER_STRING_OPERATOR_EQUAL) - ], - 0 - ), - ( - (TaskFieldFilter.APPLICATION_NAME.contains("TestName") & (TaskFieldFilter.CREATED_AT > Timestamp(seconds=1000, nanos=500))), - 1, [2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_NAME, "TestName", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_CREATED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_AFTER) - ], - 1 - ), - ( - (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 2, [1, 2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY > 3) & ~(TaskFieldFilter.STATUS == TaskStatus.COMPLETED) & TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) | (TaskFieldFilter.ENGINE_TYPE.endswith("Test") & (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_NOT_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))) + (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 4, [2, 3, 2, 1], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ], - 2 - ) -]) -def test_taskfilter_to_message(filt: Filter, n_or: int, n_and: List[int], filters: List[SimpleFieldFilter], expected_type: int): - print(filt) - message = filt.to_message() - conjs: Collection = [] - if expected_type == 2: # Disjunction - conjs: Collection = getattr(message, "or") - assert len(conjs) == n_or - sorted_n_and = sorted(n_and) - sorted_actual = sorted([len(getattr(f, "and")) for f in conjs]) - assert len(sorted_n_and) == len(sorted_actual) - assert all((sorted_n_and[i] == sorted_actual[i] for i in range(len(sorted_actual)))) - - if expected_type == 1: # Conjunction - conjs: Collection = [message] - - if expected_type == 0: # Simple filter - m = BasicFilterAnd() - setattr(m, "and", [message]) - conjs: Collection = [m] - - for conj in conjs: - basics = getattr(conj, "and") - for f in basics: - field_value = getattr(f.field, f.field.WhichOneof("field")).field - for i, expected in enumerate(filters): - if expected.field == field_value and expected.value == getattr(f, f.WhichOneof("value_condition")).value and expected.operator == getattr(f, f.WhichOneof("value_condition")).operator: - filters.pop(i) - break - else: - print(f"Could not find {str(f)}") - assert False - assert len(filters) == 0 diff --git a/packages/python/tests/filters_test.py b/packages/python/tests/test_filters.py similarity index 100% rename from packages/python/tests/filters_test.py rename to packages/python/tests/test_filters.py diff --git a/packages/python/tests/helpers_test.py b/packages/python/tests/test_helpers.py similarity index 76% rename from packages/python/tests/helpers_test.py rename to packages/python/tests/test_helpers.py index 20e07d922..1c0cf517b 100644 --- a/packages/python/tests/helpers_test.py +++ b/packages/python/tests/test_helpers.py @@ -5,7 +5,9 @@ from google.protobuf.duration_pb2 import Duration from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from armonik.common.helpers import datetime_to_timestamp, timestamp_to_datetime, timedelta_to_duration, duration_to_timedelta +from armonik.common.helpers import datetime_to_timestamp, timestamp_to_datetime, timedelta_to_duration, duration_to_timedelta, batched + +from typing import Iterable, List @dataclass @@ -60,3 +62,14 @@ def test_duration_to_timedelta(case: Case): def test_timedelta_to_duration(case: Case): ts = timedelta_to_duration(case.delta) assert ts.seconds == case.duration.seconds and abs(ts.nanos - case.duration.nanos) < 1000 + + +@pytest.mark.parametrize(["iterable", "batch_size", "iterations"], [ + ([1, 2, 3], 3, [[1, 2, 3]]), + ([1, 2, 3], 5, [[1, 2, 3]]), + ([1, 2, 3], 2, [[1, 2], [3]]), + ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 3, [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11]]) +]) +def test_batched(iterable: Iterable, batch_size: int, iterations: List[Iterable]): + for index, batch in enumerate(batched(iterable, batch_size)): + assert batch == iterations[index] diff --git a/packages/python/tests/worker_test.py b/packages/python/tests/worker_test.py deleted file mode 100644 index 032c406ee..000000000 --- a/packages/python/tests/worker_test.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -import logging -import os -import pytest -from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger -from armonik.common import Output -from .taskhandler_test import should_succeed_case, data_folder, DummyAgent -from .common import DummyChannel -from armonik.protogen.common.objects_pb2 import Empty -import grpc - - -def do_nothing(_: TaskHandler) -> Output: - return Output() - - -def throw_error(_: TaskHandler) -> Output: - raise ValueError("TestError") - - -def return_error(_: TaskHandler) -> Output: - return Output("TestError") - - -def return_and_send(th: TaskHandler) -> Output: - th.send_result(th.expected_results[0], b"result") - return Output() - - -@pytest.fixture(autouse=True, scope="function") -def remove_result(): - yield - if os.path.exists(os.path.join(data_folder, "resultid")): - os.remove(os.path.join(data_folder, "resultid")) - - -def test_do_nothing_worker(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success - worker.HealthCheck(Empty(), None) - - -def test_worker_should_return_none(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - assert reply is None - - -def test_worker_should_error(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) - assert not output.success - assert output.error == "TestError" - - -def test_worker_should_write_result(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG)) - worker._client = DummyAgent(DummyChannel()) - reply = worker.Process(should_succeed_case, None) - assert reply is not None - output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) - assert output.success - assert os.path.exists(os.path.join(data_folder, should_succeed_case.expected_output_keys[0])) - with open(os.path.join(data_folder, should_succeed_case.expected_output_keys[0]), "rb") as f: - value = f.read() - assert len(value) > 0 -