Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Python API add events and health checks services #464

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions packages/python/src/armonik/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,20 @@
from .tasks import ArmoniKTasks, TaskFieldFilter
from .results import ArmoniKResults, ResultFieldFilter
from .versions import ArmoniKVersions
from .events import ArmoniKEvents
from .health_checks import ArmoniKHealthChecks

__all__ = [
'ArmoniKPartitions',
'ArmoniKSessions',
'ArmoniKSubmitter',
'ArmoniKTasks',
'ArmoniKResults',
"ArmoniKVersions",
"ArmoniKEvents",
"ArmoniKHealthChecks",
"TaskFieldFilter",
"PartitionFieldFilter",
"SessionFieldFilter",
"ResultFieldFilter"
]
86 changes: 86 additions & 0 deletions packages/python/src/armonik/client/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Any, Callable, cast, List

from grpc import Channel

from .results import ArmoniKResults
from ..common import EventTypes, Filter, NewTaskEvent, NewResultEvent, ResultOwnerUpdateEvent, ResultStatusUpdateEvent, TaskStatusUpdateEvent, ResultStatus, Event
from .results import ResultFieldFilter
from ..protogen.client.events_service_pb2_grpc import EventsStub
from ..protogen.common.events_common_pb2 import EventSubscriptionRequest, EventSubscriptionResponse
from ..protogen.common.results_filters_pb2 import Filters as rawResultFilters
from ..protogen.common.tasks_filters_pb2 import Filters as rawTaskFilters

class ArmoniKEvents:

_events_obj_mapping = {
"new_result": NewResultEvent,
"new_task": NewTaskEvent,
"result_owner_update": ResultOwnerUpdateEvent,
"result_status_update": ResultStatusUpdateEvent,
"task_status_update": TaskStatusUpdateEvent
}

def __init__(self, grpc_channel: Channel):
"""Events service client

Args:
grpc_channel: gRPC channel to use
"""
self._client = EventsStub(grpc_channel)
self._results_client = ArmoniKResults(grpc_channel)

def get_events(self, session_id: str, event_types: List[EventTypes], event_handlers: List[Callable[[str, EventTypes, Event], bool]], task_filter: Filter | None = None, result_filter: Filter | None = None) -> None:
"""Get events that represents updates of result and tasks data.

Args:
session_id: The ID of the session.
event_types: The list of the types of event to catch.
event_handlers: The list of handlers that process the events. Handlers are evaluated in he order they are provided.
An handler takes three positional arguments: the ID of the session, the type of event and the event as an object.
An handler returns a boolean, if True the process continues, otherwise the stream is closed and the service stops
listening to new events.
task_filter: A filter on tasks.
result_filter: A filter on results.

"""
request = EventSubscriptionRequest(
session_id=session_id,
returned_events=event_types
)
if task_filter:
request.tasks_filters=cast(rawTaskFilters, task_filter.to_disjunction().to_message()),
if result_filter:
request.results_filters=cast(rawResultFilters, result_filter.to_disjunction().to_message()),

streaming_call = self._client.GetEvents(request)
for message in streaming_call:
event_type = message.WhichOneof("update")
if any([event_handler(session_id, EventTypes.from_string(event_type), self._events_obj_mapping[event_type].from_raw_event(getattr(message, event_type))) for event_handler in event_handlers]):
break

def wait_for_result_availability(self, result_id: str, session_id: str) -> None:
"""Wait until a result is ready i.e its status updates to COMPLETED.

Args:
result_id: The ID of the result.
session_id: The ID of the session.

Raises:
RuntimeError: If the result status is ABORTED.
"""
def handler(session_id, event_type, event):
if not isinstance(event, ResultStatusUpdateEvent):
raise ValueError("Handler should receive event of type 'ResultStatusUpdateEvent'.")
if event.status == ResultStatus.COMPLETED:
return False
elif event.status == ResultStatus.ABORTED:
raise RuntimeError(f"Result {result.name} with ID {result_id} is aborted.")
return True

result = self._results_client.get_result(result_id)
if result.status == ResultStatus.COMPLETED:
return
elif result.status == ResultStatus.ABORTED:
raise RuntimeError(f"Result {result.name} with ID {result_id} is aborted.")

self.get_events(session_id, [EventTypes.RESULT_STATUS_UPDATE], [handler], result_filter=(ResultFieldFilter.RESULT_ID == result_id))
21 changes: 21 additions & 0 deletions packages/python/src/armonik/client/health_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import cast, List, Tuple

from grpc import Channel

from ..common import HealthCheckStatus
from ..protogen.client.health_checks_service_pb2_grpc import HealthChecksServiceStub
from ..protogen.common.health_checks_common_pb2 import CheckHealthRequest, CheckHealthResponse


class ArmoniKHealthChecks:
def __init__(self, grpc_channel: Channel):
""" Result service client

Args:
grpc_channel: gRPC channel to use
"""
self._client = HealthChecksServiceStub(grpc_channel)

