Skip to content

Commit

Permalink
pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
rogelioLpz committed Dec 30, 2022
1 parent 1f310b9 commit 28d5f80
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 7 deletions.
3 changes: 2 additions & 1 deletion examples/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fast_agave.middlewares import FastAgaveErrorHandler
from .resources import app as resources
from .middlewares import AuthedMiddleware
from .tasks.task_example import dummy_task
from .tasks.task_example import dummy_task, task_validator

connect(host='mongomock://localhost:27017/db')
app = FastAPI(title='example')
Expand All @@ -27,3 +27,4 @@ async def on_startup() -> None: # pragma: no cover
# Inicializa el task que recibe mensajes
# provenientes de SQS
asyncio.create_task(dummy_task())
asyncio.create_task(task_validator())
15 changes: 15 additions & 0 deletions examples/tasks/task_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from typing import Optional

from fast_agave.tasks.sqs_tasks import task

from pydantic import BaseModel

# Esta URL es solo un mock de la queue.
# Debes reemplazarla con la URL de tu queue
QUEUE_URL = 'http://127.0.0.1:4000/123456789012/core.fifo'
QUEUE2_URL = 'http://127.0.0.1:4000/123456789012/validator.fifo'


class ValidatorModel(BaseModel):
name: str
age: int
nick_name: Optional[str]


@task(queue_url=QUEUE_URL, region_name='us-east-1')
async def dummy_task(message) -> None:
print(message)


@task(queue_url=QUEUE2_URL, region_name='us-east-1', validator=ValidatorModel)
async def task_validator(message: ValidatorModel) -> None:
print(message.dict())
20 changes: 15 additions & 5 deletions fast_agave/tasks/sqs_tasks.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
import asyncio
import json
import os
from functools import wraps
from itertools import count
from typing import AsyncGenerator, Callable, Coroutine
from typing import AsyncGenerator, Callable, Coroutine, Optional, Type

from aiobotocore.httpsession import HTTPClientError
from aiobotocore.session import get_session
from pydantic import BaseModel

from ..exc import RetryTask

AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION', '')


async def run_task(
coro: Coroutine,
task_func: Callable,
body: dict,
sqs,
queue_url: str,
receipt_handle: str,
message_receive_count: int,
max_retries: int,
validator: Optional[Type[BaseModel]] = None,
) -> None:
delete_message = True
try:
await coro
data = validator(**body) if validator else body
await task_func(data)
except RetryTask:
delete_message = message_receive_count >= max_retries + 1
finally:
Expand Down Expand Up @@ -67,11 +74,12 @@ async def get_running_fast_agave_tasks():

def task(
queue_url: str,
region_name: str,
region_name: str = AWS_DEFAULT_REGION,
wait_time_seconds: int = 15,
visibility_timeout: int = 3600,
max_retries: int = 1,
max_concurrent_tasks: int = 5,
validator: Optional[Type[BaseModel]] = None,
):
def task_builder(task_func: Callable):
@wraps(task_func)
Expand Down Expand Up @@ -106,12 +114,14 @@ async def concurrency_controller(coro: Coroutine) -> None:
asyncio.create_task(
concurrency_controller(
run_task(
task_func(body),
task_func,
body,
sqs,
queue_url,
message['ReceiptHandle'],
message_receive_count,
max_retries,
validator,
),
),
name='fast-agave-task',
Expand Down
2 changes: 1 addition & 1 deletion fast_agave/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.8.0'
__version__ = '0.8.1.dev0'
40 changes: 40 additions & 0 deletions tests/tasks/test_sqs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import aiobotocore.client
import pytest
from aiobotocore.httpsession import HTTPClientError
from pydantic import BaseModel

from fast_agave.exc import RetryTask
from fast_agave.tasks.sqs_tasks import get_running_fast_agave_tasks, task
Expand Down Expand Up @@ -40,6 +41,45 @@ async def test_execute_tasks(sqs_client) -> None:
assert 'Messages' not in resp


@pytest.mark.asyncio
async def test_execute_tasks_validator(sqs_client) -> None:
async_mock_function = AsyncMock(return_value=None)

class Validator(BaseModel):
id: str
name: str

task_params = dict(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
validator=Validator,
)
# Invalid body, not execute function
await sqs_client.send_message(
MessageBody=json.dumps(dict(foo='bar')),
MessageGroupId='4321',
)
await task(**task_params)(async_mock_function)()
assert async_mock_function.call_count == 0
resp = await sqs_client.receive_message()
assert 'Messages' not in resp

# Body approve validator, function receive Validator
test_message = Validator(id='abc123', name='fast-agave')
await sqs_client.send_message(
MessageBody=test_message.json(),
MessageGroupId='1234',
)
await task(**task_params)(async_mock_function)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

resp = await sqs_client.receive_message()
assert 'Messages' not in resp


@pytest.mark.asyncio
async def test_not_execute_tasks(sqs_client) -> None:
"""
Expand Down

0 comments on commit 28d5f80

Please sign in to comment.