Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Cancel Job Attachments session action when transfer rates drop below threshold #143

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/deadline_worker_agent/aws/deadline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,3 +820,18 @@ def record_sync_outputs_telemetry_event(queue_id: str, summary: SummaryStatistic
event_type="com.amazon.rum.deadline.worker_agent.sync_outputs_summary",
event_details=details,
)


def record_sync_inputs_fail_telemetry_event(
queue_id: str,
failure_reason: str,
) -> None:
"""Calls the telemetry client to record an event capturing the sync-inputs failure."""
details = {
"queue_id": queue_id,
"failure_reason": failure_reason,
}
_get_deadline_telemetry_client().record_event(
event_type="com.amazon.rum.deadline.worker_agent.sync_inputs_failure",
event_details=details,
)
65 changes: 61 additions & 4 deletions src/deadline_worker_agent/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@
)
from deadline.job_attachments.progress_tracker import ProgressReportMetadata, SummaryStatistics

from ..aws.deadline import record_sync_inputs_telemetry_event, record_sync_outputs_telemetry_event
from ..aws.deadline import (
record_sync_inputs_fail_telemetry_event,
record_sync_inputs_telemetry_event,
record_sync_outputs_telemetry_event,
)
from ..scheduler.session_action_status import SessionActionStatus
from ..sessions.errors import SessionActionError

Expand All @@ -83,6 +87,14 @@
}
TIME_DELTA_ZERO = timedelta()

# During a SYNC_INPUT_JOB_ATTACHMENTS session action, the transfer rate is periodically reported through
# a callback function. If a transfer rate lower than LOW_TRANSFER_RATE_THRESHOLD is observed in a series
# for LOW_TRANSFER_COUNT_THRESHOLD times, it is considered concerning or potentially stalled, and the
# session action is canceled.
LOW_TRANSFER_RATE_THRESHOLD = 10 * 10**3 # 10 KB/s
LOW_TRANSFER_COUNT_THRESHOLD = (
60 # Each progress report takes 1 sec at the longest, so 60 reports amount to 1 min in total.
)

logger = getLogger(__name__)

Expand Down Expand Up @@ -776,12 +788,45 @@ def sync_asset_inputs(
if self._asset_sync is None:
return

def progress_handler(job_upload_status: ProgressReportMetadata) -> bool:
low_transfer_count = 0

def progress_handler(job_attachments_download_status: ProgressReportMetadata) -> bool:
"""
Callback for Job Attachments' sync_inputs() to track the download progress.
Returns True if the operation should continue as normal or False to cancel.
"""
# Check the transfer rate from the progress report. It monitors for a series of
# alarmingly low transfer rates, and if the count exceeds the specified threshold,
# cancels the download and fails the current (SYNC_INPUT_JOB_ATTACHMENTS) action.
nonlocal low_transfer_count
transfer_rate = job_attachments_download_status.transferRate

if transfer_rate < LOW_TRANSFER_RATE_THRESHOLD:
low_transfer_count += 1
else:
low_transfer_count = 0
if low_transfer_count >= LOW_TRANSFER_COUNT_THRESHOLD:
cancel.set()
action_status = ActionStatus(
state=ActionState.FAILED,
fail_message=(
f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). "
f"The transfer rate was below the threshold for the last {self._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}."
),
)
self.update_action(action_status)
# Send the telemetry data of input syncing failure due to insufficient download speed.
record_sync_inputs_fail_telemetry_event(
queue_id=self._queue_id,
failure_reason=(f"Insufficient download speed: {action_status.fail_message}"),
)
return False

self.update_action(
action_status=ActionStatus(
state=ActionState.RUNNING,
status_message=job_upload_status.progressMessage,
progress=job_upload_status.progress,
status_message=job_attachments_download_status.progressMessage,
progress=job_attachments_download_status.progress,
),
)
return not cancel.is_set()
Expand Down Expand Up @@ -888,6 +933,18 @@ def progress_handler(job_upload_status: ProgressReportMetadata) -> bool:
# sort here since we're modifying that internal list appending to the list.
self._session._path_mapping_rules.sort(key=lambda rule: -len(rule.source_path.parts))

def _seconds_to_minutes_str(self, seconds: int) -> str:
minutes = seconds // 60
remaining_seconds = seconds % 60
if minutes > 0 and remaining_seconds > 0:
return f"{minutes} minute{'s' if minutes != 1 else ''} {remaining_seconds} second{'s' if remaining_seconds != 1 else ''}"
elif minutes > 0:
return f"{minutes} minute{'s' if minutes != 1 else ''}"
elif remaining_seconds == 0:
return "0 seconds"
else:
return f"{remaining_seconds} second{'s' if remaining_seconds != 1 else ''}"

