Skip to content

Commit

Permalink
cloudwatch message reader
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Aug 6, 2024
1 parent f5b1c38 commit cf82b96
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 41 deletions.
124 changes: 96 additions & 28 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@
import string
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, Iterator, Literal, Mapping, Optional, Sequence
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
TypedDict,
)

import boto3
import dagster._check as check
Expand Down Expand Up @@ -157,10 +169,22 @@ def no_messages_debug_text(self) -> str:
)


class CloudWatchEvent(TypedDict):
timestamp: int
message: str
ingestionTime: int


@experimental
class PipesCloudWatchMessageReader(PipesMessageReader):
"""Message reader that consumes AWS CloudWatch logs to read pipes messages."""

def __init__(self, client: Optional[boto3.client] = None):
"""Args:
client (boto3.client): boto3 CloudWatch client.
"""
self.client = client or boto3.client("logs")

@contextmanager
def read_messages(
self,
Expand All @@ -174,13 +198,53 @@ def read_messages(
self._handler = None

def consume_cloudwatch_logs(
self, client: boto3.client, log_group: str, log_stream: str
self,
log_group: str,
log_stream: str,
start_time: Optional[int] = None,
end_time: Optional[int] = None,
) -> None:
raise NotImplementedError("CloudWatch logs are not yet supported in the pipes protocol.")
handler = check.not_none(
self._handler, "Can only consume logs within context manager scope."
)

for events_batch in self._get_all_cloudwatch_events(
log_group=log_group, log_stream=log_stream, start_time=start_time, end_time=end_time
):
for event in events_batch:
for log_line in event["message"].splitlines():
extract_message_or_forward_to_stdout(handler, log_line)

def no_messages_debug_text(self) -> str:
return "Attempted to read messages by extracting them from the tail of CloudWatch logs directly."

def _get_all_cloudwatch_events(
self,
log_group: str,
log_stream: str,
start_time: Optional[int] = None,
end_time: Optional[int] = None,
) -> Generator[List[CloudWatchEvent], None, None]:
"""Returns batches of CloudWatch events until the stream is complete or end_time."""
params: Dict[str, Any] = {
"logGroupName": log_group,
"logStreamName": log_stream,
}

if start_time is not None:
params["startTime"] = start_time
if end_time is not None:
params["endTime"] = end_time

response = self.client.get_log_events(**params)

while events := response.get("events"):
yield events

params["nextToken"] = response["nextForwardToken"]

response = self.client.get_log_events(**params)


class PipesLambdaEventContextInjector(PipesEnvContextInjector):
def no_messages_debug_text(self) -> str:
Expand All @@ -203,11 +267,11 @@ class PipesLambdaClient(PipesClient, TreatAsResourceParam):

def __init__(
self,
client: boto3.client,
client: Optional[boto3.client] = None,
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
):
self._client = client
self._client = client or boto3.client("lambda")
self._message_reader = message_reader or PipesLambdaLogsMessageReader()
self._context_injector = context_injector or PipesLambdaEventContextInjector()

Expand Down Expand Up @@ -272,35 +336,34 @@ def run(

class PipesGlueContextInjector(PipesS3ContextInjector):
def no_messages_debug_text(self) -> str:
return "Attempted to inject context via Glue job arguments."
return "Attempted to inject context via Glue job Arguments"


class PipesGlueLogsMessageReader(PipesCloudWatchMessageReader):
def no_messages_debug_text(self) -> str:
return "Attempted to read messages by extracting them from the tail of CloudWatch logs directly."
pass


@experimental
class PipesGlueClient(PipesClient, TreatAsResourceParam):
"""A pipes client for invoking AWS Glue jobs.
Args:
client (boto3.client): The boto Glue client used to call invoke.
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into the Glue job, for example, :py:class:`PipesGlueContextInjector`.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the glue job run. Defaults to :py:class:`PipesGlueLogsMessageReader`.
client (Optional[boto3.client]): The boto Glue client used to launch the Glue job
"""

def __init__(
self,
client: boto3.client,
context_injector: PipesContextInjector,
message_reader: Optional[PipesMessageReader] = None,
client: Optional[boto3.client] = None,
):
self._client = client
self._client = client or boto3.client("glue")
self._context_injector = context_injector
self._message_reader = message_reader or PipesCloudWatchMessageReader()
self._message_reader = message_reader or PipesGlueLogsMessageReader()

@classmethod
def _is_dagster_maintained(cls) -> bool:
Expand Down Expand Up @@ -377,19 +440,10 @@ def run(
# so we need to filter them out
params = {k: v for k, v in params.items() if v is not None}

try:
response = self._client.start_job_run(**params)
run_id = response["JobRunId"]
context.log.info(f"Started AWS Glue job {job_name} run: {run_id}")
response = self._wait_for_job_run_completion(job_name, run_id)

if response["JobRun"]["JobRunState"] == "FAILED":
raise RuntimeError(
f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}"
)
else:
context.log.info(f"Glue job {job_name} run {run_id} completed successfully")
start_timestamp = time.time() * 1000 # unix time in ms

try:
run_id = self._client.start_job_run(**params)["JobRunId"]
except ClientError as err:
context.log.error(
"Couldn't create job %s. Here's why: %s: %s",
Expand All @@ -399,11 +453,25 @@ def run(
)
raise

# TODO: get logs from CloudWatch. there are 2 separate streams for stdout and driver stderr to read from
# the log group can be found in the response from start_job_run, and the log stream is the job run id
# worker logs have log streams like: <job_id>_<worker_id> but we probably don't need to read those
response = self._client.get_job_run(JobName=job_name, RunId=run_id)
log_group = response["JobRun"]["LogGroupName"]
context.log.info(f"Started AWS Glue job {job_name} run: {run_id}")

response = self._wait_for_job_run_completion(job_name, run_id)

if response["JobRun"]["JobRunState"] == "FAILED":
raise RuntimeError(
f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}"
)
else:
context.log.info(f"Glue job {job_name} run {run_id} completed successfully")

if isinstance(self._message_reader, PipesCloudWatchMessageReader):
# TODO: receive messages from a background thread in real-time
self._message_reader.consume_cloudwatch_logs(
f"{log_group}/output", run_id, start_time=int(start_timestamp)
)

# should probably have a way to return the lambda result payload
return PipesClientCompletedInvocation(session)

def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
import subprocess
import sys
import tempfile
from typing import Dict, Optional
import time
from typing import Dict, Literal, Optional

import boto3


class LocalGlueMockClient:
def __init__(self, s3_client: boto3.client, glue_client: boto3.client):
def __init__(
self,
aws_endpoint_url: str, # usually received from moto
s3_client: boto3.client,
glue_client: boto3.client,
pipes_messages_backend: Literal["s3", "cloudwatch"],
cloudwatch_client: Optional[boto3.client] = None,
):
"""This class wraps moto3 clients for S3 and Glue, and provides a way to "run" Glue jobs locally.
This is necessary because moto3 does not actually run anything when you start a Glue job, so we won't be able
to receive any Dagster messages from it.
If pipes_messages_backend is configured to be CloudWatch, it also uploads stderr and stdout logs to CloudWatch
as if this has been done by Glue.
"""
self.aws_endpoint_url = aws_endpoint_url
self.s3_client = s3_client
self.glue_client = glue_client
self.pipes_messages_backend = pipes_messages_backend
self.cloudwatch_client = cloudwatch_client

def get_job_run(self, *args, **kwargs):
return self.glue_client.get_job_run(*args, **kwargs)

def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], *args, **kwargs):
def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwargs):
params = {
"JobName": JobName,
}
Expand All @@ -44,12 +57,46 @@ def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], *args
result = subprocess.run(
[sys.executable, f.name, *args],
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
capture_output=True,
)

