Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-aws] Pipes AWS Glue Dagster run interruption handler #23354

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dagster import PipesClient
from dagster._annotations import experimental
from dagster._core.definitions.resource_annotation import TreatAsResourceParam
from dagster._core.errors import DagsterExecutionInterruptedError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClientCompletedInvocation,
Expand Down Expand Up @@ -344,17 +345,20 @@ class PipesGlueClient(PipesClient, TreatAsResourceParam):
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the glue job run. Defaults to :py:class:`PipesCloudWatchsMessageReader`.
client (Optional[boto3.client]): The boto Glue client used to launch the Glue job
forward_termination (bool): Whether to cancel the Glue job run when the Dagster process receives a termination signal.
"""

def __init__(
self,
context_injector: PipesContextInjector,
message_reader: Optional[PipesMessageReader] = None,
client: Optional[boto3.client] = None,
forward_termination: bool = True,
):
self._client = client or boto3.client("glue")
self._context_injector = context_injector
self._message_reader = message_reader or PipesCloudWatchMessageReader()
self.forward_termination = check.bool_param(forward_termination, "forward_termination")

@classmethod
def _is_dagster_maintained(cls) -> bool:
Expand Down Expand Up @@ -435,6 +439,7 @@ def run(

try:
run_id = self._client.start_job_run(**params)["JobRunId"]

except ClientError as err:
context.log.error(
"Couldn't create job %s. Here's why: %s: %s",
Expand All @@ -448,11 +453,16 @@ def run(
log_group = response["JobRun"]["LogGroupName"]
context.log.info(f"Started AWS Glue job {job_name} run: {run_id}")

response = self._wait_for_job_run_completion(job_name, run_id)
try:
response = self._wait_for_job_run_completion(job_name, run_id)
except DagsterExecutionInterruptedError:
if self.forward_termination:
self._terminate_job_run(context=context, job_name=job_name, run_id=run_id)
raise

if response["JobRun"]["JobRunState"] == "FAILED":
if status := response["JobRun"]["JobRunState"] != "SUCCEEDED":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

congrats for introducing only the third instance of the walrus to the codebase! I'm so trained not to use new features given how we have to support old Python version I totally forget they exist.

raise RuntimeError(
f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}"
f"Glue job {job_name} run {run_id} completed with status {status} :\n{response['JobRun'].get('ErrorMessage')}"
)
else:
context.log.info(f"Glue job {job_name} run {run_id} completed successfully")
Expand All @@ -470,6 +480,27 @@ def run(
def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, Any]:
while True:
response = self._client.get_job_run(JobName=job_name, RunId=run_id)
if response["JobRun"]["JobRunState"] in ["FAILED", "SUCCEEDED"]:
# https://docs.aws.amazon.com/glue/latest/dg/job-run-statuses.html
if response["JobRun"]["JobRunState"] in [
"FAILED",
"SUCCEEDED",
"STOPPED",
"TIMEOUT",
"ERROR",
]:
return response
time.sleep(5)

def _terminate_job_run(self, context: OpExecutionContext, job_name: str, run_id: str):
"""Creates a handler which will gracefully stop the Run in case of external termination.
It will stop the Glue job before doing so.
"""
context.log.warning(f"[pipes] execution interrupted, stopping Glue job run {run_id}...")
response = self._client.batch_stop_job_run(JobName=job_name, JobRunIds=[run_id])
runs = response["SuccessfulSubmissions"]
if len(runs) > 0:
context.log.warning(f"Successfully stopped Glue job run {run_id}.")
else:
context.log.warning(
f"Something went wrong during Glue job run termination: {response['errors']}"
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import subprocess
import sys
import tempfile
import time
from typing import Dict, Literal, Optional
import warnings
from dataclasses import dataclass
from subprocess import PIPE, Popen
from typing import Dict, List, Literal, Optional

import boto3


@dataclass
class SimulatedJobRun:
popen: Popen
job_run_id: str
log_group: str
local_script: tempfile._TemporaryFileWrapper
stopped: bool = False


class LocalGlueMockClient:
def __init__(
self,
Expand All @@ -21,15 +32,49 @@ def __init__(
to receive any Dagster messages from it.
If pipes_messages_backend is configured to be CloudWatch, it also uploads stderr and stdout logs to CloudWatch
as if this has been done by Glue.

Once the job is submitted, it is being executed in a separate process to mimic Glue behavior.
Once the job status is requested, the process is checked for its status and the result is returned.
"""
self.aws_endpoint_url = aws_endpoint_url
self.s3_client = s3_client
self.glue_client = glue_client
self.pipes_messages_backend = pipes_messages_backend
self.cloudwatch_client = cloudwatch_client

