diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 6565bcecfbaf1..dac0e1b23cabd 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -42,6 +42,8 @@ if TYPE_CHECKING: from airflow.utils.context import Context +import boto3 +import botocore.waiter class BatchOperator(BaseOperator): """ @@ -363,3 +365,50 @@ def execute(self, context: Context): self.hook.client.create_compute_environment(**trim_none_values(kwargs)) self.log.info("AWS Batch compute environment created successfully") + + +def submit_batch_job(job_name, job_definition, job_queue, command): + client = boto3.client("batch") + response = client.submit_job( + jobName=job_name, + jobQueue=job_queue, + jobDefinition=job_definition, + containerOverrides={"command": command}, + ) + job_id = response["jobId"] + return job_id + +class JobStatusWaiter(botocore.waiter.Waiter): + def __init__(self, client, job_id, desired_status): + super().__init__( + client=client, + waiter_name="JobStatusWaiter", + delay=10, + max_attempts=60, + operation_name="DescribeJobs", + acceptors=[ + { + "expected": desired_status, + "matcher": "path", + "state": "success", + "argument": "jobs[].status", + }, + { + "expected": "FAILED", + "matcher": "path", + "state": "failure", + "argument": "jobs[].status", + }, + { + "expected": "SUCCEEDED", + "matcher": "path", + "state": "failure", + "argument": "jobs[].status", + }, + ], + ) + self.job_id = job_id + + def wait(self): + self._wait(JobIds=[self.job_id]) +