Skip to content

Commit

Permalink
test: Refactor python tests with mock server (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
aneojgurhem authored Jan 3, 2024
2 parents ee0ff9e + d56bf0f commit eda55a4
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 798 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ jobs:
- name: Install dependencies
run: pip install "$(echo pkg/armonik*.whl)[tests]"

- name: Install .NET Core
uses: actions/setup-dotnet@3447fd6a9f9e57506b15f895c5b76d3b197dc7c2 # v3
with:
dotnet-version: 6.x

- name: Start Mock server
run: |
cd ../csharp/ArmoniK.Api.Mock
nohup dotnet run > /dev/null 2>&1 &
sleep 60
- name: Run tests
run: python -m pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-report xml:coverage.xml --cov-report html:coverage_report

Expand Down
31 changes: 30 additions & 1 deletion packages/python/src/armonik/common/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import timedelta, datetime, timezone
from typing import List, Optional
from typing import List, Optional, Iterable, TypeVar

import google.protobuf.duration_pb2 as duration
import google.protobuf.timestamp_pb2 as timestamp
Expand All @@ -9,6 +9,9 @@
from .enumwrapper import TaskStatus


T = TypeVar('T')


def get_task_filter(session_ids: Optional[List[str]] = None, task_ids: Optional[List[str]] = None,
included_statuses: Optional[List[TaskStatus]] = None,
excluded_statuses: Optional[List[TaskStatus]] = None) -> TaskFilter:
Expand Down Expand Up @@ -96,3 +99,29 @@ def timedelta_to_duration(delta: timedelta) -> duration.Duration:
d = duration.Duration()
d.FromTimedelta(delta)
return d


def batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]:
"""
Batches elements from an iterable into lists of size at most 'n'.
Args:
iterable : The input iterable.
n : The batch size.
Yields:
A generator yielding batches of elements from the input iterable.
"""
it = iter(iterable)

sentinel = object()
batch = []
c = next(it, sentinel)
while c is not sentinel:
batch.append(c)
if len(batch) == n:
yield batch
batch.clear()
c = next(it, sentinel)
if len(batch) > 0:
yield batch
40 changes: 0 additions & 40 deletions packages/python/tests/common.py

This file was deleted.

123 changes: 123 additions & 0 deletions packages/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import grpc
import os
import pytest
import requests

from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub
from typing import List


# Mock server endpoints used for the tests.
grpc_endpoint = "localhost:5001"
calls_endpoint = "http://localhost:5000/calls.json"
reset_endpoint = "http://localhost:5000/reset"
data_folder = os.getcwd()


@pytest.fixture(scope="session", autouse=True)
def clean_up(request):
"""
This fixture runs at the session scope and is automatically used before and after
running all the tests. It set up and teardown the testing environments by:
- creating dummy files before testing begins;
- clear files after testing;
- resets the mocking gRPC server counters to maintain a clean testing environment.
Yields:
None: This fixture is used as a context manager, and the test code runs between
the 'yield' statement and the cleanup code.
Raises:
requests.exceptions.HTTPError: If an error occurs when attempting to reset
the mocking gRPC server counters.
"""
# Write dumm payload and data dependency to files for testing purposes
with open(os.path.join(data_folder, "payload-id"), "wb") as f:
f.write("payload".encode())
with open(os.path.join(data_folder, "dd-id"), "wb") as f:
f.write("dd".encode())

# Run all the tests
yield

# Remove the temporary files created for testing
os.remove(os.path.join(data_folder, "payload-id"))
os.remove(os.path.join(data_folder, "dd-id"))

# Reset the mock server counters
try:
response = requests.post(reset_endpoint)
response.raise_for_status()
print("\nMock server resetted.")
except requests.exceptions.HTTPError as e:
print("An error occurred when resetting the server: " + str(e))


def rpc_called(service_name: str, rpc_name: str, n_calls: int = 1, endpoint: str = calls_endpoint) -> bool:
"""Check if a remote procedure call (RPC) has been made a specified number of times.
This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint.
Args:
service_name (str): The name of the service providing the RPC.
rpc_name (str): The name of the specific RPC to check for the number of calls.
n_calls (int, optional): The expected number of times the RPC should have been called. Default is 1.
endpoint (str, optional): The URL of the remote service providing RPC information. Default to
calls_endpoint.
Returns:
bool: True if the specified RPC has been called the expected number of times, False otherwise.
Raises:
requests.exceptions.RequestException: If an error occurs when requesting ArmoniK.Api.Mock.
Example:
>>> rpc_called('http://localhost:5000/calls.json', 'Versions', 'ListVersionss', 0)
True
"""
response = requests.get(endpoint)
response.raise_for_status()
data = response.json()

# Check if the RPC has been called n_calls times
if data[service_name][rpc_name] == n_calls:
return True
return False


def all_rpc_called(service_name: str, missings: List[str] = [], endpoint: str = calls_endpoint) -> bool:
"""
Check if all remote procedure calls (RPCs) in a service have been made at least once.
This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint.
Args:
service_name (str): The name of the service containing the RPC information in the response.
endpoint (str, optional): The URL of the remote service providing RPC information. Default is
the value of calls_endpoint.
missings (List[str], optional): A list of RPCs known to be not implemented. Default is an empty list.
Returns:
bool: True if all RPCs in the specified service have been called at least once, False otherwise.
Raises:
requests.exceptions.RequestException: If an error occurs when requesting ArmoniK.Api.Mock.
Example:
>>> all_rpc_called('http://localhost:5000/calls.json', 'Versions')
False
"""
response = requests.get(endpoint)
response.raise_for_status()
data = response.json()

missing_rpcs = []

# Check if all RPCs in the service have been called at least once
for rpc_name, rpc_num_calls in data[service_name].items():
if rpc_num_calls == 0:
missing_rpcs.append(rpc_name)
if missing_rpcs:
if missings == missing_rpcs:
return True
print(f"RPCs not implemented in {service_name} service: {missing_rpcs}.")
return False
return True
Loading

0 comments on commit eda55a4

Please sign in to comment.