Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pipes-cloudwatch-message-reader'…
Browse files Browse the repository at this point in the history
… into pipes-cloudwatch-message-reader
  • Loading branch information
danielgafni committed Aug 5, 2024
2 parents 20098b3 + 18aafe3 commit 2e31171
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def run(
# so we need to filter them out
params = {k: v for k, v in params.items() if v is not None}

start_timestamp = time.time()
start_timestamp = time.time() * 1000 # unix time in ms

try:
run_id = self._client.start_job_run(**params)["JobRunId"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
import subprocess
import sys
import tempfile
from typing import Dict, Optional
import time
from typing import Dict, Literal, Optional

import boto3


class LocalGlueMockClient:
def __init__(self, s3_client: boto3.client, glue_client: boto3.client):
def __init__(
self,
aws_endpoint_url: str, # usually received from moto
s3_client: boto3.client,
glue_client: boto3.client,
pipes_messages_backend: Literal["s3", "cloudwatch"],
cloudwatch_client: Optional[boto3.client] = None,
):
"""This class wraps moto3 clients for S3 and Glue, and provides a way to "run" Glue jobs locally.
This is necessary because moto3 does not actually run anything when you start a Glue job, so we won't be able
to receive any Dagster messages from it.
If pipes_messages_backend is configured to be CloudWatch, it also uploads stderr and stdout logs to CloudWatch
as if this has been done by Glue.
"""
self.aws_endpoint_url = aws_endpoint_url
self.s3_client = s3_client
self.glue_client = glue_client
self.pipes_messages_backend = pipes_messages_backend
self.cloudwatch_client = cloudwatch_client

def get_job_run(self, *args, **kwargs):
return self.glue_client.get_job_run(*args, **kwargs)

def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], *args, **kwargs):
def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwargs):
params = {
"JobName": JobName,
}
Expand All @@ -44,12 +57,46 @@ def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], *args
result = subprocess.run(
[sys.executable, f.name, *args],
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
capture_output=True,
)

# mock the job run with moto
response = self.glue_client.start_job_run(**params)
job_run_id = response["JobRunId"]

job_run_response = self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)
log_group = job_run_response["JobRun"]["LogGroupName"]

if self.pipes_messages_backend == "cloudwatch":
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)

self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)

for line in result.stderr.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output", # yes, Glue routes stderr to /output
logStreamName=job_run_id,
logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

# replace run state with actual results
response["JobRun"] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import textwrap
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Iterator
from typing import Any, Callable, Iterator, Literal

import boto3
import pytest
Expand All @@ -24,6 +24,7 @@
from dagster._core.pipes.utils import PipesEnvContextInjector
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster_aws.pipes import (
PipesCloudWatchMessageReader,
PipesGlueClient,
PipesLambdaClient,
PipesLambdaLogsMessageReader,
Expand Down Expand Up @@ -283,6 +284,7 @@ def fake_lambda_asset(context):
def external_s3_glue_script(s3_client) -> Iterator[str]:
# This is called in an external process and so cannot access outer scope
def script_fn():
import os
import time

import boto3
Expand All @@ -293,12 +295,18 @@ def script_fn():
open_dagster_pipes,
)

client = boto3.client("s3", region_name="us-east-1", endpoint_url="http://localhost:5193")
context_loader = PipesS3ContextLoader(client=client)
message_writer = PipesS3MessageWriter(client, interval=0.001)
s3_client = boto3.client(
"s3", region_name="us-east-1", endpoint_url="http://localhost:5193"
)

messages_backend = os.environ["PIPES_MESSAGES_BACKEND"]
if messages_backend == "s3":
message_writer = PipesS3MessageWriter(s3_client, interval=0.001)
else:
message_writer = None

with open_dagster_pipes(
context_loader=context_loader,
context_loader=PipesS3ContextLoader(client=s3_client),
message_writer=message_writer,
params_loader=PipesCliArgsParamsLoader(),
) as context:
Expand Down Expand Up @@ -341,9 +349,25 @@ def glue_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client
return client


def test_glue_s3_pipes(capsys, s3_client, glue_client):
@pytest.fixture
def cloudwatch_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client:
return boto3.client("logs", region_name="us-east-1", endpoint_url=_MOTO_SERVER_URL)


@pytest.mark.parametrize("pipes_messages_backend", ["s3", "cloudwatch"])
def test_glue_pipes(
capsys,
s3_client,
glue_client,
cloudwatch_client,
pipes_messages_backend: Literal["s3", "cloudwatch"],
):
context_injector = PipesS3ContextInjector(bucket=_S3_TEST_BUCKET, client=s3_client)
message_reader = PipesS3MessageReader(bucket=_S3_TEST_BUCKET, client=s3_client, interval=0.001)
message_reader = (
PipesS3MessageReader(bucket=_S3_TEST_BUCKET, client=s3_client, interval=0.001)
if pipes_messages_backend == "s3"
else PipesCloudWatchMessageReader(client=cloudwatch_client)
)

@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient):
Expand All @@ -355,7 +379,13 @@ def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient):
return results

pipes_glue_client = PipesGlueClient(
client=LocalGlueMockClient(glue_client=glue_client, s3_client=s3_client),
client=LocalGlueMockClient(
aws_endpoint_url=_MOTO_SERVER_URL,
glue_client=glue_client,
s3_client=s3_client,
cloudwatch_client=cloudwatch_client,
pipes_messages_backend=pipes_messages_backend,
),
context_injector=context_injector,
message_reader=message_reader,
)
Expand Down

0 comments on commit 2e31171

Please sign in to comment.