def check_health(self):
response: CheckHealthResponse = self._client.CheckHealth(CheckHealthRequest())
return {service.name: {"message": service.message, "status": service.healthy} for service in response.services}
43 changes: 40 additions & 3 deletions packages/python/src/armonik/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,41 @@
from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter
from .helpers import (
datetime_to_timestamp,
timestamp_to_datetime,
duration_to_timedelta,
timedelta_to_duration,
get_task_filter,
batched
)
from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result, Partition
from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, ResultStatus, SessionStatus
from .filter import StringFilter, StatusFilter
from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, SessionStatus, ResultStatus, EventTypes, ServiceHealthCheckStatus
from .events import *
from .filter import Filter, StringFilter, StatusFilter

__all__ = [
'datetime_to_timestamp',
'timestamp_to_datetime',
'duration_to_timedelta',
'timedelta_to_duration',
'get_task_filter',
'batched',
'Task',
'TaskDefinition',
'TaskOptions',
'Output',
'ResultAvailability',
'Session',
'Result',
'Partition',
'HealthCheckStatus',
'TaskStatus',
'Direction',
'SessionStatus',
'ResultStatus',
'EventTypes',
# Include all names from events module
# Add names from filter module
'Filter',
'StringFilter',
'StatusFilter',
'ServiceHealthCheckStatus'
]
53 changes: 53 additions & 0 deletions packages/python/src/armonik/common/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from abc import ABC
from typing import List

from dataclasses import dataclass, fields

from .enumwrapper import TaskStatus, ResultStatus


class Event(ABC):
@classmethod
def from_raw_event(cls, raw_event):
values = {}
for raw_field in cls.__annotations__.keys():
values[raw_field] = getattr(raw_event, raw_field)
return cls(**values)


@dataclass
class TaskStatusUpdateEvent(Event):
task_id: str
status: TaskStatus


@dataclass
class ResultStatusUpdateEvent(Event):
result_id: str
status: ResultStatus


@dataclass
class ResultOwnerUpdateEvent(Event):
result_id: str
previous_owner_id: str
current_owner_id: str


@dataclass
class NewTaskEvent(Event):
task_id: str
payload_id: str
origin_task_id: str
status: TaskStatus
expected_output_keys: List[str]
data_dependencies: List[str]
retry_of_ids: List[str]
parent_task_ids: List[str]


@dataclass
class NewResultEvent(Event):
result_id: str
owner_id: str
status: ResultStatus
10 changes: 7 additions & 3 deletions packages/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import requests

from armonik.client import ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions
from armonik.client import ArmoniKEvents, ArmoniKHealthChecks, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions
from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub
from typing import List, Union

Expand Down Expand Up @@ -55,7 +55,7 @@ def clean_up(request):
print("An error occurred when resetting the server: " + str(e))


def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[AgentStub, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]:
def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[AgentStub, ArmoniKEvents, ArmoniKHealthChecks, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]:
"""
Get the ArmoniK client instance based on the specified service name.

Expand All @@ -64,7 +64,7 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[AgentSt
endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint.

Returns:
Union[AgentStub, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]
Union[AgentStub, ArmoniKEvents, ArmoniKHealthChecks, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]
An instance of the specified ArmoniK client.

Raises:
Expand All @@ -78,6 +78,10 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[AgentSt
match client_name:
case "Agent":
return AgentStub(channel)
case "Events":
return ArmoniKEvents(channel)
case "HealthChecks":
return ArmoniKHealthChecks(channel)
case "Partitions":
return ArmoniKPartitions(channel)
case "Results":
Expand Down
22 changes: 22 additions & 0 deletions packages/python/tests/test_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .conftest import all_rpc_called, rpc_called, get_client
from armonik.client import ArmoniKEvents
from armonik.common import EventTypes, NewResultEvent, ResultStatus


class TestArmoniKEvents:
def test_get_events_no_filter(self):
def test_handler(session_id, event_type, event):
assert session_id == "session-id"
assert event_type == EventTypes.NEW_RESULT
assert isinstance(event, NewResultEvent)
assert event.result_id == "result-id"
assert event.owner_id == "owner-id"
assert event.status == ResultStatus.CREATED

tasks_client: ArmoniKEvents = get_client("Events")
tasks_client.get_events("session-id", [EventTypes.TASK_STATUS_UPDATE], [test_handler])

assert rpc_called("Events", "GetEvents")

def test_service_fully_implemented(self):
assert all_rpc_called("Events")
18 changes: 18 additions & 0 deletions packages/python/tests/test_healthcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import datetime

from .conftest import all_rpc_called, rpc_called, get_client
from armonik.client import ArmoniKHealthChecks
from armonik.common import ServiceHealthCheckStatus


class TestArmoniKHealthChecks:

def test_check_health(self):
health_checks_client: ArmoniKHealthChecks = get_client("HealthChecks")
services_health = health_checks_client.check_health()

assert rpc_called("HealthChecks", "CheckHealth")
assert services_health == {'mock': {'message': 'Mock is healthy', 'status': ServiceHealthCheckStatus.HEALTHY}}

def test_service_fully_implemented(self):
assert all_rpc_called("HealthChecks")
Loading