# mock the job run with moto
response = self.glue_client.start_job_run(**params)
job_run_id = response["JobRunId"]

job_run_response = self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)
log_group = job_run_response["JobRun"]["LogGroupName"]

if self.pipes_messages_backend == "cloudwatch":
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)

self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)

for line in result.stderr.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output", # yes, Glue routes stderr to /output
logStreamName=job_run_id,
logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

# replace run state with actual results
response["JobRun"] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import textwrap
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Iterator
from typing import Any, Callable, Iterator, Literal

import boto3
import pytest
Expand All @@ -24,6 +24,7 @@
from dagster._core.pipes.utils import PipesEnvContextInjector
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster_aws.pipes import (
PipesCloudWatchMessageReader,
PipesGlueClient,
PipesLambdaClient,
PipesLambdaLogsMessageReader,
Expand Down Expand Up @@ -283,6 +284,7 @@ def fake_lambda_asset(context):
def external_s3_glue_script(s3_client) -> Iterator[str]:
# This is called in an external process and so cannot access outer scope
def script_fn():
import os
import time

import boto3
Expand All @@ -293,12 +295,18 @@ def script_fn():
open_dagster_pipes,
)

client = boto3.client("s3", region_name="us-east-1", endpoint_url="http://localhost:5193")
context_loader = PipesS3ContextLoader(client=client)
message_writer = PipesS3MessageWriter(client, interval=0.001)
s3_client = boto3.client(
"s3", region_name="us-east-1", endpoint_url="http://localhost:5193"
)

messages_backend = os.environ["PIPES_MESSAGES_BACKEND"]
if messages_backend == "s3":
message_writer = PipesS3MessageWriter(s3_client, interval=0.001)
else:
message_writer = None

with open_dagster_pipes(
context_loader=context_loader,
context_loader=PipesS3ContextLoader(client=s3_client),
message_writer=message_writer,
params_loader=PipesCliArgsParamsLoader(),
) as context:
Expand Down Expand Up @@ -341,9 +349,25 @@ def glue_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client
return client


def test_glue_s3_pipes(capsys, s3_client, glue_client):
@pytest.fixture
def cloudwatch_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client:
return boto3.client("logs", region_name="us-east-1", endpoint_url=_MOTO_SERVER_URL)


@pytest.mark.parametrize("pipes_messages_backend", ["s3", "cloudwatch"])
def test_glue_pipes(
capsys,
s3_client,
glue_client,
cloudwatch_client,
pipes_messages_backend: Literal["s3", "cloudwatch"],
):
context_injector = PipesS3ContextInjector(bucket=_S3_TEST_BUCKET, client=s3_client)
message_reader = PipesS3MessageReader(bucket=_S3_TEST_BUCKET, client=s3_client, interval=0.001)
message_reader = (
PipesS3MessageReader(bucket=_S3_TEST_BUCKET, client=s3_client, interval=0.001)
if pipes_messages_backend == "s3"
else PipesCloudWatchMessageReader(client=cloudwatch_client)
)

@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient):
Expand All @@ -355,7 +379,13 @@ def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient):
return results

pipes_glue_client = PipesGlueClient(
client=LocalGlueMockClient(glue_client=glue_client, s3_client=s3_client),
client=LocalGlueMockClient(
aws_endpoint_url=_MOTO_SERVER_URL,
glue_client=glue_client,
s3_client=s3_client,
cloudwatch_client=cloudwatch_client,
pipes_messages_backend=pipes_messages_backend,
),
context_injector=context_injector,
message_reader=message_reader,
)
Expand Down

0 comments on commit cf82b96

Please sign in to comment.