diff --git a/src/deadline_worker_agent/aws/deadline/__init__.py b/src/deadline_worker_agent/aws/deadline/__init__.py index 68b2af3d..361424a2 100644 --- a/src/deadline_worker_agent/aws/deadline/__init__.py +++ b/src/deadline_worker_agent/aws/deadline/__init__.py @@ -820,3 +820,18 @@ def record_sync_outputs_telemetry_event(queue_id: str, summary: SummaryStatistic event_type="com.amazon.rum.deadline.worker_agent.sync_outputs_summary", event_details=details, ) + + +def record_sync_inputs_fail_telemetry_event( + queue_id: str, + failure_reason: str, +) -> None: + """Calls the telemetry client to record an event capturing the sync-inputs failure.""" + details = { + "queue_id": queue_id, + "failure_reason": failure_reason, + } + _get_deadline_telemetry_client().record_event( + event_type="com.amazon.rum.deadline.worker_agent.sync_inputs_failure", + event_details=details, + ) diff --git a/src/deadline_worker_agent/sessions/session.py b/src/deadline_worker_agent/sessions/session.py index 98d23b9b..1e6eab01 100644 --- a/src/deadline_worker_agent/sessions/session.py +++ b/src/deadline_worker_agent/sessions/session.py @@ -64,7 +64,11 @@ ) from deadline.job_attachments.progress_tracker import ProgressReportMetadata, SummaryStatistics -from ..aws.deadline import record_sync_inputs_telemetry_event, record_sync_outputs_telemetry_event +from ..aws.deadline import ( + record_sync_inputs_fail_telemetry_event, + record_sync_inputs_telemetry_event, + record_sync_outputs_telemetry_event, +) from ..scheduler.session_action_status import SessionActionStatus from ..sessions.errors import SessionActionError @@ -83,6 +87,14 @@ } TIME_DELTA_ZERO = timedelta() +# During a SYNC_INPUT_JOB_ATTACHMENTS session action, the transfer rate is periodically reported through +# a callback function. If a transfer rate lower than LOW_TRANSFER_RATE_THRESHOLD is observed in a series +# for LOW_TRANSFER_COUNT_THRESHOLD times, it is considered concerning or potentially stalled, and the +# session action is canceled. +LOW_TRANSFER_RATE_THRESHOLD = 10 * 10**3 # 10 KB/s +LOW_TRANSFER_COUNT_THRESHOLD = ( + 60 # Each progress report takes 1 sec at the longest, so 60 reports amount to 1 min in total. +) logger = getLogger(__name__) @@ -776,12 +788,45 @@ def sync_asset_inputs( if self._asset_sync is None: return - def progress_handler(job_upload_status: ProgressReportMetadata) -> bool: + low_transfer_count = 0 + + def progress_handler(job_attachments_download_status: ProgressReportMetadata) -> bool: + """ + Callback for Job Attachments' sync_inputs() to track the download progress. + Returns True if the operation should continue as normal or False to cancel. + """ + # Check the transfer rate from the progress report. It monitors for a series of + # alarmingly low transfer rates, and if the count exceeds the specified threshold, + # cancels the download and fails the current (SYNC_INPUT_JOB_ATTACHMENTS) action. + nonlocal low_transfer_count + transfer_rate = job_attachments_download_status.transferRate + + if transfer_rate < LOW_TRANSFER_RATE_THRESHOLD: + low_transfer_count += 1 + else: + low_transfer_count = 0 + if low_transfer_count >= LOW_TRANSFER_COUNT_THRESHOLD: + cancel.set() + action_status = ActionStatus( + state=ActionState.FAILED, + fail_message=( + f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). " + f"The transfer rate was below the threshold for the last {self._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}." + ), + ) + self.update_action(action_status) + # Send the telemetry data of input syncing failure due to insufficient download speed. + record_sync_inputs_fail_telemetry_event( + queue_id=self._queue_id, + failure_reason=(f"Insufficient download speed: {action_status.fail_message}"), + ) + return False + self.update_action( action_status=ActionStatus( state=ActionState.RUNNING, - status_message=job_upload_status.progressMessage, - progress=job_upload_status.progress, + status_message=job_attachments_download_status.progressMessage, + progress=job_attachments_download_status.progress, ), ) return not cancel.is_set() @@ -888,6 +933,18 @@ def progress_handler(job_upload_status: ProgressReportMetadata) -> bool: # sort here since we're modifying that internal list appending to the list. self._session._path_mapping_rules.sort(key=lambda rule: -len(rule.source_path.parts)) + def _seconds_to_minutes_str(self, seconds: int) -> str: + minutes = seconds // 60 + remaining_seconds = seconds % 60 + if minutes > 0 and remaining_seconds > 0: + return f"{minutes} minute{'s' if minutes != 1 else ''} {remaining_seconds} second{'s' if remaining_seconds != 1 else ''}" + elif minutes > 0: + return f"{minutes} minute{'s' if minutes != 1 else ''}" + elif remaining_seconds == 0: + return "0 seconds" + else: + return f"{remaining_seconds} second{'s' if remaining_seconds != 1 else ''}" + def update_action(self, action_status: ActionStatus) -> None: """Callback called on every Open Job Description status/progress update and the completion/exit of the current action. diff --git a/test/unit/sessions/test_session.py b/test/unit/sessions/test_session.py index e7fad32b..11c3cf96 100644 --- a/test/unit/sessions/test_session.py +++ b/test/unit/sessions/test_session.py @@ -30,8 +30,10 @@ from deadline_worker_agent.api_models import EnvironmentAction, TaskRunAction from deadline_worker_agent.sessions import Session -from deadline_worker_agent.sessions import session as session_module +import deadline_worker_agent.sessions.session as session_mod from deadline_worker_agent.sessions.session import ( + LOW_TRANSFER_RATE_THRESHOLD, + LOW_TRANSFER_COUNT_THRESHOLD, CurrentAction, SessionActionStatus, ) @@ -51,9 +53,11 @@ JobAttachmentS3Settings, ) from deadline.job_attachments.os_file_permission import PosixFileSystemPermissionSettings -from deadline.job_attachments.progress_tracker import SummaryStatistics - -import deadline_worker_agent.sessions.session as session_mod +from deadline.job_attachments.progress_tracker import ( + ProgressReportMetadata, + ProgressStatus, + SummaryStatistics, +) @pytest.fixture(params=(PosixSessionUser(user="some-user", group="some-group"),)) @@ -121,15 +125,13 @@ def action_update_lock() -> MagicMock: @pytest.fixture(autouse=True) def mock_telemetry_event_for_sync_inputs() -> Generator[MagicMock, None, None]: - with patch.object(session_module, "record_sync_inputs_telemetry_event") as mock_telemetry_event: + with patch.object(session_mod, "record_sync_inputs_telemetry_event") as mock_telemetry_event: yield mock_telemetry_event @pytest.fixture(autouse=True) def mock_telemetry_event_for_sync_outputs() -> Generator[MagicMock, None, None]: - with patch.object( - session_module, "record_sync_outputs_telemetry_event" - ) as mock_telemetry_event: + with patch.object(session_mod, "record_sync_outputs_telemetry_event") as mock_telemetry_event: yield mock_telemetry_event @@ -662,6 +664,77 @@ def test_sync_asset_inputs( sync_asset_inputs_args_sequence ) + def test_sync_asset_inputs_cancellation_by_low_transfer_rate( + self, + session: Session, + mock_asset_sync: MagicMock, + ): + """ + Tests that the session is canceled if it observes a series of alarmingly low transfer rates. + """ + + # Mock out the Job Attachment's sync_inputs function to report multiple consecutive low transfer rates + # (lower than the threshold) via callback function. + def mock_sync_inputs(on_downloading_files, *args, **kwargs): + low_transfer_rate_report = ProgressReportMetadata( + status=ProgressStatus.DOWNLOAD_IN_PROGRESS, + progress=0.0, + transferRate=LOW_TRANSFER_RATE_THRESHOLD / 2, + progressMessage="", + ) + for _ in range(LOW_TRANSFER_COUNT_THRESHOLD): + on_downloading_files(low_transfer_rate_report) + return ({}, {}) + + mock_asset_sync.sync_inputs = mock_sync_inputs + mock_cancel = MagicMock(spec=Event) + + with patch.object(session, "update_action") as mock_update_action, patch.object( + session_mod, "record_sync_inputs_fail_telemetry_event" + ) as mock_record_sync_inputs_fail_telemetry_event: + session.sync_asset_inputs( + cancel=mock_cancel, + job_attachment_details=JobAttachmentDetails( + manifests=[], + job_attachments_file_system=JobAttachmentsFileSystem.COPIED, + ), + ) + mock_cancel.set.assert_called_once() + mock_update_action.assert_called_with( + ActionStatus( + state=ActionState.FAILED, + fail_message=( + f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). " + f"The transfer rate was below the threshold for the last {session._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}." + ), + ), + ) + mock_record_sync_inputs_fail_telemetry_event.assert_called_once_with( + queue_id="queue-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + failure_reason=( + "Insufficient download speed: " + f"Input syncing failed due to successive low transfer rates (< {LOW_TRANSFER_RATE_THRESHOLD / 1000} KB/s). " + f"The transfer rate was below the threshold for the last {session._seconds_to_minutes_str(LOW_TRANSFER_COUNT_THRESHOLD)}." + ), + ) + + @pytest.mark.parametrize( + "seconds, expected_str", + [ + (0, "0 seconds"), + (1, "1 second"), + (30, "30 seconds"), + (60, "1 minute"), + (61, "1 minute 1 second"), + (90, "1 minute 30 seconds"), + (120, "2 minutes"), + (121, "2 minutes 1 second"), + (150, "2 minutes 30 seconds"), + ], + ) + def test_seconds_to_minutes_str(self, session: Session, seconds: int, expected_str: str): + assert session._seconds_to_minutes_str(seconds) == expected_str + class TestSessionSyncAssetOutputs: @pytest.fixture(autouse=True) @@ -1254,7 +1327,7 @@ def test_success_task_run( def mock_now(*arg, **kwarg) -> datetime: return action_complete_time - with patch.object(session_module, "datetime") as mock_datetime, patch.object( + with patch.object(session_mod, "datetime") as mock_datetime, patch.object( session, "_sync_asset_outputs" ) as mock_sync_asset_outputs: mock_datetime.now.side_effect = mock_now @@ -1335,7 +1408,7 @@ def test_success_task_run_fail_output_sync( def mock_now(*arg, **kwarg) -> datetime: return action_complete_time - with patch.object(session_module, "datetime") as mock_datetime, patch.object( + with patch.object(session_mod, "datetime") as mock_datetime, patch.object( session, "_sync_asset_outputs", side_effect=sync_outputs_exception ) as mock_sync_asset_outputs: mock_datetime.now.side_effect = mock_now