def get_job_run(self, *args, **kwargs):
return self.glue_client.get_job_run(*args, **kwargs)
self.process = None # jobs will be executed in a separate process

self._job_runs: Dict[str, SimulatedJobRun] = {} # mapping of JobRunId to SimulatedJobRun

def get_job_run(self, JobName: str, RunId: str):
# get original response
response = self.glue_client.get_job_run(JobName=JobName, RunId=RunId)

# check if status override is set
simulated_job_run = self._job_runs[RunId]

if simulated_job_run.stopped:
response["JobRun"]["JobRunState"] = "STOPPED"
return response

# check if popen has completed
if simulated_job_run.popen.poll() is not None:
simulated_job_run.popen.wait()
# check status code
if simulated_job_run.popen.returncode == 0:
response["JobRun"]["JobRunState"] = "SUCCEEDED"
else:
response["JobRun"]["JobRunState"] = "FAILED"
_, stderr = simulated_job_run.popen.communicate()
response["JobRun"]["ErrorMessage"] = stderr.decode()

# upload logs to cloudwatch
if self.pipes_messages_backend == "cloudwatch":
self._upload_logs_to_cloudwatch(RunId)
else:
response["JobRun"]["JobRunState"] = "RUNNING"

return response

def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwargs):
params = {
Expand All @@ -45,67 +90,97 @@ def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwa
bucket = script_s3_path.split("/")[2]
key = "/".join(script_s3_path.split("/")[3:])

# load the script and execute it locally
with tempfile.NamedTemporaryFile() as f:
self.s3_client.download_file(bucket, key, f.name)

args = []
for key, val in (Arguments or {}).items():
args.append(key)
args.append(val)

result = subprocess.run(
[sys.executable, f.name, *args],
check=False,
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
capture_output=True,
)

# mock the job run with moto
response = self.glue_client.start_job_run(**params)
job_run_id = response["JobRunId"]

job_run_response = self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)
log_group = job_run_response["JobRun"]["LogGroupName"]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f = tempfile.NamedTemporaryFile(
delete=False
) # we will close this file later during garbage collection
# load the S3 script to a local file
self.s3_client.download_file(bucket, key, f.name)

# execute the script in a separate process
args = []
for key, val in (Arguments or {}).items():
args.append(key)
args.append(val)
popen = Popen(
[sys.executable, f.name, *args],
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
stdout=PIPE,
stderr=PIPE,
)

# record execution metadata for later use
self._job_runs[job_run_id] = SimulatedJobRun(
popen=popen,
job_run_id=job_run_id,
log_group=self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)["JobRun"][
"LogGroupName"
],
local_script=f,
)

return response

def batch_stop_job_run(self, JobName: str, JobRunIds: List[str]):
for job_run_id in JobRunIds:
if simulated_job_run := self._job_runs.get(job_run_id):
simulated_job_run.popen.terminate()
simulated_job_run.stopped = True
self._upload_logs_to_cloudwatch(job_run_id)

def _upload_logs_to_cloudwatch(self, job_run_id: str):
log_group = self._job_runs[job_run_id].log_group
stdout, stderr = self._job_runs[job_run_id].popen.communicate()

if self.pipes_messages_backend == "cloudwatch":
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)

self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)

for line in result.stderr.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output", # yes, Glue routes stderr to /output
logStreamName=job_run_id,
logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

# replace run state with actual results
response["JobRun"] = {}

response["JobRun"]["JobRunState"] = "SUCCEEDED" if result.returncode == 0 else "FAILED"

# add error message if failed
if result.returncode != 0:
# this actually has to be just the Python exception, but this is good enough for now
response["JobRun"]["ErrorMessage"] = result.stderr
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

return response
try:
self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)
except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException:
pass

try:
self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)
except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException:
pass

for out in [stderr, stdout]: # Glue routes both stderr and stdout to /output
for line in out.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
logEvents=[
{"timestamp": int(time.time() * 1000), "message": str(line)}
],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

def __del__(self):
# cleanup local script paths
for job_run in self._job_runs.values():
job_run.local_script.close()
Loading