Skip to content

Commit

Permalink
pipes-glue-interruption-tests-wip (#23502)
Browse files Browse the repository at this point in the history
Exploring if it's possible to rewrite the fake testing Glue client to
run execute scripts in a background process so we can test interruption
  • Loading branch information
danielgafni committed Aug 8, 2024
1 parent 221c062 commit 7356891
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 69 deletions.
16 changes: 12 additions & 4 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def __init__(
context_injector: PipesContextInjector,
message_reader: Optional[PipesMessageReader] = None,
client: Optional[boto3.client] = None,
forward_termination: bool = False,
forward_termination: bool = True,
):
self._client = client or boto3.client("glue")
self._context_injector = context_injector
Expand Down Expand Up @@ -439,6 +439,7 @@ def run(

try:
run_id = self._client.start_job_run(**params)["JobRunId"]

except ClientError as err:
context.log.error(
"Couldn't create job %s. Here's why: %s: %s",
Expand All @@ -459,9 +460,9 @@ def run(
self._terminate_job_run(context=context, job_name=job_name, run_id=run_id)
raise

if response["JobRun"]["JobRunState"] == "FAILED":
if status := response["JobRun"]["JobRunState"] != "SUCCEEDED":
raise RuntimeError(
f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}"
f"Glue job {job_name} run {run_id} completed with status {status} :\n{response['JobRun'].get('ErrorMessage')}"
)
else:
context.log.info(f"Glue job {job_name} run {run_id} completed successfully")
Expand All @@ -479,7 +480,14 @@ def run(
def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, Any]:
while True:
response = self._client.get_job_run(JobName=job_name, RunId=run_id)
if response["JobRun"]["JobRunState"] in ["FAILED", "SUCCEEDED"]:
# https://docs.aws.amazon.com/glue/latest/dg/job-run-statuses.html
if response["JobRun"]["JobRunState"] in [
"FAILED",
"SUCCEEDED",
"STOPPED",
"TIMEOUT",
"ERROR",
]:
return response
time.sleep(5)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import subprocess
import sys
import tempfile
import time
from typing import Dict, Literal, Optional
import warnings
from dataclasses import dataclass
from subprocess import PIPE, Popen
from typing import Dict, List, Literal, Optional

import boto3


@dataclass
class SimulatedJobRun:
popen: Popen
job_run_id: str
log_group: str
local_script: tempfile._TemporaryFileWrapper
stopped: bool = False


class LocalGlueMockClient:
def __init__(
self,
Expand All @@ -21,15 +32,49 @@ def __init__(
to receive any Dagster messages from it.
If pipes_messages_backend is configured to be CloudWatch, it also uploads stderr and stdout logs to CloudWatch
as if this has been done by Glue.
Once the job is submitted, it is being executed in a separate process to mimic Glue behavior.
Once the job status is requested, the process is checked for its status and the result is returned.
"""
self.aws_endpoint_url = aws_endpoint_url
self.s3_client = s3_client
self.glue_client = glue_client
self.pipes_messages_backend = pipes_messages_backend
self.cloudwatch_client = cloudwatch_client

def get_job_run(self, *args, **kwargs):
return self.glue_client.get_job_run(*args, **kwargs)
self.process = None # jobs will be executed in a separate process

self._job_runs: Dict[str, SimulatedJobRun] = {} # mapping of JobRunId to SimulatedJobRun

def get_job_run(self, JobName: str, RunId: str):
# get original response
response = self.glue_client.get_job_run(JobName=JobName, RunId=RunId)

# check if status override is set
simulated_job_run = self._job_runs[RunId]

if simulated_job_run.stopped:
response["JobRun"]["JobRunState"] = "STOPPED"
return response

# check if popen has completed
if simulated_job_run.popen.poll() is not None:
simulated_job_run.popen.wait()
# check status code
if simulated_job_run.popen.returncode == 0:
response["JobRun"]["JobRunState"] = "SUCCEEDED"
else:
response["JobRun"]["JobRunState"] = "FAILED"
_, stderr = simulated_job_run.popen.communicate()
response["JobRun"]["ErrorMessage"] = stderr.decode()

# upload logs to cloudwatch
if self.pipes_messages_backend == "cloudwatch":
self._upload_logs_to_cloudwatch(RunId)
else:
response["JobRun"]["JobRunState"] = "RUNNING"

return response

def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwargs):
params = {
Expand All @@ -45,67 +90,97 @@ def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwa
bucket = script_s3_path.split("/")[2]
key = "/".join(script_s3_path.split("/")[3:])

# load the script and execute it locally
with tempfile.NamedTemporaryFile() as f:
self.s3_client.download_file(bucket, key, f.name)

args = []
for key, val in (Arguments or {}).items():
args.append(key)
args.append(val)

result = subprocess.run(
[sys.executable, f.name, *args],
check=False,
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
capture_output=True,
)

# mock the job run with moto
response = self.glue_client.start_job_run(**params)
job_run_id = response["JobRunId"]

job_run_response = self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)
log_group = job_run_response["JobRun"]["LogGroupName"]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f = tempfile.NamedTemporaryFile(
delete=False
) # we will close this file later during garbage collection
# load the S3 script to a local file
self.s3_client.download_file(bucket, key, f.name)

# execute the script in a separate process
args = []
for key, val in (Arguments or {}).items():
args.append(key)
args.append(val)
popen = Popen(
[sys.executable, f.name, *args],
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
stdout=PIPE,
stderr=PIPE,
)

# record execution metadata for later use
self._job_runs[job_run_id] = SimulatedJobRun(
popen=popen,
job_run_id=job_run_id,
log_group=self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)["JobRun"][
"LogGroupName"
],
local_script=f,
)

return response

def batch_stop_job_run(self, JobName: str, JobRunIds: List[str]):
for job_run_id in JobRunIds:
if simulated_job_run := self._job_runs.get(job_run_id):
simulated_job_run.popen.terminate()
simulated_job_run.stopped = True
self._upload_logs_to_cloudwatch(job_run_id)

def _upload_logs_to_cloudwatch(self, job_run_id: str):
log_group = self._job_runs[job_run_id].log_group
stdout, stderr = self._job_runs[job_run_id].popen.communicate()

if self.pipes_messages_backend == "cloudwatch":
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)

self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)

for line in result.stderr.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output", # yes, Glue routes stderr to /output
logStreamName=job_run_id,
logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

# replace run state with actual results
response["JobRun"] = {}

response["JobRun"]["JobRunState"] = "SUCCEEDED" if result.returncode == 0 else "FAILED"

# add error message if failed
if result.returncode != 0:
# this actually has to be just the Python exception, but this is good enough for now
response["JobRun"]["ErrorMessage"] = result.stderr
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

return response
try:
self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)
except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException:
pass

try:
self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)
except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException:
pass

for out in [stderr, stdout]: # Glue routes both stderr and stdout to /output
for line in out.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
logEvents=[
{"timestamp": int(time.time() * 1000), "message": str(line)}
],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

def __del__(self):
# cleanup local script paths
for job_run in self._job_runs.values():
job_run.local_script.close()
Loading

0 comments on commit 7356891

Please sign in to comment.