From f637780da00cb5b66d5e8a3c39778b36516dc324 Mon Sep 17 00:00:00 2001 From: Daniel Gafni Date: Sat, 10 Aug 2024 16:22:02 +0200 Subject: [PATCH] add ECS client --- pyright/alt-1/requirements-pinned.txt | 28 ++- pyright/master/requirements-pinned.txt | 47 +++- .../dagster-aws/dagster_aws/pipes/__init__.py | 3 +- .../dagster_aws/pipes/clients/__init__.py | 3 +- .../dagster_aws/pipes/clients/ecs.py | 219 ++++++++++++++++ .../dagster_aws_tests/pipes_tests/fake_ecs.py | 236 ++++++++++++++++++ python_modules/libraries/dagster-aws/setup.py | 1 + 7 files changed, 515 insertions(+), 22 deletions(-) create mode 100644 python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/ecs.py create mode 100644 python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_ecs.py diff --git a/pyright/alt-1/requirements-pinned.txt b/pyright/alt-1/requirements-pinned.txt index 94228fcfbeb80..8d36c1f277ba4 100644 --- a/pyright/alt-1/requirements-pinned.txt +++ b/pyright/alt-1/requirements-pinned.txt @@ -1,5 +1,5 @@ agate==1.9.1 -aiobotocore==2.13.2 +aiobotocore==2.13.3 aiofile==3.8.8 aiohappyeyeballs==2.4.0 aiohttp==3.10.5 @@ -9,7 +9,6 @@ alembic==1.13.2 aniso8601==9.0.1 annotated-types==0.7.0 anyio==4.4.0 -appnope==0.1.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 @@ -18,15 +17,16 @@ asn1crypto==1.5.1 astroid==3.2.4 asttokens==2.4.1 async-lru==2.0.4 -async-timeout==4.0.3 attrs==24.2.0 babel==2.16.0 backoff==2.2.1 backports-tarfile==1.2.0 beautifulsoup4==4.12.3 bleach==6.1.0 -boto3==1.34.131 -botocore==1.34.131 +boto3==1.34.162 +boto3-stubs==1.35.4 +botocore==1.34.162 +botocore-stubs==1.35.4 buildkite-test-collector==0.1.8 cachetools==5.5.0 caio==0.9.17 @@ -83,7 +83,6 @@ distlib==0.3.8 docker==7.1.0 docstring-parser==0.16 duckdb==1.0.0 -exceptiongroup==1.2.2 execnet==2.1.1 executing==2.0.1 fastjsonschema==2.20.0 @@ -108,6 +107,7 @@ gql==3.5.0 graphene==3.3 graphql-core==3.2.3 graphql-relay==3.2.0 +greenlet==3.0.3 grpcio==1.66.0 grpcio-health-checking==1.62.3 grpcio-status==1.62.3 @@ -131,6 +131,7 @@ jaraco-classes==3.4.0 jaraco-context==6.0.1 jaraco-functools==4.0.2 jedi==0.19.1 +jeepney==0.8.0 jinja2==3.1.4 jmespath==1.0.1 joblib==1.4.2 @@ -169,6 +170,14 @@ msgpack==1.0.8 multidict==6.0.5 multimethod==1.10 mypy==1.11.1 +mypy-boto3-cloudformation==1.35.0 +mypy-boto3-dynamodb==1.35.0 +mypy-boto3-ec2==1.35.3 +mypy-boto3-ecs==1.35.2 +mypy-boto3-lambda==1.35.3 +mypy-boto3-rds==1.35.0 +mypy-boto3-s3==1.35.2 +mypy-boto3-sqs==1.35.0 mypy-extensions==1.0.0 mypy-protobuf==3.6.0 nbclient==0.10.0 @@ -253,6 +262,7 @@ s3transfer==0.10.2 scikit-learn==1.5.1 scipy==1.14.1 seaborn==0.13.2 +secretstorage==3.3.3 send2trash==1.8.3 setuptools==73.0.1 shellingham==1.5.4 @@ -264,13 +274,13 @@ snowflake-sqlalchemy==1.5.1 sortedcontainers==2.4.0 soupsieve==2.6 sqlalchemy==1.4.53 -sqlglot==25.16.0 +sqlglot==25.16.1 sqlglotrs==0.2.9 sqlparse==0.5.1 stack-data==0.6.3 starlette==0.38.2 structlog==24.4.0 -syrupy==4.6.4 +syrupy==4.7.1 tabulate==0.9.0 terminado==0.18.1 text-unidecode==1.3 @@ -285,6 +295,7 @@ tqdm==4.66.5 traitlets==5.14.3 typeguard==4.3.0 typer==0.12.4 +types-awscrt==0.21.2 types-backports==0.1.3 types-certifi==2021.10.8.3 types-cffi==1.16.0.20240331 @@ -299,6 +310,7 @@ types-python-dateutil==2.9.0.20240821 types-pytz==2024.1.0.20240417 types-pyyaml==6.0.12.20240808 types-requests==2.32.0.20240712 +types-s3transfer==0.10.1 types-setuptools==73.0.0.20240822 types-simplejson==3.19.0.20240801 types-six==1.16.21.20240513 diff --git a/pyright/master/requirements-pinned.txt b/pyright/master/requirements-pinned.txt index 74154502ac45c..61eb14eb6a35a 100644 --- a/pyright/master/requirements-pinned.txt +++ b/pyright/master/requirements-pinned.txt @@ -25,7 +25,6 @@ apache-airflow-providers-sqlite==3.8.2 apeye==1.4.1 apeye-core==1.1.5 apispec==6.6.1 -appnope==0.1.4 argcomplete==3.5.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 @@ -37,14 +36,13 @@ asn1crypto==1.5.1 asttokens==2.4.1 astunparse==1.6.3 async-lru==2.0.4 -async-timeout==4.0.3 attrs==24.2.0 autodocsumm==0.2.13 autoflake==2.3.1 -e python_modules/automation avro==1.11.3 avro-gen3==0.7.13 -aws-sam-translator==1.89.0 +aws-sam-translator==1.91.0 aws-xray-sdk==2.14.0 azure-core==1.30.2 azure-identity==1.17.1 @@ -58,9 +56,11 @@ billiard==4.2.0 bitmath==1.3.3.1 bleach==6.1.0 blinker==1.8.2 -bokeh==3.5.1 +bokeh==3.5.2 boto3==1.35.4 +boto3-stubs==1.35.4 botocore==1.35.4 +botocore-stubs==1.35.4 buildkite-test-collector==0.1.8 cachecontrol==0.14.0 cached-property==1.5.2 @@ -209,7 +209,6 @@ duckdb==1.0.0 ecdsa==0.19.0 email-validator==1.3.1 entrypoints==0.4 -exceptiongroup==1.2.2 execnet==2.1.1 executing==2.0.1 expandvars==0.12.0 @@ -255,6 +254,7 @@ graphql-core==3.2.3 graphql-relay==3.2.0 graphviz==0.20.3 great-expectations==0.17.11 +greenlet==3.0.3 grpcio==1.66.0 grpcio-health-checking==1.62.3 grpcio-status==1.62.3 @@ -320,7 +320,7 @@ langchain-community==0.2.9 langchain-core==0.2.34 langchain-openai==0.1.14 langchain-text-splitters==0.2.2 -langsmith==0.1.102 +langsmith==0.1.104 lazy-object-proxy==1.10.0 leather==0.4.0 limits==3.13.0 @@ -357,6 +357,14 @@ msal-extensions==1.2.0 msgpack==1.0.8 multidict==6.0.5 multimethod==1.10 +mypy-boto3-cloudformation==1.35.0 +mypy-boto3-dynamodb==1.35.0 +mypy-boto3-ec2==1.35.3 +mypy-boto3-ecs==1.35.2 +mypy-boto3-lambda==1.35.3 +mypy-boto3-rds==1.35.0 +mypy-boto3-s3==1.35.2 +mypy-boto3-sqs==1.35.0 mypy-extensions==1.0.0 mypy-protobuf==3.6.0 mysql-connector-python==9.0.0 @@ -372,6 +380,18 @@ noteable-origami==1.1.5 notebook==7.2.1 notebook-shim==0.2.4 numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.20 +nvidia-nvtx-cu12==12.1.105 oauth2client==4.1.3 oauthlib==3.2.2 objgraph==3.6.1 @@ -408,7 +428,7 @@ partd==1.4.2 path==16.16.0 pathable==0.4.3 pathspec==0.12.1 -pathvalidate==3.2.0 +pathvalidate==3.2.1 pendulum==2.1.2 pexpect==4.9.0 pillow==10.4.0 @@ -438,7 +458,7 @@ pyarrow-hotfix==0.6 pyasn1==0.6.0 pyasn1-modules==0.4.0 pycparser==2.22 -pydantic==1.10.17 +pydantic==1.10.18 pydata-google-auth==1.8.2 pyflakes==3.2.0 pygments==2.18.0 @@ -498,7 +518,7 @@ s3transfer==0.10.2 scikit-learn==1.5.1 scipy==1.14.1 scrapbook==0.5.0 -sdf-cli==0.3.21 +sdf-cli==0.3.23 seaborn==0.13.2 selenium==4.23.1 semver==3.0.2 @@ -515,7 +535,7 @@ skein==0.8.2 skl2onnx==1.17.0 slack-sdk==3.31.0 sling==1.2.15 -sling-mac-arm64==1.2.15 +sling-linux-amd64==1.2.15 smmap==5.0.1 sniffio==1.3.1 snowballstemmer==2.2.0 @@ -538,7 +558,7 @@ sphinxcontrib-serializinghtml==2.0.0 sqlalchemy==1.4.53 sqlalchemy-jsonfield==1.0.2 sqlalchemy-utils==0.41.2 -sqlglot==25.16.0 +sqlglot==25.16.1 sqlglotrs==0.2.9 sqlparse==0.5.1 sshpubkeys==3.3.1 @@ -547,7 +567,7 @@ stack-data==0.6.3 starlette==0.38.2 structlog==24.4.0 sympy==1.13.2 -syrupy==4.6.4 +syrupy==4.7.1 tabledata==1.3.3 tabulate==0.9.0 tblib==3.0.0 @@ -571,6 +591,7 @@ tqdm==4.66.5 traitlets==5.14.3 trio==0.26.2 trio-websocket==0.11.1 +triton==3.0.0 -e examples/experimental/dagster-airlift/examples/tutorial-example -e examples/tutorial_notebook_assets twilio==9.2.3 @@ -578,6 +599,7 @@ twine==1.15.0 typeguard==4.3.0 typepy==1.3.2 typer==0.12.4 +types-awscrt==0.21.2 types-backports==0.1.3 types-certifi==2021.10.8.3 types-cffi==1.16.0.20240331 @@ -592,6 +614,7 @@ types-python-dateutil==2.9.0.20240821 types-pytz==2024.1.0.20240417 types-pyyaml==6.0.12.20240808 types-requests==2.31.0.6 +types-s3transfer==0.10.1 types-setuptools==73.0.0.20240822 types-simplejson==3.19.0.20240801 types-six==1.16.21.20240513 diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py index 1da771cf018c6..e513f5cc16adf 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py @@ -1,4 +1,4 @@ -from dagster_aws.pipes.clients import PipesGlueClient, PipesLambdaClient +from dagster_aws.pipes.clients import PipesECSClient, PipesGlueClient, PipesLambdaClient from dagster_aws.pipes.context_injectors import ( PipesLambdaEventContextInjector, PipesS3ContextInjector, @@ -12,6 +12,7 @@ __all__ = [ "PipesGlueClient", "PipesLambdaClient", + "PipesECSClient", "PipesS3ContextInjector", "PipesLambdaEventContextInjector", "PipesS3MessageReader", diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py index 3495d649d9390..b7625af2e2cfb 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py @@ -1,4 +1,5 @@ +from dagster_aws.pipes.clients.ecs import PipesECSClient from dagster_aws.pipes.clients.glue import PipesGlueClient from dagster_aws.pipes.clients.lambda_ import PipesLambdaClient -__all__ = ["PipesGlueClient", "PipesLambdaClient"] +__all__ = ["PipesGlueClient", "PipesLambdaClient", "PipesECSClient"] diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/ecs.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/ecs.py new file mode 100644 index 0000000000000..5768655dd891d --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/ecs.py @@ -0,0 +1,219 @@ +from pprint import pformat +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast + +import boto3 +import botocore +import dagster._check as check +from dagster import PipesClient +from dagster._annotations import experimental +from dagster._core.definitions.resource_annotation import TreatAsResourceParam +from dagster._core.errors import DagsterExecutionInterruptedError +from dagster._core.execution.context.compute import OpExecutionContext +from dagster._core.pipes.client import ( + PipesClientCompletedInvocation, + PipesContextInjector, + PipesMessageReader, +) +from dagster._core.pipes.utils import PipesEnvContextInjector, open_pipes_session + +from dagster_aws.pipes.message_readers import PipesCloudWatchMessageReader + +if TYPE_CHECKING: + from mypy_boto3_ecs.client import ECSClient + from mypy_boto3_ecs.type_defs import RunTaskRequestRequestTypeDef + + +@experimental +class PipesECSClient(PipesClient, TreatAsResourceParam): + """A pipes client for running AWS ECS tasks. + + Args: + client (Optional[boto3.client]): The boto ECS client used to launch the ECS task + context_injector (Optional[PipesContextInjector]): A context injector to use to inject + context into the ECS task. Defaults to :py:class:`PipesEnvContextInjector`. + message_reader (Optional[PipesMessageReader]): A message reader to use to read messages + from the ECS task. Defaults to :py:class:`PipesCloudWatchMessageReader`. + forward_termination (bool): Whether to cancel the ECS task when the Dagster process receives a termination signal. + """ + + def __init__( + self, + client: Optional[boto3.client] = None, + context_injector: Optional[PipesContextInjector] = None, + message_reader: Optional[PipesMessageReader] = None, + forward_termination: bool = True, + ): + self._client: "ECSClient" = client or boto3.client("ecs") + self._context_injector = context_injector or PipesEnvContextInjector() + self._message_reader = message_reader or PipesCloudWatchMessageReader() + self.forward_termination = check.bool_param(forward_termination, "forward_termination") + + @classmethod + def _is_dagster_maintained(cls) -> bool: + return True + + def run( + self, + *, + context: OpExecutionContext, + extras: Optional[Dict[str, Any]] = None, + params: "RunTaskRequestRequestTypeDef", + ) -> PipesClientCompletedInvocation: + """Run ECS tasks, enriched with the pipes protocol. + + Args: + context (OpExecutionContext): The context of the currently executing Dagster op or asset. + extras (Optional[Dict[str, Any]]): Additional information to pass to the pipes session. + params (dict): Parameters for the ``run_task`` boto3 ECS client call. + Must contain ``taskDefinition`` key. + See `Boto3 API Documentation `_ + + Returns: + PipesClientCompletedInvocation: Wrapper containing results reported by the external + process. + """ + with open_pipes_session( + context=context, + message_reader=self._message_reader, + context_injector=self._context_injector, + extras=extras, + ) as session: + task_definition = params["taskDefinition"] + cluster = params.get("cluster") + + overrides = cast(dict, params.get("overrides") or {}) + overrides["containerOverrides"] = overrides.get("containerOverrides", []) + + # get all containers from task definition + task_definition_response = self._client.describe_task_definition( + taskDefinition=task_definition + ) + + log_configurations = { + container["name"]: container.get("logConfiguration") + for container in task_definition_response["taskDefinition"]["containerDefinitions"] + } + + all_container_names = { + container["name"] + for container in task_definition_response["taskDefinition"]["containerDefinitions"] + } + + container_names_with_overrides = { + container_override["name"] for container_override in overrides["containerOverrides"] + } + + pipes_args = session.get_bootstrap_env_vars() + + # set env variables for every container in the taskDefinition + # respecting current overrides provided by the user + + environment_overrides = [ + { + "name": k, + "value": v, + } + for k, v in pipes_args.items() + ] + + # set environment variables for existing overrides + + for container_override in overrides["containerOverrides"]: + container_override["environment"] = container_override.get("environment", []) + container_override["environment"].extend(environment_overrides) + + # set environment variables for containers that are not in the overrides + for container_name in all_container_names - container_names_with_overrides: + overrides["containerOverrides"].append( + { + "name": container_name, + "environment": environment_overrides, + } + ) + + params["overrides"] = ( + overrides # assign in case overrides was created here as an empty dict + ) + + response = self._client.run_task(**params) + + tasks: List[str] = [task["taskArn"] for task in response["tasks"]] + + try: + response = self._wait_for_tasks_completion(tasks=tasks, cluster=cluster) + + # collect logs from all containers + for task in response["tasks"]: + task_id = task["taskArn"].split("/")[-1] + + for container in task["containers"]: + if log_config := log_configurations.get(container["name"]): + if log_config["logDriver"] == "awslogs": + log_group = log_config["options"]["awslogs-group"] + + # stream name is combined from: prefix, container name, task id + log_stream = f"{log_config['options']['awslogs-stream-prefix']}/{container['name']}/{task_id}" + + if isinstance(self._message_reader, PipesCloudWatchMessageReader): + self._message_reader.consume_cloudwatch_logs( + log_group, + log_stream, + start_time=int(task["createdAt"].timestamp() * 1000), + ) + else: + context.log.warning( + f"[pipes] Unsupported log driver {log_config['logDriver']} for container {container['name']} in task {task['taskArn']}. Dagster Pipes won't be able to receive messages from this container." + ) + + # check for failed containers + failed_containers = {} + + for task in response["tasks"]: + for container in task["containers"]: + if container.get("exitCode") not in (0, None): + failed_containers[container["runtimeId"]] = container.get("exitCode") + + if failed_containers: + raise RuntimeError( + f"Some ECS containers finished with non-zero exit code:\n{pformat(list(failed_containers.keys()))}" + ) + + except DagsterExecutionInterruptedError: + if self.forward_termination: + context.log.warning( + "[pipes] Dagster process interrupted, terminating ECS tasks" + ) + self._terminate_tasks(context=context, tasks=tasks, cluster=cluster) + raise + + context.log.info(f"[pipes] ECS tasks {tasks} completed") + + return PipesClientCompletedInvocation(session) + + def _wait_for_tasks_completion( + self, tasks: List[str], cluster: Optional[str] = None + ) -> Dict[str, Any]: + waiter = self._client.get_waiter("tasks_stopped") + + params: Dict[str, Any] = {"tasks": tasks} + + if cluster: + params["cluster"] = cluster + + waiter.wait(**params) + return self._client.describe_tasks(**params) + + def _terminate_tasks( + self, context: OpExecutionContext, tasks: List[str], cluster: Optional[str] = None + ): + for task in tasks: + try: + self._client.stop_task( + cluster=cluster, + task=task, + reason="Dagster process was interrupted", + ) + except botocore.exceptions.ClientError as e: + context.log.warning( + f"[pipes] Couldn't stop ECS task {task} in cluster {cluster}:\n{e}" + ) diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_ecs.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_ecs.py new file mode 100644 index 0000000000000..fb2fa09a3caae --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_ecs.py @@ -0,0 +1,236 @@ +import sys +import time +import uuid +from dataclasses import dataclass +from datetime import datetime +from subprocess import PIPE, Popen +from typing import Dict, List, Optional, cast + +import boto3 + + +@dataclass +class SimulatedTaskRun: + popen: Popen + cluster: str + task_arn: str + log_group: str + log_stream: str + created_at: datetime + runtime_id: str + stopped_reason: Optional[str] = None + stopped: bool = False + logs_uploaded: bool = False + + +class LocalECSMockClient: + def __init__(self, ecs_client: boto3.client, cloudwatch_client: boto3.client): + self.ecs_client = ecs_client + self.cloudwatch_client = cloudwatch_client + + self._task_runs: Dict[ + str, SimulatedTaskRun + ] = {} # mapping of TaskDefinitionArn to TaskDefinition + + def get_waiter(self, waiter_name: str): + return WaiterMock(self, waiter_name) + + def register_task_definition(self, **kwargs): + return self.ecs_client.register_task_definition(**kwargs) + + def describe_task_definition(self, **kwargs): + response = self.ecs_client.describe_task_definition(**kwargs) + assert ( + len(response["taskDefinition"]["containerDefinitions"]) == 1 + ), "Only 1 container is supported in tests" + # unlike real ECS, moto doesn't use cloudwatch logging by default + # so let's add it here + response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"] = ( + response["taskDefinition"]["containerDefinitions"][0].get("logConfiguration") + or { + "logDriver": "awslogs", + "options": { + "awslogs-group": f"{response['taskDefinition']['taskDefinitionArn']}", # this value doesn't really matter + "awslogs-stream-prefix": "ecs", + }, + } + ) + return response + + def run_task(self, **kwargs): + response = self.ecs_client.run_task(**kwargs) + + task_arn = response["tasks"][0]["taskArn"] + task_definition_arn = response["tasks"][0]["taskDefinitionArn"] + + task_definition = self.describe_task_definition(taskDefinition=task_definition_arn)[ + "taskDefinition" + ] + + assert ( + len(task_definition["containerDefinitions"]) == 1 + ), "Only 1 container is supported in tests" + + # execute in a separate process + command = task_definition["containerDefinitions"][0]["command"] + + assert ( + command[0] == sys.executable + ), "Only the current Python interpreter is supported in tests" + + created_at = datetime.now() + + popen = Popen( + command, + stdout=PIPE, + stderr=PIPE, + # get env from container overrides + env={ + env["name"]: env["value"] + for env in kwargs["overrides"]["containerOverrides"][0].get("environment", []) + }, + ) + + log_group = task_definition["containerDefinitions"][0]["logConfiguration"]["options"][ + "awslogs-group" + ] + stream_prefix = task_definition["containerDefinitions"][0]["logConfiguration"]["options"][ + "awslogs-stream-prefix" + ] + container_name = task_definition["containerDefinitions"][0]["name"] + log_stream = f"{stream_prefix}/{container_name}/{task_arn.split('/')[-1]}" + + self._task_runs[task_arn] = SimulatedTaskRun( + popen=popen, + cluster=kwargs.get("cluster", "default"), + task_arn=task_arn, + log_group=log_group, + log_stream=log_stream, + created_at=created_at, + runtime_id=str(uuid.uuid4()), + ) + + return response + + def describe_tasks(self, cluster: str, tasks: List[str]): + assert len(tasks) == 1, "Only 1 task is supported in tests" + + simulated_task = cast(SimulatedTaskRun, self._task_runs[tasks[0]]) + + response = self.ecs_client.describe_tasks(cluster=cluster, tasks=tasks) + + assert len(response["tasks"]) == 1, "Only 1 task is supported in tests" + + task_definition = self.describe_task_definition( + taskDefinition=response["tasks"][0]["taskDefinitionArn"] + )["taskDefinition"] + + assert ( + len(task_definition["containerDefinitions"]) == 1 + ), "Only 1 container is supported in tests" + + # need to inject container name since moto doesn't return it + + response["tasks"][0]["containers"].append( + { + "name": task_definition["containerDefinitions"][0]["name"], + "runtimeId": simulated_task.runtime_id, + } + ) + + response["tasks"][0]["createdAt"] = simulated_task.created_at + + # check if any failed + for task in response["tasks"]: + if task["taskArn"] in self._task_runs: + simulated_task = self._task_runs[task["taskArn"]] + + if simulated_task.stopped: + task["lastStatus"] = "STOPPED" + task["stoppedReason"] = simulated_task.stopped_reason + task["containers"][0]["exitCode"] = 1 + self._upload_logs_to_cloudwatch(task["taskArn"]) + return response + + if simulated_task.popen.poll() is not None: + simulated_task.popen.wait() + # check status code + if simulated_task.popen.returncode == 0: + task["lastStatus"] = "STOPPED" + task["containers"][0]["exitCode"] = 0 + else: + task["lastStatus"] = "STOPPED" + # _, stderr = simulated_task.popen.communicate() + task["containers"][0]["exitCode"] = 1 + + self._upload_logs_to_cloudwatch(task["taskArn"]) + + else: + task["lastStatus"] = "RUNNING" + + return response + + def stop_task(self, cluster: str, task: str, reason: Optional[str] = None): + if simulated_task := self._task_runs.get(task): + simulated_task.popen.terminate() + simulated_task.stopped = True + simulated_task.stopped_reason = reason + self._upload_logs_to_cloudwatch(task) + else: + raise RuntimeError(f"Task {task} was not found") + + def _upload_logs_to_cloudwatch(self, task: str): + simulated_task = self._task_runs[task] + + if simulated_task.logs_uploaded: + return + + log_group = simulated_task.log_group + log_stream = simulated_task.log_stream + + stdout, stderr = self._task_runs[task].popen.communicate() + + try: + self.cloudwatch_client.create_log_group( + logGroupName=f"{log_group}", + ) + except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException: + pass + + try: + self.cloudwatch_client.create_log_stream( + logGroupName=f"{log_group}", + logStreamName=log_stream, + ) + except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException: + pass + + for out in [stderr, stdout]: + for line in out.decode().split("\n"): + if line: + self.cloudwatch_client.put_log_events( + logGroupName=f"{log_group}", + logStreamName=log_stream, + logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}], + ) + + time.sleep(0.01) + + simulated_task.logs_uploaded = True + + +class WaiterMock: + def __init__(self, client: LocalECSMockClient, waiter_name: str): + self.client = client + self.waiter_name = waiter_name + + def wait(self, **kwargs): + if self.waiter_name == "tasks_stopped": + while True: + response = self.client.describe_tasks(**kwargs) + if all(task["lastStatus"] == "STOPPED" for task in response["tasks"]): + return + time.sleep(0.1) + + else: + raise NotImplementedError(f"Waiter {self.waiter_name} is not implemented") diff --git a/python_modules/libraries/dagster-aws/setup.py b/python_modules/libraries/dagster-aws/setup.py index a48aebfe3f0d4..346dd33796560 100644 --- a/python_modules/libraries/dagster-aws/setup.py +++ b/python_modules/libraries/dagster-aws/setup.py @@ -37,6 +37,7 @@ def get_version() -> str: python_requires=">=3.8,<3.13", install_requires=[ "boto3", + "boto3-stubs[essential, ecs]", f"dagster{pin}", "packaging", "requests",