Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Python API update sessions and partitions services #461

Merged
merged 2 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/python/src/armonik/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .partitions import ArmoniKPartitions, PartitionFieldFilter
from .sessions import ArmoniKSessions, SessionFieldFilter
from .submitter import ArmoniKSubmitter
from .tasks import ArmoniKTasks, TaskFieldFilter
from .results import ArmoniKResults, ResultFieldFilter
Expand Down
65 changes: 65 additions & 0 deletions packages/python/src/armonik/client/partitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import cast, List, Tuple

from grpc import Channel

from ..common import Direction, Partition
from ..common.filter import Filter, NumberFilter
from ..protogen.client.partitions_service_pb2_grpc import PartitionsStub
from ..protogen.common.partitions_common_pb2 import ListPartitionsRequest, ListPartitionsResponse, GetPartitionRequest, GetPartitionResponse
from ..protogen.common.partitions_fields_pb2 import PartitionField, PartitionRawField, PARTITION_RAW_ENUM_FIELD_PRIORITY
from ..protogen.common.partitions_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFiltersAnd, FilterField as rawFilterField
from ..protogen.common.sort_direction_pb2 import SortDirection


class PartitionFieldFilter:
PRIORITY = NumberFilter(
PartitionField(partition_raw_field=PartitionRawField(field=PARTITION_RAW_ENUM_FIELD_PRIORITY)),
rawFilters,
rawFiltersAnd,
rawFilterField
)


class ArmoniKPartitions:
def __init__(self, grpc_channel: Channel):
""" Result service client

Args:
grpc_channel: gRPC channel to use
"""
self._client = PartitionsStub(grpc_channel)

def list_partitions(self, partition_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = PartitionFieldFilter.PRIORITY, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Partition]]:
"""List partitions based on a filter.

Args:
partition_filter: Filter to apply when listing partitions
page: page number to request, useful for pagination, defaults to 0
page_size: size of a page, defaults to 1000
sort_field: field to sort the resulting list by, defaults to the status
sort_direction: direction of the sort, defaults to ascending

Returns:
A tuple containing :
- The total number of results for the given filter
- The obtained list of results
"""
request = ListPartitionsRequest(
page=page,
page_size=page_size,
filters=cast(rawFilters, partition_filter.to_disjunction().to_message()) if partition_filter else None,
sort=ListPartitionsRequest.Sort(field=cast(PartitionField, sort_field.field), direction=sort_direction),
)
response: ListPartitionsResponse = self._client.ListPartitions(request)
return response.total, [Partition.from_message(p) for p in response.partitions]