def update_action(self, action_status: ActionStatus) -> None:
"""Callback called on every Open Job Description status/progress update and the completion/exit of the
current action.
Expand Down
93 changes: 83 additions & 10 deletions test/unit/sessions/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@

from deadline_worker_agent.api_models import EnvironmentAction, TaskRunAction
from deadline_worker_agent.sessions import Session
from deadline_worker_agent.sessions import session as session_module
import deadline_worker_agent.sessions.session as session_mod
from deadline_worker_agent.sessions.session import (
LOW_TRANSFER_RATE_THRESHOLD,
LOW_TRANSFER_COUNT_THRESHOLD,
CurrentAction,
SessionActionStatus,
)
Expand All @@ -51,9 +53,11 @@
JobAttachmentS3Settings,
)
from deadline.job_attachments.os_file_permission import PosixFileSystemPermissionSettings
from deadline.job_attachments.progress_tracker import SummaryStatistics

import deadline_worker_agent.sessions.session as session_mod
from deadline.job_attachments.progress_tracker import (
ProgressReportMetadata,
ProgressStatus,
SummaryStatistics,
)


@pytest.fixture(params=(PosixSessionUser(user="some-user", group="some-group"),))
Expand Down Expand Up @@ -121,15 +125,13 @@ def action_update_lock() -> MagicMock:

@pytest.fixture(autouse=True)
def mock_telemetry_event_for_sync_inputs() -> Generator[MagicMock, None, None]:
with patch.object(session_module, "record_sync_inputs_telemetry_event") as mock_telemetry_event:
with patch.object(session_mod, "record_sync_inputs_telemetry_event") as mock_telemetry_event:
yield mock_telemetry_event


@pytest.fixture(autouse=True)
def mock_telemetry_event_for_sync_outputs() -> Generator[MagicMock, None, None]:
with patch.object(
session_module, "record_sync_outputs_telemetry_event"
) as mock_telemetry_event:
with patch.object(session_mod, "record_sync_outputs_telemetry_event") as mock_telemetry_event:
yield mock_telemetry_event


Expand Down Expand Up @@ -662,6 +664,77 @@ def test_sync_asset_inputs(
sync_asset_inputs_args_sequence
)

def test_sync_asset_inputs_cancellation_by_low_transfer_rate(
self,
session: Session,
mock_asset_sync: MagicMock,
):
"""
Tests that the session is canceled if it observes a series of alarmingly low transfer rates.
"""

# Mock out the Job Attachment's sync_inputs function to report multiple consecutive low transfer rates
# (lower than the threshold) via callback function.
def mock_sync_inputs(on_downloading_files, *args, **kwargs):
low_transfer_rate_report = ProgressReportMetadata(
status=ProgressStatus.DOWNLOAD_IN_PROGRESS,
progress=0.0,
transferRate=LOW_TRANSFER_RATE_THRESHOLD / 2,
progressMessage="",
)
for _ in range(LOW_TRANSFER_COUNT_THRESHOLD):
on_downloading_files(low_transfer_rate_report)
return ({}, {})

mock_asset_sync.sync_inputs = mock_sync_inputs
mock_cancel = MagicMock(spec=Event)

with patch.object(session, "update_action") as mock_update_action, patch.object(
session_mod, "record_sync_inputs_fail_telemetry_event"
) as mock_record_sync_inputs_fail_telemetry_event:
session.sync_asset_inputs(
cancel=mock_cancel,
job_attachment_details=JobAttachmentDetails(
manifests=[],
job_attachments_file_system=JobAttachmentsFileSystem.COPIED,
),
)
mock_cancel.set.assert_called_once()
mock_update_action.assert_called_with(
ActionStatus(
state=ActionState.FAILED,
fail_message=(
f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). "
f"The transfer rate was below the threshold for the last {session._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}."
),
),
)
mock_record_sync_inputs_fail_telemetry_event.assert_called_once_with(
queue_id="queue-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
failure_reason=(
"Insufficient download speed: "
f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). "
f"The transfer rate was below the threshold for the last {session._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}."
),
)

@pytest.mark.parametrize(
"seconds, expected_str",
[
(0, "0 seconds"),
(1, "1 second"),
(30, "30 seconds"),
(60, "1 minute"),
(61, "1 minute 1 second"),
(90, "1 minute 30 seconds"),
(120, "2 minutes"),
(121, "2 minutes 1 second"),
(150, "2 minutes 30 seconds"),
],
)
def test_seconds_to_minutes_str(self, session: Session, seconds: int, expected_str: str):
assert session._seconds_to_minutes_str(seconds) == expected_str


class TestSessionSyncAssetOutputs:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -1254,7 +1327,7 @@ def test_success_task_run(
def mock_now(*arg, **kwarg) -> datetime:
return action_complete_time

with patch.object(session_module, "datetime") as mock_datetime, patch.object(
with patch.object(session_mod, "datetime") as mock_datetime, patch.object(
session, "_sync_asset_outputs"
) as mock_sync_asset_outputs:
mock_datetime.now.side_effect = mock_now
Expand Down Expand Up @@ -1335,7 +1408,7 @@ def test_success_task_run_fail_output_sync(
def mock_now(*arg, **kwarg) -> datetime:
return action_complete_time

with patch.object(session_module, "datetime") as mock_datetime, patch.object(
with patch.object(session_mod, "datetime") as mock_datetime, patch.object(
session, "_sync_asset_outputs", side_effect=sync_outputs_exception
) as mock_sync_asset_outputs:
mock_datetime.now.side_effect = mock_now
Expand Down