Skip to content

Commit

Permalink
feat: Cancel Job Attachments session action when transfer rates drop …
Browse files Browse the repository at this point in the history
…below threshold (#143)

Signed-off-by: Gahyun Suh <132245153+gahyusuh@users.noreply.github.com>
  • Loading branch information
gahyusuh committed Feb 23, 2024
1 parent 296d1d7 commit c49bbb4
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 14 deletions.
15 changes: 15 additions & 0 deletions src/deadline_worker_agent/aws/deadline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,3 +820,18 @@ def record_sync_outputs_telemetry_event(queue_id: str, summary: SummaryStatistic
event_type="com.amazon.rum.deadline.worker_agent.sync_outputs_summary",
event_details=details,
)


def record_sync_inputs_fail_telemetry_event(
queue_id: str,
failure_reason: str,
) -> None:
"""Calls the telemetry client to record an event capturing the sync-inputs failure."""
details = {
"queue_id": queue_id,
"failure_reason": failure_reason,
}
_get_deadline_telemetry_client().record_event(
event_type="com.amazon.rum.deadline.worker_agent.sync_inputs_failure",
event_details=details,
)
65 changes: 61 additions & 4 deletions src/deadline_worker_agent/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@
)
from deadline.job_attachments.progress_tracker import ProgressReportMetadata, SummaryStatistics

from ..aws.deadline import record_sync_inputs_telemetry_event, record_sync_outputs_telemetry_event
from ..aws.deadline import (
record_sync_inputs_fail_telemetry_event,
record_sync_inputs_telemetry_event,
record_sync_outputs_telemetry_event,
)
from ..scheduler.session_action_status import SessionActionStatus
from ..sessions.errors import SessionActionError

Expand All @@ -83,6 +87,14 @@
}
TIME_DELTA_ZERO = timedelta()

# During a SYNC_INPUT_JOB_ATTACHMENTS session action, the transfer rate is periodically reported through
# a callback function. If a transfer rate lower than LOW_TRANSFER_RATE_THRESHOLD is observed in a series
# for LOW_TRANSFER_COUNT_THRESHOLD times, it is considered concerning or potentially stalled, and the
# session action is canceled.
LOW_TRANSFER_RATE_THRESHOLD = 10 * 10**3 # 10 KB/s
LOW_TRANSFER_COUNT_THRESHOLD = (
60 # Each progress report takes 1 sec at the longest, so 60 reports amount to 1 min in total.
)

logger = getLogger(__name__)

Expand Down Expand Up @@ -776,12 +788,45 @@ def sync_asset_inputs(
if self._asset_sync is None:
return

def progress_handler(job_upload_status: ProgressReportMetadata) -> bool:
low_transfer_count = 0

def progress_handler(job_attachments_download_status: ProgressReportMetadata) -> bool:
"""
Callback for Job Attachments' sync_inputs() to track the download progress.
Returns True if the operation should continue as normal or False to cancel.
"""
# Check the transfer rate from the progress report. It monitors for a series of
# alarmingly low transfer rates, and if the count exceeds the specified threshold,
# cancels the download and fails the current (SYNC_INPUT_JOB_ATTACHMENTS) action.
nonlocal low_transfer_count
transfer_rate = job_attachments_download_status.transferRate

if transfer_rate < LOW_TRANSFER_RATE_THRESHOLD:
low_transfer_count += 1
else:
low_transfer_count = 0
if low_transfer_count >= LOW_TRANSFER_COUNT_THRESHOLD:
cancel.set()
action_status = ActionStatus(
state=ActionState.FAILED,
fail_message=(
f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). "
f"The transfer rate was below the threshold for the last {self._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}."
),
)
self.update_action(action_status)
# Send the telemetry data of input syncing failure due to insufficient download speed.
record_sync_inputs_fail_telemetry_event(
queue_id=self._queue_id,
failure_reason=(f"Insufficient download speed: {action_status.fail_message}"),
)
return False

self.update_action(
action_status=ActionStatus(
state=ActionState.RUNNING,
status_message=job_upload_status.progressMessage,
progress=job_upload_status.progress,
status_message=job_attachments_download_status.progressMessage,
progress=job_attachments_download_status.progress,
),
)
return not cancel.is_set()
Expand Down Expand Up @@ -888,6 +933,18 @@ def progress_handler(job_upload_status: ProgressReportMetadata) -> bool:
# sort here since we're modifying that internal list appending to the list.
self._session._path_mapping_rules.sort(key=lambda rule: -len(rule.source_path.parts))

def _seconds_to_minutes_str(self, seconds: int) -> str:
minutes = seconds // 60
remaining_seconds = seconds % 60
if minutes > 0 and remaining_seconds > 0:
return f"{minutes} minute{'s' if minutes != 1 else ''} {remaining_seconds} second{'s' if remaining_seconds != 1 else ''}"
elif minutes > 0:
return f"{minutes} minute{'s' if minutes != 1 else ''}"
elif remaining_seconds == 0:
return "0 seconds"
else:
return f"{remaining_seconds} second{'s' if remaining_seconds != 1 else ''}"