def get_partition(self, partition_id: str) -> Partition:
"""Get a partition by its ID.

Args:
partition_id: The partition ID.

Return:
The partition summary.
"""
return Partition.from_message(self._client.GetPartition(GetPartitionRequest(id=partition_id)).partition)
33 changes: 22 additions & 11 deletions packages/python/src/armonik/client/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,26 @@ def create_session(self, default_task_options: TaskOptions, partition_ids: Optio
Returns:
Session Id
"""
if partition_ids is None:
partition_ids = []
request = CreateSessionRequest(default_task_option=default_task_options.to_message())
for partition in partition_ids:
request.partition_ids.append(partition)
request = CreateSessionRequest(
default_task_option=default_task_options.to_message(),
partition_ids=partition_ids if partition_ids else []
)
return self._client.CreateSession(request).session_id

def list_sessions(self, task_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]:
def get_session(self, session_id: str):
"""Get a session by its ID.

Args:
session_id: The ID of the session.

Return:
The session summary.
"""
request = GetSessionRequest(session_id=session_id)
response: GetSessionResponse = self._client.GetSession(request)
return Session.from_message(response.session)

def list_sessions(self, session_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]:
"""
List sessions

Expand All @@ -76,14 +88,14 @@ def list_sessions(self, task_filter: Filter, page: int = 0, page_size: int = 100
- The total number of sessions for the given filter
- The obtained list of sessions
"""
request : ListSessionsRequest = ListSessionsRequest(
request = ListSessionsRequest(
page=page,
page_size=page_size,
filters=cast(rawFilters, task_filter.to_disjunction().to_message()),
filters=cast(rawFilters, session_filter.to_disjunction().to_message()) if session_filter else None,
sort=ListSessionsRequest.Sort(field=cast(SessionField, sort_field.field), direction=sort_direction),
)
list_response : ListSessionsResponse = self._client.ListSessions(request)
return list_response.total, [Session.from_message(t) for t in list_response.sessions]
response : ListSessionsResponse = self._client.ListSessions(request)
return response.total, [Session.from_message(s) for s in response.sessions]

def cancel_session(self, session_id: str) -> None:
"""Cancel a session
Expand All @@ -92,4 +104,3 @@ def cancel_session(self, session_id: str) -> None:
session_id: Id of the session to b cancelled
"""
self._client.CancelSession(CancelSessionRequest(session_id=session_id))

4 changes: 2 additions & 2 deletions packages/python/src/armonik/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter
from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result
from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, ResultStatus
from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result, Partition
from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, ResultStatus, SessionStatus
from .filter import StringFilter, StatusFilter
22 changes: 22 additions & 0 deletions packages/python/src/armonik/common/enumwrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from ..protogen.common.task_status_pb2 import TaskStatus as RawStatus, _TASKSTATUS, TASK_STATUS_CANCELLED, TASK_STATUS_CANCELLING, TASK_STATUS_COMPLETED, TASK_STATUS_CREATING, TASK_STATUS_DISPATCHED, TASK_STATUS_ERROR, TASK_STATUS_PROCESSED, TASK_STATUS_PROCESSING, TASK_STATUS_SUBMITTED, TASK_STATUS_TIMEOUT, TASK_STATUS_UNSPECIFIED, TASK_STATUS_RETRIED
from ..protogen.common.events_common_pb2 import EventsEnum as rawEventsEnum, EVENTS_ENUM_UNSPECIFIED, EVENTS_ENUM_NEW_TASK, EVENTS_ENUM_TASK_STATUS_UPDATE, EVENTS_ENUM_NEW_RESULT, EVENTS_ENUM_RESULT_STATUS_UPDATE, EVENTS_ENUM_RESULT_OWNER_UPDATE
from ..protogen.common.session_status_pb2 import SessionStatus as RawSessionStatus, _SESSIONSTATUS, SESSION_STATUS_UNSPECIFIED, SESSION_STATUS_CANCELLED, SESSION_STATUS_RUNNING
from ..protogen.common.result_status_pb2 import ResultStatus as RawResultStatus, _RESULTSTATUS, RESULT_STATUS_UNSPECIFIED, RESULT_STATUS_CREATED, RESULT_STATUS_COMPLETED, RESULT_STATUS_ABORTED, RESULT_STATUS_NOTFOUND
from ..protogen.common.health_checks_common_pb2 import HEALTH_STATUS_ENUM_UNSPECIFIED, HEALTH_STATUS_ENUM_HEALTHY, HEALTH_STATUS_ENUM_DEGRADED, HEALTH_STATUS_ENUM_UNHEALTHY
from ..protogen.common.worker_common_pb2 import HealthCheckReply
from ..protogen.common.sort_direction_pb2 import SORT_DIRECTION_ASC, SORT_DIRECTION_DESC

Expand Down Expand Up @@ -58,3 +60,23 @@ def name_from_value(status: RawResultStatus) -> str:
COMPLETED = RESULT_STATUS_COMPLETED
ABORTED = RESULT_STATUS_ABORTED
NOTFOUND = RESULT_STATUS_NOTFOUND


class EventTypes:
UNSPECIFIED = EVENTS_ENUM_UNSPECIFIED
NEW_TASK = EVENTS_ENUM_NEW_TASK
TASK_STATUS_UPDATE = EVENTS_ENUM_TASK_STATUS_UPDATE
NEW_RESULT = EVENTS_ENUM_NEW_RESULT
RESULT_STATUS_UPDATE = EVENTS_ENUM_RESULT_STATUS_UPDATE
RESULT_OWNER_UPDATE = EVENTS_ENUM_RESULT_OWNER_UPDATE

@classmethod
def from_string(cls, name: str):
return getattr(cls, name.upper())
Comment on lines +65 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit too soon no ?



class ServiceHealthCheckStatus:
UNSPECIFIED = HEALTH_STATUS_ENUM_UNSPECIFIED
HEALTHY = HEALTH_STATUS_ENUM_HEALTHY
DEGRADED = HEALTH_STATUS_ENUM_DEGRADED
UNHEALTHY = HEALTH_STATUS_ENUM_UNHEALTHY
22 changes: 22 additions & 0 deletions packages/python/src/armonik/common/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,25 @@ def from_message(cls, result_raw: ResultRaw) -> "Result":
result_id=result_raw.result_id,
size=result_raw.size
)

@dataclass
class Partition:
id: str
parent_partition_ids: List[str]
pod_reserved: int
pod_max: int
pod_configuration: Dict[str, str]
preemption_percentage: int
priority: int

@classmethod
def from_message(cls, partition_raw: PartitionRaw) -> "Partition":
return cls(
id=partition_raw.id,
parent_partition_ids=partition_raw.parent_partition_ids,
pod_reserved=partition_raw.pod_reserved,
pod_max=partition_raw.pod_max,
pod_configuration=partition_raw.pod_configuration,
preemption_percentage=partition_raw.preemption_percentage,
priority=partition_raw.priority
)
10 changes: 7 additions & 3 deletions packages/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import requests

from armonik.client import ArmoniKResults, ArmoniKTasks, ArmoniKVersions
from armonik.client import ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions
from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub
from typing import List, Union

Expand Down Expand Up @@ -54,7 +54,7 @@ def clean_up(request):
print("An error occurred when resetting the server: " + str(e))


def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniKTasks, ArmoniKVersions]:
def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]:
"""
Get the ArmoniK client instance based on the specified service name.

Expand All @@ -63,7 +63,7 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK
endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint.

Returns:
Union[ArmoniKTasks, ArmoniKVersions]
Union[ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]
An instance of the specified ArmoniK client.

Raises:
Expand All @@ -75,8 +75,12 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK
"""
channel = grpc.insecure_channel(endpoint).__enter__()
match client_name:
case "Partitions":
return ArmoniKPartitions(channel)
case "Results":
return ArmoniKResults(channel)
case "Sessions":
return ArmoniKSessions(channel)
case "Tasks":
return ArmoniKTasks(channel)
case "Versions":
Expand Down
41 changes: 41 additions & 0 deletions packages/python/tests/test_partitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from .conftest import all_rpc_called, rpc_called, get_client
from armonik.client import ArmoniKPartitions, PartitionFieldFilter
from armonik.common import Partition


class TestArmoniKPartitions:

def test_get_partitions(self):
partitions_client: ArmoniKPartitions = get_client("Partitions")
partition = partitions_client.get_partition("partition-id")

assert rpc_called("Partitions", "GetPartition")
assert isinstance(partition, Partition)
assert partition.id == 'partition-id'
assert partition.parent_partition_ids == []
assert partition.pod_reserved == 1
assert partition.pod_max == 1
assert partition.pod_configuration == {}
assert partition.preemption_percentage == 0
assert partition.priority == 1

def test_list_partitions_no_filter(self):
partitions_client: ArmoniKPartitions = get_client("Partitions")
num, partitions = partitions_client.list_partitions()

assert rpc_called("Partitions", "ListPartitions")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert partitions == []

def test_list_partitions_with_filter(self):
partitions_client: ArmoniKPartitions = get_client("Partitions")
num, partitions = partitions_client.list_partitions(PartitionFieldFilter.PRIORITY == 1)

assert rpc_called("Partitions", "ListPartitions", 2)
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert partitions == []

def test_service_fully_implemented(self):
assert all_rpc_called("Partitions")
72 changes: 72 additions & 0 deletions packages/python/tests/test_sessions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import datetime

from .conftest import all_rpc_called, rpc_called, get_client
from armonik.client import ArmoniKSessions, SessionFieldFilter
from armonik.common import Session, SessionStatus, TaskOptions


class TestArmoniKSessions:

def test_create_session(self):
sessions_client: ArmoniKSessions = get_client("Sessions")
default_task_options = TaskOptions(
max_duration=datetime.timedelta(seconds=1),
priority=1,
max_retries=1
)
session_id = sessions_client.create_session(default_task_options)

assert rpc_called("Sessions", "CreateSession")
assert session_id == "session-id"

def test_get_session(self):
sessions_client: ArmoniKSessions = get_client("Sessions")
session = sessions_client.get_session("session-id")

assert rpc_called("Sessions", "GetSession")
assert isinstance(session, Session)
assert session.session_id == 'session-id'
assert session.status == SessionStatus.CANCELLED
assert session.partition_ids == []
assert session.options == TaskOptions(
max_duration=datetime.timedelta(0),
priority=0,
max_retries=0,
partition_id='',
application_name='',
application_version='',
application_namespace='',
application_service='',
engine_type='',
options={}
)
assert session.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert session.cancelled_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert session.duration == datetime.timedelta(0)

def test_list_session_no_filter(self):
sessions_client: ArmoniKSessions = get_client("Sessions")
num, sessions = sessions_client.list_sessions()

assert rpc_called("Sessions", "ListSessions")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert sessions == []

def test_list_session_with_filter(self):
sessions_client: ArmoniKSessions = get_client("Sessions")
num, sessions = sessions_client.list_sessions(SessionFieldFilter.STATUS == SessionStatus.RUNNING)

assert rpc_called("Sessions", "ListSessions", 2)
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert sessions == []

def test_cancel_session(self):
sessions_client: ArmoniKSessions = get_client("Sessions")
sessions_client.cancel_session("session-id")

assert rpc_called("Sessions", "CancelSession")

def test_service_fully_implemented(self):
assert all_rpc_called("Sessions")
Loading