def update_action(self, action_status: ActionStatus) -> None:
"""Callback called on every Open Job Description status/progress update and the completion/exit of the
current action.
Expand Down
93 changes: 83 additions & 10 deletions test/unit/sessions/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@

from deadline_worker_agent.api_models import EnvironmentAction, TaskRunAction
from deadline_worker_agent.sessions import Session
from deadline_worker_agent.sessions import session as session_module
import deadline_worker_agent.sessions.session as session_mod
from deadline_worker_agent.sessions.session import (
LOW_TRANSFER_RATE_THRESHOLD,
LOW_TRANSFER_COUNT_THRESHOLD,
CurrentAction,
SessionActionStatus,
)
Expand All @@ -51,9 +53,11 @@
JobAttachmentS3Settings,
)
from deadline.job_attachments.os_file_permission import PosixFileSystemPermissionSettings
from deadline.job_attachments.progress_tracker import SummaryStatistics

import deadline_worker_agent.sessions.session as session_mod
from deadline.job_attachments.progress_tracker import (
ProgressReportMetadata,
ProgressStatus,
SummaryStatistics,
)


@pytest.fixture(params=(PosixSessionUser(user="some-user", group="some-group"),))
Expand Down Expand Up @@ -121,15 +125,13 @@ def action_update_lock() -> MagicMock:

@pytest.fixture(autouse=True)
def mock_telemetry_event_for_sync_inputs() -> Generator[MagicMock, None, None]:
with patch.object(session_module, "record_sync_inputs_telemetry_event") as mock_telemetry_event:
with patch.object(session_mod, "record_sync_inputs_telemetry_event") as mock_telemetry_event:
yield mock_telemetry_event


@pytest.fixture(autouse=True)
def mock_telemetry_event_for_sync_outputs() -> Generator[MagicMock, None, None]:
with patch.object(
session_module, "record_sync_outputs_telemetry_event"
) as mock_telemetry_event:
with patch.object(session_mod, "record_sync_outputs_telemetry_event") as mock_telemetry_event:
yield mock_telemetry_event


Expand Down Expand Up @@ -662,6 +664,77 @@ def test_sync_asset_inputs(
sync_asset_inputs_args_sequence
)

def test_sync_asset_inputs_cancellation_by_low_transfer_rate(
self,
session: Session,
mock_asset_sync: MagicMock,
):
"""
Tests that the session is canceled if it observes a series of alarmingly low transfer rates.
"""

# Mock out the Job Attachment's sync_inputs function to report multiple consecutive low transfer rates
# (lower than the threshold) via callback function.
def mock_sync_inputs(on_downloading_files, *args, **kwargs):
low_transfer_rate_report = ProgressReportMetadata(
status=ProgressStatus.DOWNLOAD_IN_PROGRESS,
progress=0.0,
transferRate=LOW_TRANSFER_RATE_THRESHOLD / 2,
progressMessage="",
)
for _ in range(LOW_TRANSFER_COUNT_THRESHOLD):
on_downloading_files(low_transfer_rate_report)
return ({}, {})

mock_asset_sync.sync_inputs = mock_sync_inputs
mock_cancel = MagicMock(spec=Event)

with patch.object(session, "update_action") as mock_update_action, patch.object(
session_mod, "record_sync_inputs_fail_telemetry_event"
) as mock_record_sync_inputs_fail_telemetry_event:
session.sync_asset_inputs(
cancel=mock_cancel,
job_attachment_details=JobAttachmentDetails(
manifests=[],
job_attachments_file_system=JobAttachmentsFileSystem.COPIED,
),
)
mock_cancel.set.assert_called_once()
mock_update_action.assert_called_with(
ActionStatus(
state=ActionState.FAILED,
fail_message=(
f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). "
f"The transfer rate was below the threshold for the last {session._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}."
),
),
)
mock_record_sync_inputs_fail_telemetry_event.assert_called_once_with(
queue_id="queue-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
failure_reason=(
"Insufficient download speed: "
f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). "
f"The transfer rate was below the threshold for the last {session._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}."
),
)

@pytest.mark.parametrize(
"seconds, expected_str",
[
(0, "0 seconds"),
(1, "1 second"),
(30, "30 seconds"),
(60, "1 minute"),
(61, "1 minute 1 second"),
(90, "1 minute 30 seconds"),
(120, "2 minutes"),
(121, "2 minutes 1 second"),
(150, "2 minutes 30 seconds"),
],
)
def test_seconds_to_minutes_str(self, session: Session, seconds: int, expected_str: str):
assert session._seconds_to_minutes_str(seconds) == expected_str


class TestSessionSyncAssetOutputs:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -1254,7 +1327,7 @@ def test_success_task_run(
def mock_now(*arg, **kwarg) -> datetime:
return action_complete_time

with patch.object(session_module, "datetime") as mock_datetime, patch.object(
with patch.object(session_mod, "datetime") as mock_datetime, patch.object(
session, "_sync_asset_outputs"
) as mock_sync_asset_outputs:
mock_datetime.now.side_effect = mock_now
Expand Down Expand Up @@ -1335,7 +1408,7 @@ def test_success_task_run_fail_output_sync(
def mock_now(*arg, **kwarg) -> datetime:
return action_complete_time

with patch.object(session_module, "datetime") as mock_datetime, patch.object(
with patch.object(session_mod, "datetime") as mock_datetime, patch.object(
session, "_sync_asset_outputs", side_effect=sync_outputs_exception
) as mock_sync_asset_outputs:
mock_datetime.now.side_effect = mock_now
Expand Down

0 comments on commit c49bbb4

Please sign in to comment.