diff --git a/pyproject.toml b/pyproject.toml index 0ca44451..43e29a02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,13 @@ ignore = [ "F811", ] +[tool.ruff.lint.per-file-ignores] +# We need to use a platform assertion to short-circuit mypy type checking on non-Windows platforms +# https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks +# This causes imports to come after regular Python statements causing flake8 rule E402 to be flagged +"src/deadline_worker_agent/**/*windows*.py" = ["E402"] +"test/**/*windows*.py" = ["E402"] + [tool.ruff.lint.isort] known-first-party = [ "deadline_worker_agent", @@ -153,6 +160,9 @@ omit = [ "*/scheduler/**/*.py", "*/worker.py", ] +plugins = [ + "coverage_conditional_plugin" +] [tool.coverage.paths] source = [ "src/" ] @@ -161,6 +171,25 @@ source = [ "src/" ] show_missing = true fail_under = 78 +# https://github.com/wemake-services/coverage-conditional-plugin +[tool.coverage.coverage_conditional_plugin.omit] +"sys_platform != 'win32'" = [ + "src/deadline_worker_agent/windows/*.py", + "src/deadline_worker_agent/installer/win_installer.py" +] + +[tool.coverage.coverage_conditional_plugin.rules] +# This cannot be empty otherwise coverage-conditional-plugin crashes with: +# AttributeError: 'NoneType' object has no attribute 'items' +# +# =========== WARNING TO REVIEWERS ============ +# +# Any rules added here are ran through Python's +# eval() function so watch for code injection +# attacks. +# +# =========== WARNING TO REVIEWERS ============ + [tool.semantic_release] # Can be removed or set to true once we are v1 major_on_zero = false diff --git a/requirements-testing.txt b/requirements-testing.txt index 85ce1191..8e4ab999 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -1,4 +1,5 @@ coverage[toml] ~= 7.4 +coverage-conditional-plugin == 0.9.* deadline-cloud-test-fixtures == 0.5.* pytest ~= 8.1 pytest-cov == 4.1.* diff --git a/src/deadline_worker_agent/installer/__init__.py b/src/deadline_worker_agent/installer/__init__.py index 81245f9b..38e7d855 100644 --- a/src/deadline_worker_agent/installer/__init__.py +++ b/src/deadline_worker_agent/installer/__init__.py @@ -34,8 +34,8 @@ def install() -> None: fleet_id=args.fleet_id, region=args.region, worker_agent_program=scripts_path, - no_install_service=not args.install_service, - start=args.service_start, + install_service=args.install_service, + start_service=args.service_start, confirm=args.confirmed, allow_shutdown=args.allow_shutdown, parser=arg_parser, @@ -96,8 +96,8 @@ class ParsedCommandLineArguments(Namespace): fleet_id: str region: str user: str - password: Optional[str] - group: Optional[str] + password: Optional[str] = None + group: Optional[str] = None confirmed: bool service_start: bool allow_shutdown: bool @@ -184,7 +184,10 @@ def get_argument_parser() -> ArgumentParser: # pragma: no cover if sys.platform == "win32": parser.add_argument( "--password", - help="The password for the AWS Deadline Cloud Worker Agent user. Defaults to generating a password.", + help=( + "The password for the AWS Deadline Cloud Worker Agent user. Defaults to generating a password " + "if the user does not exist or prompting for the password if the user pre-exists." + ), required=False, default=None, ) diff --git a/src/deadline_worker_agent/installer/win_installer.py b/src/deadline_worker_agent/installer/win_installer.py index d3e10d4b..422aefed 100644 --- a/src/deadline_worker_agent/installer/win_installer.py +++ b/src/deadline_worker_agent/installer/win_installer.py @@ -8,14 +8,10 @@ import shutil import string import sys -import typing from argparse import ArgumentParser +from getpass import getpass from pathlib import Path - -from deadline_worker_agent.file_system_operations import ( - _set_windows_permissions, - FileSystemPermissionEnum, -) +from typing import Optional import deadline.client.config.config_file import pywintypes @@ -23,9 +19,18 @@ import win32net import win32netcon import win32security +import win32service +import win32serviceutil import winerror +from openjd.sessions import BadCredentialsException, WindowsSessionUser from win32comext.shell import shell +from ..file_system_operations import ( + _set_windows_permissions, + FileSystemPermissionEnum, +) +from ..windows.win_service import WorkerAgentWindowsService + # Defaults DEFAULT_WA_USER = "deadline-worker" @@ -172,6 +177,9 @@ def ensure_local_agent_user(username: str, password: str) -> None: """ if check_user_existence(username): logging.info(f"Agent User {username} already exists") + # This is only to verify the credentials. It will raise a BadCredentialsError if the + # credentials cannot be used to logon the user + WindowsSessionUser(user=username, password=password) else: logging.info(f"Creating Agent user {username}") user_info = { @@ -254,7 +262,7 @@ def update_config_file( deadline_config_sub_directory: str, farm_id: str, fleet_id: str, - shutdown_on_stop: typing.Optional[bool] = None, + shutdown_on_stop: Optional[bool] = None, ) -> None: """ Updates the worker configuration file, creating it from the example if it does not exist. @@ -435,6 +443,155 @@ def update_deadline_client_config( os.environ.update(old_environ) +def _install_service( + *, + agent_user_name: str, + password: str, +) -> None: + """Installs the Windows Service that hosts the Worker Agent + + Parameters + agent_user_name(str): Worker Agent's account username + password(str): The Worker Agent's user account password + """ + + # If the username does not contain the domain, then assume the local domain + # https://learn.microsoft.com/en-us/windows/win32/secauthn/user-name-formats + if "\\" not in agent_user_name and "@" not in agent_user_name: + agent_user_name = f".\\{agent_user_name}" + + # Determine the Windows Service configuration. This uses the same logic as + # win32serviceutil.HandleCommandLine() so that the service can be debugged + # using: + # + # python -m deadline_worker_agent.windows.win_service debug + service_class_str = win32serviceutil.GetServiceClassString(WorkerAgentWindowsService) + service_name = WorkerAgentWindowsService._svc_name_ + service_display_name = WorkerAgentWindowsService._svc_display_name_ + service_description = getattr(WorkerAgentWindowsService, "_svc_description_", None) + exe_name = getattr(WorkerAgentWindowsService, "_exe_name_", None) + exe_args = getattr(WorkerAgentWindowsService, "_exe_args_", None) + + # Configure the service to start on boot + startup = win32service.SERVICE_AUTO_START + + logging.info(f'Configuring Windows Service "{service_display_name}"...') + try: + win32serviceutil.InstallService( + service_class_str, + service_name, + service_display_name, + serviceDeps=None, + startType=startup, + bRunInteractive=None, + userName=agent_user_name, + password=password, + exeName=exe_name, + perfMonIni=None, + perfMonDll=None, + exeArgs=exe_args, + description=service_description, + delayedstart=False, + ) + except win32service.error as exc: + if exc.winerror != winerror.ERROR_SERVICE_EXISTS: + raise + logging.info(f'Service "{service_display_name}" already exists, updating instead...') + win32serviceutil.ChangeServiceConfig( + service_class_str, + service_name, + serviceDeps=None, + startType=startup, + bRunInteractive=None, + userName=agent_user_name, + password=password, + exeName=exe_name, + displayName=service_display_name, + perfMonIni=None, + perfMonDll=None, + exeArgs=exe_args, + description=service_description, + delayedstart=False, + ) + logging.info(f'Successfully updated Windows Service "{service_display_name}"') + else: + logging.info(f'Successfully created Windows Service "{service_display_name}"') + + logging.info(f'Configuring the failure actions of Windows Service "{service_display_name}"...') + configure_service_failure_actions(service_name) + logging.info( + f'Successfully configured the failure actions for Window Service "{service_display_name}"' + ) + + +def configure_service_failure_actions(service_name): + """Configures the failure actions of the Windows Service. + + We use exponential backoff with a base of 2 seconds and doubling each iteration. This grows until + it reaches ~4m 16s and then repeats indefinitely at this interval. The backoff resets if the service + heals and stays alive for 20 minutes. + + This uses the ChangeServiceConfig2 win32 API: + https://learn.microsoft.com/en-us/windows/win32/api/winsvc/nf-winsvc-changeserviceconfig2w + + Notably, the third parameter of ChangeServiceConfig2 expects a SERVICE_FAILURE_ACTIONSW structure. + whose API reference docs best explains how Windows Service failure actions work: + https://learn.microsoft.com/en-us/windows/win32/api/winsvc/ns-winsvc-service_failure_actionsw#remarks + """ + + # pywin32's ChangeServiceConfig2 wrapper accepts tuples ofs: (action type, delay in ms) + # Exponential backoff with base of 2 seconds (2000 ms), doubling each iteration. + # The backoff grows from 2 seconds to ~4m 16s over 8 attempts totalling 510s (or 8m 30s). + actions = [(win32service.SC_ACTION_RESTART, 2000 * 2**i) for i in range(8)] + + logging.debug("Opening the Service Control Manager...") + scm = win32service.OpenSCManager(None, None, win32service.SC_MANAGER_ALL_ACCESS) + logging.debug("Successfully opened the Service Control Manager") + try: + logging.debug(f'Opening the Windows Service "{service_name}"') + service = win32service.OpenService(scm, service_name, win32service.SERVICE_ALL_ACCESS) + logging.debug(f'Successfully opened the Windows Service "{service_name}"') + + logging.debug(f'Modifying the failure actions of Windows Service "{service_name}...') + try: + win32service.ChangeServiceConfig2( + service, + win32service.SERVICE_CONFIG_FAILURE_ACTIONS, + { + # Repeat the last action (restart with ~4m 16s delay) until the service recovers + # for 20 minutes (in seconds) + "ResetPeriod": 20 * 60, + "RebootMsg": None, + "Command": None, + "Actions": actions, + }, + ) + logging.debug( + f'Successfully modified the failure actions of Windows Service "{service_name}...' + ) + finally: + logging.debug(f'Closing the Windows Service "{service_name}"..') + win32service.CloseServiceHandle(service) + logging.debug(f'Successfully closed the Windows Service "{service_name}"') + finally: + logging.debug("Closing the Service Control Manager...") + win32service.CloseServiceHandle(scm) + logging.debug("Successfully closed the Service Control Manager") + + +def _start_service() -> None: + """Starts the Windows Service hosting the Worker Agent""" + service_name = WorkerAgentWindowsService._svc_name_ + + logging.info(f'Starting service "{service_name}"...') + try: + win32serviceutil.StartService(serviceName=service_name) + except Exception as e: + logging.warning(f'Failed to start service "{service_name}": {e}') + else: + logging.info(f'Successfully started service "{service_name}"') + + def start_windows_installer( farm_id: str, fleet_id: str, @@ -442,11 +599,11 @@ def start_windows_installer( worker_agent_program: Path, allow_shutdown: bool, parser: ArgumentParser, - password: typing.Optional[str] = None, user_name: str = DEFAULT_WA_USER, + password: Optional[str] = None, group_name: str = DEFAULT_JOB_GROUP, - no_install_service: bool = False, - start: bool = False, + install_service: bool = False, + start_service: bool = False, confirm: bool = False, telemetry_opt_out: bool = False, ): @@ -469,8 +626,6 @@ def print_helping_info_and_exit(): elif not validate_deadline_id("fleet", fleet_id): logging.error(f"Not a valid value for Fleet id: {fleet_id}") print_helping_info_and_exit() - if not password: - password = generate_password() # Check that user has Administrator privileges if not shell.IsUserAnAdmin(): @@ -479,6 +634,18 @@ def print_helping_info_and_exit(): # Print configuration print_banner() + + if not password: + if check_user_existence(user_name): + password = getpass("Agent user password: ") + try: + WindowsSessionUser(user_name, password=password) + except BadCredentialsException: + print("ERROR: Password incorrect") + sys.exit(1) + else: + password = generate_password() + print( f"Farm ID: {farm_id}\n" f"Fleet ID: {fleet_id}\n" @@ -487,9 +654,11 @@ def print_helping_info_and_exit(): f"Worker job group: {group_name}\n" f"Worker agent program path: {str(worker_agent_program)}\n" f"Allow worker agent shutdown: {allow_shutdown}\n" - f"Start service: {start}\n" + f"Install Windows service: {install_service}\n" + f"Start service: {start_service}" f"Telemetry opt-out: {telemetry_opt_out}" ) + print() # Confirm installation if not confirm: @@ -515,9 +684,11 @@ def print_helping_info_and_exit(): # Check if the job group exists, and create it if not ensure_local_queue_user_group_exists(group_name) + # Add the worker agent user to the job group add_user_to_group(group_name, user_name) + # Create directories and configure their permissions agent_dirs = provision_directories(user_name) update_config_file( str(agent_dirs.deadline_config_subdir), @@ -539,3 +710,14 @@ def print_helping_info_and_exit(): settings={"telemetry.opt_out": "true"}, ) logging.info("Opted out of client telemetry") + + # Install the Windows service if specified + if install_service: + _install_service( + agent_user_name=user_name, + password=password, + ) + + # Start the Windows service if specified + if start_service: + _start_service() diff --git a/src/deadline_worker_agent/scheduler/scheduler.py b/src/deadline_worker_agent/scheduler/scheduler.py index 57bbbd31..c7d02f46 100644 --- a/src/deadline_worker_agent/scheduler/scheduler.py +++ b/src/deadline_worker_agent/scheduler/scheduler.py @@ -17,10 +17,10 @@ import logging import os import stat +import sys from openjd.sessions import ActionState, ActionStatus, SessionUser from openjd.sessions import LOG as OPENJD_SESSION_LOG -from openjd.sessions import ActionState, ActionStatus from deadline.job_attachments.asset_sync import AssetSync from ..aws.deadline import update_worker @@ -54,7 +54,12 @@ from ..startup.config import JobsRunAsUserOverride from ..utils import MappingWithCallbacks from ..file_system_operations import FileSystemPermissionEnum, make_directory, touch_file -from ..windows_credentials_resolver import WindowsCredentialsResolver + +if sys.platform == "win32": + from ..windows.win_credentials_resolver import WindowsCredentialsResolver +else: + WindowsCredentialsResolver = Any + logger = LOGGER @@ -179,6 +184,7 @@ def __init__( worker_persistence_dir: Path, worker_logs_dir: Path | None, retain_session_dir: bool = False, + stop: Event | None = None, ) -> None: """Queue of Worker Sessions and their actions @@ -198,7 +204,7 @@ def __init__( self._executor = ThreadPoolExecutor(max_workers=100) self._sessions = SessionMap(cleanup_session_user_processes=cleanup_session_user_processes) self._wakeup = Event() - self._shutdown = Event() + self._shutdown = stop or Event() self._farm_id = farm_id self._fleet_id = fleet_id self._worker_id = worker_id @@ -283,6 +289,9 @@ def run(self) -> None: raise finally: self._drain_scheduler() + if os.name == "nt": + assert self._windows_credentials_resolver is not None + self._windows_credentials_resolver.clear() def _drain_scheduler(self) -> None: # Note: diff --git a/src/deadline_worker_agent/sessions/job_entities/job_entities.py b/src/deadline_worker_agent/sessions/job_entities/job_entities.py index b38a260e..fa664247 100644 --- a/src/deadline_worker_agent/sessions/job_entities/job_entities.py +++ b/src/deadline_worker_agent/sessions/job_entities/job_entities.py @@ -6,7 +6,7 @@ from logging import getLogger from threading import Event, Thread from typing import Any, Iterator, Iterable, TYPE_CHECKING, TypeVar, Union, cast, Optional -from ...windows_credentials_resolver import WindowsCredentialsResolver +import sys from ...api_models import ( EntityIdentifier, @@ -41,8 +41,14 @@ BaseEntityErrorFields, EntityDetails, ) + + if sys.platform == "win32": + from ...windows.win_credentials_resolver import WindowsCredentialsResolver + else: + WindowsCredentialsResolver = Any else: BaseEntityErrorFields = Any + WindowsCredentialsResolver = Any S = TypeVar( diff --git a/src/deadline_worker_agent/startup/entrypoint.py b/src/deadline_worker_agent/startup/entrypoint.py index aa8d0a1c..836192ec 100644 --- a/src/deadline_worker_agent/startup/entrypoint.py +++ b/src/deadline_worker_agent/startup/entrypoint.py @@ -8,7 +8,9 @@ import os import subprocess import sys +from getpass import getuser from logging.handlers import TimedRotatingFileHandler +from threading import Event from typing import Optional from pathlib import Path @@ -38,7 +40,7 @@ _logger = logging.getLogger(__name__) -def entrypoint(cli_args: Optional[list[str]] = None) -> None: +def entrypoint(cli_args: Optional[list[str]] = None, *, stop: Optional[Event] = None) -> None: """Entrypoint for the Worker Agent. The worker gets registered and then polls for tasks to complete. @@ -135,6 +137,7 @@ def filter(self, record: logging.LogRecord) -> bool: host_metrics_logging=config.host_metrics_logging, host_metrics_logging_interval_seconds=config.host_metrics_logging_interval_seconds, retain_session_dir=config.retain_session_dir, + stop=stop, ) try: worker_sessions.run() @@ -308,6 +311,13 @@ def _log_agent_info() -> None: _logger.info(f"Platform: {sys.platform}") _logger.info("Agent Version: %s", __version__) _logger.info("Installed at: %s", str(Path(__file__).resolve().parent.parent)) + try: + user = getuser() + except Exception: + # This is best-effort. If we cannot determine the user we will not log + pass + else: + _logger.info("Running as: %s", user) _logger.info("Dependency versions installed:") _logger.info("\topenjd.model: %s", openjd_model_version) _logger.info("\topenjd.sessions: %s", openjd_sessions_version) diff --git a/src/deadline_worker_agent/windows/__init__.py b/src/deadline_worker_agent/windows/__init__.py new file mode 100644 index 00000000..8d929cc8 --- /dev/null +++ b/src/deadline_worker_agent/windows/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/src/deadline_worker_agent/windows_credentials_resolver.py b/src/deadline_worker_agent/windows/win_credentials_resolver.py similarity index 56% rename from src/deadline_worker_agent/windows_credentials_resolver.py rename to src/deadline_worker_agent/windows/win_credentials_resolver.py index 941cfc13..64579275 100644 --- a/src/deadline_worker_agent/windows_credentials_resolver.py +++ b/src/deadline_worker_agent/windows/win_credentials_resolver.py @@ -1,22 +1,42 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# This assertion short-circuits mypy from type checking this module on platforms other than Windows +# https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks +import sys + +assert sys.platform == "win32" + + import json import os +from ctypes.wintypes import HANDLE as cHANDLE from datetime import datetime, timedelta, timezone from logging import getLogger -from typing import Dict, Optional +from typing import Any, Dict, Optional, TYPE_CHECKING from botocore.client import BaseClient from botocore.exceptions import ClientError from botocore.retries.standard import RetryContext - from openjd.sessions import WindowsSessionUser, BadCredentialsException +from pywintypes import HANDLE as PyHANDLE +from win32security import ( + LogonUser, + LOGON32_LOGON_NETWORK_CLEARTEXT, + LOGON32_PROVIDER_DEFAULT, +) +from win32profile import LoadUserProfile, PI_NOUI, UnloadUserProfile + +if TYPE_CHECKING: + from _win32typing import PyHKEY +else: + PyHKEY = Any -from .boto import ( +from ..boto import ( OTHER_BOTOCORE_CONFIG, NoOverflowExponentialBackoff as Backoff, Session as BotoSession, ) +from . import win_service logger = getLogger(__name__) @@ -27,10 +47,14 @@ def __init__( windows_session_user: Optional[WindowsSessionUser], last_fetched_at: datetime, last_accessed: datetime, + user_profile: Optional[PyHKEY] = None, + logon_token: Optional[PyHANDLE] = None, ): self.windows_session_user = windows_session_user self.last_fetched_at = last_fetched_at self.last_accessed = last_accessed + self.user_profile = user_profile + self.logon_token = logon_token class WindowsCredentialsResolver: @@ -44,7 +68,7 @@ def __init__( self, boto_session: BotoSession, ) -> None: - if os.name != "nt": + if os.name != "nt": # pragma: no cover raise RuntimeError("Windows credentials resolver can only be used on Windows") self._boto_session = boto_session self._user_cache: Dict[str, _WindowsCredentialsCacheEntry] = {} @@ -92,14 +116,52 @@ def _fetch_secret_from_secrets_manager(self, secretArn: str) -> dict: raise ValueError(f"Contents of secret {secretArn} is not valid JSON.") def prune_cache(self): - now = datetime.now(tz=timezone.utc) + # If we are running as a Windows Service, we maintain a logon token for the user and + # do not need to persist the password nor rotate it. + if win_service.is_windows_session_zero(): + return + # Filter out entries that haven't been accessed in the last CACHE_EXPIRATION hours + now = datetime.now(tz=timezone.utc) self._user_cache = { key: value for key, value in self._user_cache.items() if now - value.last_accessed < self.CACHE_EXPIRATION } + def clear(self): + """Clears all users from the cache and cleans up any open resources""" + if win_service.is_windows_session_zero(): + for user in self._user_cache.values(): + if user.windows_session_user: + logger.info( + f"Removing user {user.windows_session_user.user} from the windows credentials resolver cache" + ) + if user.user_profile: + # https://timgolden.me.uk/pywin32-docs/win32profile__UnloadUserProfile_meth.html + UnloadUserProfile(user.windows_session_user.logon_token, user.user_profile) + assert user.logon_token is not None + user.logon_token.Close() + self._user_cache.clear() + + @staticmethod + def _user_cache_key(*, user_name: str, password_arn: str) -> str: + """Returns the cache key for a given user and password ARN + + This behavior differs in a Windows Service. Through experimentation, we can use the + LogonUserW and CreateProcessAsUserW win32 APIs. We can cache the Windows logon token + handle from LogonUserW indefinitely which should remain valid after password rotations. + + Outside a Windows Service, we must use the CreateProcessWithLogonW API which requires + a username and password. For this reason, our cache key should use the password secret + ARN since a change of secret may imply a change of password. + """ + if win_service.is_windows_session_zero(): + return user_name + else: + # Create a composite key using user and arn + return f"{user_name}_{password_arn}" + def get_windows_session_user(self, user: str, passwordArn: str) -> WindowsSessionUser: # Raises ValueError on problems so that the scheduler can cleanly fail the associated jobs # Any failure here should be cached so that we wait self.RETRY_AFTER minutes before fetching @@ -107,8 +169,10 @@ def get_windows_session_user(self, user: str, passwordArn: str) -> WindowsSessio # Create a composite key using user and arn should_fetch = True - user_key = f"{user}_{passwordArn}" + user_key = self._user_cache_key(user_name=user, password_arn=passwordArn) windows_session_user: Optional[WindowsSessionUser] = None + logon_token: Optional[PyHANDLE] = None + user_profile: Optional[PyHKEY] = None # Prune the cache before fetching or returning the user self.prune_cache() @@ -145,13 +209,39 @@ def get_windows_session_user(self, user: str, passwordArn: str) -> WindowsSessio f'Contents of secret {passwordArn} did not match the expected format: {"password":"value"}' ) else: - try: - # OpenJD will test the ultimate validity of the credentials when creating a WindowsSessionUser - windows_session_user = WindowsSessionUser(user=user, password=password) - except BadCredentialsException: - logger.error( - f"Username and/or password within {passwordArn} were not correct" - ) + if win_service.is_windows_session_zero(): + try: + # https://timgolden.me.uk/pywin32-docs/win32profile__LoadUserProfile_meth.html + logon_token = LogonUser( + Username=user, + LogonType=LOGON32_LOGON_NETWORK_CLEARTEXT, + LogonProvider=LOGON32_PROVIDER_DEFAULT, + Password=password, + Domain=None, + ) + # https://timgolden.me.uk/pywin32-docs/win32profile__LoadUserProfile_meth.html + user_profile = LoadUserProfile( + logon_token, + { + "UserName": user, + "Flags": PI_NOUI, + "ProfilePath": None, + }, + ) + windows_session_user = WindowsSessionUser( + user=user, + logon_token=cHANDLE(int(logon_token)), + ) + except OSError as e: + logger.error(f'Error logging on as "{user}": {e}') + else: + try: + # OpenJD will test the ultimate validity of the credentials when creating a WindowsSessionUser + windows_session_user = WindowsSessionUser(user=user, password=password) + except BadCredentialsException: + logger.error( + f"Username and/or password within {passwordArn} were not correct" + ) # Cache the WindowsSessionUser object, last fetched at, and last accessed time for future use # If the credentials were not valid cache that too to prevent repeated calls to SecretsManager @@ -159,6 +249,8 @@ def get_windows_session_user(self, user: str, passwordArn: str) -> WindowsSessio windows_session_user=windows_session_user, last_fetched_at=datetime.now(tz=timezone.utc), last_accessed=datetime.now(tz=timezone.utc), + logon_token=logon_token, + user_profile=user_profile, ) if not windows_session_user: diff --git a/src/deadline_worker_agent/windows/win_service.py b/src/deadline_worker_agent/windows/win_service.py new file mode 100644 index 00000000..d441a890 --- /dev/null +++ b/src/deadline_worker_agent/windows/win_service.py @@ -0,0 +1,84 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import socket +import logging +from functools import cache +from threading import Event + +import win32process +import win32serviceutil +import win32service +import win32ts +import servicemanager + +from deadline_worker_agent.startup.entrypoint import entrypoint + + +logger = logging.getLogger(__name__) + + +class WorkerAgentWindowsService(win32serviceutil.ServiceFramework): + # Pywin32 Service Configuration + _exe_name_ = "DeadlineWorkerService.exe" + _svc_name_ = "DeadlineWorker" + _svc_display_name_ = "AWS Deadline Cloud Worker" + _svc_description_ = ( + "Service hosting the AWS Deadline Cloud Worker Agent. Connects to AWS " + "Deadline Cloud and runs jobs as a worker in a fleet." + ) + + _stop_event: Event + + def __init__(self, args): + win32serviceutil.ServiceFramework.__init__(self, args) + self._stop_event = Event() + socket.setdefaulttimeout(60) + + def SvcStop(self): + """Invoked when the Windows Service is being stopped""" + self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING) + logger.info("Windows Service is being stopped") + self._stop_event.set() + + def SvcDoRun(self): + """The main entrypoint called after the service is started""" + servicemanager.LogMsg( + servicemanager.EVENTLOG_INFORMATION_TYPE, + servicemanager.PYS_SERVICE_STARTED, + (self._svc_name_, ""), + ) + entrypoint(cli_args=[], stop=self._stop_event) + servicemanager.LogMsg( + servicemanager.EVENTLOG_INFORMATION_TYPE, + servicemanager.PYS_SERVICE_STOPPED, + (self._svc_name_, ""), + ) + logger.info("Stop status sent to Windows Service Controller") + + +def _get_current_process_session() -> int: + """Returns the Windows session ID number for the current process + + Returns + ------- + int + The session ID of the current process + """ + process_id = win32process.GetCurrentProcessId() + return win32ts.ProcessIdToSessionId(process_id) + + +@cache +def is_windows_session_zero() -> bool: + """Returns whether the current Python process is running in Windows session 0. + + Returns + ------- + bool + True if the current process is running in Windows session 0 + """ + return _get_current_process_session() == 0 + + +if __name__ == "__main__": + win32serviceutil.HandleCommandLine(WorkerAgentWindowsService) diff --git a/src/deadline_worker_agent/worker.py b/src/deadline_worker_agent/worker.py index 82962f0c..bf352a7a 100644 --- a/src/deadline_worker_agent/worker.py +++ b/src/deadline_worker_agent/worker.py @@ -91,6 +91,7 @@ def __init__( host_metrics_logging: bool, host_metrics_logging_interval_seconds: float | None = None, retain_session_dir: bool = False, + stop: Event | None = None, ) -> None: self._deadline_client = deadline_client self._s3_client = s3_client @@ -110,8 +111,9 @@ def __init__( worker_persistence_dir=worker_persistence_dir, worker_logs_dir=worker_logs_dir, retain_session_dir=retain_session_dir, + stop=stop, ) - self._stop = Event() + self._stop = stop or Event() self._boto_session = boto_session self._worker_persistence_dir = worker_persistence_dir self._retain_session_dir = retain_session_dir @@ -124,12 +126,23 @@ def __init__( logger=logger, interval_s=host_metrics_logging_interval_seconds ) - signal.signal(signal.SIGTERM, self._signal_handler) - signal.signal(signal.SIGINT, self._signal_handler) - if os.name == "posix": + signal.signal(signal.SIGTERM, self._signal_handler) + signal.signal(signal.SIGINT, self._signal_handler) # TODO: Remove this once WA is stable or put behind a debug flag signal.signal(signal.SIGUSR1, self._output_thread_stacks) # type: ignore + elif os.name == "nt": + from .windows.win_service import is_windows_session_zero + + # If we are in session 0, we are running as a Windows Service using pywin32 + # pywin32's pythonservice.exe owns the main thread and the Python application + # appears to run on a secondary thread. Python only allows registering signal + # handlers on the main thread and we only need them in the interactive case + # anyways + if not is_windows_session_zero(): + signal.signal(signal.SIGTERM, self._signal_handler) + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGBREAK, self._signal_handler) # type: ignore[attr-defined] def _signal_handler(self, signum: int, frame: FrameType | None = None) -> None: """ @@ -137,7 +150,12 @@ def _signal_handler(self, signum: int, frame: FrameType | None = None) -> None: gracefully wind-down what it's currently doing. This will set the _interrupted flag to True when we get such a signal. """ - if signum in (signal.SIGTERM, signal.SIGINT): + if ( + signum in (signal.SIGTERM, signal.SIGINT) + or + # This is to relax mypy since signal.SIGBREAK is only defined on Windows + (sys.platform == "win32" and signum == getattr(signal, "SIGBREAK")) + ): logger.info(f"Received signal {signum}. Initiating application shutdown.") self._interrupted = True self._scheduler.shutdown( diff --git a/test/unit/install/test_windows_installer.py b/test/unit/install/test_windows_installer.py index 86080f8b..3520542f 100644 --- a/test/unit/install/test_windows_installer.py +++ b/test/unit/install/test_windows_installer.py @@ -2,13 +2,16 @@ import string import sys +import sysconfig +from pathlib import Path +from typing import Generator +from unittest.mock import Mock, call, patch, MagicMock + import pytest if sys.platform != "win32": pytest.skip("Windows-specific tests", allow_module_level=True) -import pywintypes -from pywintypes import error as PyWinTypesError from deadline_worker_agent.installer.win_installer import ( ensure_local_queue_user_group_exists, ensure_local_agent_user, @@ -16,14 +19,19 @@ start_windows_installer, validate_deadline_id, ) -import sysconfig -from pathlib import Path + +from pywintypes import error as PyWinTypesError +from win32comext.shell import shell +from win32service import error as win_service_error +from win32serviceutil import GetServiceClassString +import win32netcon +import win32service +import winerror + from deadline_worker_agent import installer as installer_mod from deadline_worker_agent.installer import ParsedCommandLineArguments, install -import pytest -from unittest.mock import patch, MagicMock -import win32netcon -from win32comext.shell import shell +from deadline_worker_agent.installer import win_installer +from deadline_worker_agent.windows.win_service import WorkerAgentWindowsService def test_start_windows_installer( @@ -44,8 +52,8 @@ def test_start_windows_installer( fleet_id=parsed_args.fleet_id, region=parsed_args.region, worker_agent_program=Path(sysconfig.get_path("scripts")), - no_install_service=not parsed_args.install_service, - start=parsed_args.service_start, + install_service=parsed_args.install_service, + start_service=parsed_args.service_start, confirm=parsed_args.confirmed, parser=mock_get_arg_parser(), user_name=parsed_args.user, @@ -71,8 +79,8 @@ def test_start_windows_installer_fails_when_run_as_non_admin_user( fleet_id=parsed_args.fleet_id, region=parsed_args.region, worker_agent_program=Path(sysconfig.get_path("scripts")), - no_install_service=not parsed_args.install_service, - start=parsed_args.service_start, + install_service=parsed_args.install_service, + start_service=parsed_args.service_start, confirm=parsed_args.confirmed, parser=mock_get_arg_parser(), user_name=parsed_args.user, @@ -110,7 +118,7 @@ def test_unexpected_error_code_handling(group_name): with patch("win32net.NetLocalGroupGetInfo", side_effect=MockPyWinTypesError(9999)), patch( "win32net.NetLocalGroupAdd" ) as mock_group_add, patch("logging.error"): - with pytest.raises(pywintypes.error): + with pytest.raises(PyWinTypesError): ensure_local_queue_user_group_exists(group_name) mock_group_add.assert_not_called() @@ -192,3 +200,457 @@ def test_non_valid_deadline_id1(): def test_non_valid_deadline_id_with_wrong_prefix(): assert not validate_deadline_id("deadline", "line-123e4567e89b12d3a456426655441234") + + +class TestInstallService: + """Test cases for the install_service() function""" + + @pytest.fixture(autouse=True) + def mock_configure_service_failure_actions(self) -> Generator[Mock, None, None]: + with patch.object( + win_installer, "configure_service_failure_actions", new_callable=Mock + ) as mock_configure_service_failure_actions: + yield mock_configure_service_failure_actions + + def test_install_service_fresh_successful( + self, + mock_configure_service_failure_actions: Mock, + ) -> None: + """Tests that the installer calls pywin32's InstallService function to install the + Windows Service with the correct arguments and that succeeds as a fresh install""" + # GIVEN + agent_user_name = "myagentuser" + password = "apassword" + expected_service_display_name = WorkerAgentWindowsService._svc_display_name_ + + with ( + patch.object(win_installer.win32serviceutil, "InstallService") as mock_install_service, + patch.object(win_installer.logging, "info") as mock_logging_info, + ): + # WHEN + win_installer._install_service( + agent_user_name=agent_user_name, + password=password, + ) + + # THEN + mock_install_service.assert_called_once_with( + GetServiceClassString(WorkerAgentWindowsService), + WorkerAgentWindowsService._svc_name_, + expected_service_display_name, + serviceDeps=None, + startType=win32service.SERVICE_AUTO_START, + bRunInteractive=None, + userName=f".\\{agent_user_name}", + password=password, + exeName=getattr(WorkerAgentWindowsService, "_exe_name_", None), + perfMonIni=None, + perfMonDll=None, + exeArgs=getattr(WorkerAgentWindowsService, "_exe_args_", None), + description=getattr(WorkerAgentWindowsService, "_svc_description_", None), + delayedstart=False, + ) + mock_logging_info.assert_has_calls( + calls=[ + call(f'Configuring Windows Service "{expected_service_display_name}"...'), + call(f'Successfully created Windows Service "{expected_service_display_name}"'), + call( + f'Configuring the failure actions of Windows Service "{expected_service_display_name}"...' + ), + call( + f'Successfully configured the failure actions for Window Service "{expected_service_display_name}"' + ), + ], + ) + mock_configure_service_failure_actions.assert_called_once_with( + WorkerAgentWindowsService._svc_name_ + ) + + @pytest.mark.parametrize( + argnames=("install_service_exception",), + argvalues=( + pytest.param( + win_service_error( + winerror.ERROR_SERVICE_LOGON_FAILED, + "InstallService", + "some error message", + ), + id="win-service-error-not-existing", + ), + pytest.param( + Exception("some other error"), + id="non-win-service-error", + ), + ), + ) + def test_install_service_fresh_fail( + self, + install_service_exception: Exception, + mock_configure_service_failure_actions: Mock, + ) -> None: + """Tests how the _install_service() function deals with exceptions raised by + pywin32's InstallService function other than the one we expect to handle if the service + already exists. + + The exception should not be handled and raised as-is. + """ + # GIVEN + agent_user_name = "myagentuser" + password = "apassword" + expected_service_display_name = WorkerAgentWindowsService._svc_display_name_ + + with ( + patch.object( + win_installer.win32serviceutil, + "InstallService", + side_effect=install_service_exception, + ) as mock_install_service, + patch.object(win_installer.logging, "info") as mock_logging_info, + ): + # WHEN + def when(): + win_installer._install_service( + agent_user_name=agent_user_name, + password=password, + ) + + # THEN + with pytest.raises(type(install_service_exception)) as raise_ctx: + when() + + assert raise_ctx.value is install_service_exception + mock_install_service.assert_called_once_with( + GetServiceClassString(WorkerAgentWindowsService), + WorkerAgentWindowsService._svc_name_, + expected_service_display_name, + serviceDeps=None, + startType=win32service.SERVICE_AUTO_START, + bRunInteractive=None, + userName=f".\\{agent_user_name}", + password=password, + exeName=getattr(WorkerAgentWindowsService, "_exe_name_", None), + perfMonIni=None, + perfMonDll=None, + exeArgs=getattr(WorkerAgentWindowsService, "_exe_args_", None), + description=getattr(WorkerAgentWindowsService, "_svc_description_", None), + delayedstart=False, + ) + mock_logging_info.assert_called_once_with( + f'Configuring Windows Service "{expected_service_display_name}"...' + ) + mock_configure_service_failure_actions.assert_not_called() + + def test_install_service_existing_success( + self, + mock_configure_service_failure_actions: Mock, + ) -> None: + """Tests the behaviour of the _install_service function if the call to pywin32's + InstallService function fails because the service already exists. + + The function is expected to catch this exception and instead call pywin32's + ChangeServiceConfig function. This test asserts that ChangeServiceConfig is called + with the correct arguments.""" + # GIVEN + agent_user_name = "myagentuser" + password = "apassword" + expected_service_display_name = WorkerAgentWindowsService._svc_display_name_ + install_service_error = win_service_error( + winerror.ERROR_SERVICE_EXISTS, + "InstallService", + "service alreadyt exists", + ) + + with ( + patch.object( + win_installer.win32serviceutil, "InstallService", side_effect=install_service_error + ) as mock_install_service, + patch.object( + win_installer.win32serviceutil, "ChangeServiceConfig" + ) as mock_change_service_config, + patch.object(win_installer.logging, "info") as mock_logging_info, + ): + # WHEN + win_installer._install_service( + agent_user_name=agent_user_name, + password=password, + ) + + # THEN + mock_install_service.assert_called_once_with( + GetServiceClassString(WorkerAgentWindowsService), + WorkerAgentWindowsService._svc_name_, + expected_service_display_name, + serviceDeps=None, + startType=win32service.SERVICE_AUTO_START, + bRunInteractive=None, + userName=f".\\{agent_user_name}", + password=password, + exeName=getattr(WorkerAgentWindowsService, "_exe_name_", None), + perfMonIni=None, + perfMonDll=None, + exeArgs=getattr(WorkerAgentWindowsService, "_exe_args_", None), + description=getattr(WorkerAgentWindowsService, "_svc_description_", None), + delayedstart=False, + ) + mock_change_service_config.assert_called_once_with( + GetServiceClassString(WorkerAgentWindowsService), + WorkerAgentWindowsService._svc_name_, + serviceDeps=None, + startType=win32service.SERVICE_AUTO_START, + bRunInteractive=None, + userName=f".\\{agent_user_name}", + password=password, + exeName=getattr(WorkerAgentWindowsService, "_exe_name_", None), + displayName=expected_service_display_name, + perfMonIni=None, + perfMonDll=None, + exeArgs=getattr(WorkerAgentWindowsService, "_exe_args_", None), + description=getattr(WorkerAgentWindowsService, "_svc_description_", None), + delayedstart=False, + ) + mock_logging_info.assert_has_calls( + calls=[ + call(f'Configuring Windows Service "{expected_service_display_name}"...'), + call( + f'Service "{expected_service_display_name}" already exists, updating instead...' + ), + call(f'Successfully updated Windows Service "{expected_service_display_name}"'), + call( + f'Configuring the failure actions of Windows Service "{expected_service_display_name}"...' + ), + call( + f'Successfully configured the failure actions for Window Service "{expected_service_display_name}"' + ), + ], + ) + mock_configure_service_failure_actions.assert_called_once_with( + WorkerAgentWindowsService._svc_name_ + ) + + +class TestConfigureServiceFailureActions: + """Test cases for configure_service_failure_actions()""" + + @pytest.fixture(autouse=True) + def mock_win32_service(self) -> Generator[Mock, None, None]: + with patch.object(win_installer, "win32service", new_callable=Mock) as mock_win32_service: + yield mock_win32_service + + @pytest.fixture + def mock_open_sc_manager(self, mock_win32_service: Mock) -> Mock: + return mock_win32_service.OpenSCManager + + @pytest.fixture + def mock_open_service(self, mock_win32_service: Mock) -> Mock: + return mock_win32_service.OpenService + + @pytest.fixture + def mock_close_service_handle(self, mock_win32_service: Mock) -> Mock: + return mock_win32_service.CloseServiceHandle + + @pytest.fixture + def mock_change_service_config2(self, mock_win32_service: Mock) -> Mock: + return mock_win32_service.ChangeServiceConfig2 + + @pytest.fixture + def mock_logging_debug(self) -> Generator[Mock, None, None]: + with patch.object(win_installer.logging, "debug") as mock_logging_debug: + yield mock_logging_debug + + @pytest.fixture(params=("svc1", "svc2")) + def service_name(self, request) -> str: + return request.param + + def test_success( + self, + mock_win32_service: Mock, + mock_open_sc_manager: Mock, + mock_open_service: Mock, + mock_close_service_handle: Mock, + mock_change_service_config2: Mock, + service_name: str, + mock_logging_debug: MagicMock, + ) -> None: + # WHEN + win_installer.configure_service_failure_actions(service_name) + + # THEN + mock_open_sc_manager.assert_called_once_with( + None, None, mock_win32_service.SC_MANAGER_ALL_ACCESS + ) + mock_open_service.assert_called_once_with( + mock_open_sc_manager.return_value, + service_name, + mock_win32_service.SERVICE_ALL_ACCESS, + ) + mock_change_service_config2.assert_called_once_with( + mock_open_service.return_value, + mock_win32_service.SERVICE_CONFIG_FAILURE_ACTIONS, + { + "ResetPeriod": 1200, + "RebootMsg": None, + "Command": None, + "Actions": [ + (mock_win32_service.SC_ACTION_RESTART, 2000 * 2**i) for i in range(8) + ], + }, + ) + mock_close_service_handle.assert_has_calls( + [call(mock_open_service.return_value), call(mock_open_sc_manager.return_value)], + any_order=False, + ) + assert mock_close_service_handle.call_count == 2 + mock_logging_debug.assert_has_calls( + [ + call("Opening the Service Control Manager..."), + call("Successfully opened the Service Control Manager"), + call(f'Opening the Windows Service "{service_name}"'), + call(f'Successfully opened the Windows Service "{service_name}"'), + call(f'Modifying the failure actions of Windows Service "{service_name}...'), + call( + f'Successfully modified the failure actions of Windows Service "{service_name}...' + ), + call(f'Closing the Windows Service "{service_name}"..'), + call(f'Successfully closed the Windows Service "{service_name}"'), + call("Closing the Service Control Manager..."), + call("Successfully closed the Service Control Manager"), + ], + any_order=False, + ) + + def test_fail_open_scm( + self, + mock_win32_service: Mock, + mock_open_sc_manager: Mock, + mock_open_service: Mock, + mock_close_service_handle: Mock, + mock_change_service_config2: Mock, + mock_logging_debug: Mock, + service_name: str, + ) -> None: + # GIVEN + error = Exception("some error") + mock_open_sc_manager.side_effect = error + + # WHEN + def when() -> None: + win_installer.configure_service_failure_actions(service_name) + + # THEN + with pytest.raises(type(error)) as raise_ctxt: + when() + assert raise_ctxt.value is error + mock_open_sc_manager.assert_called_once_with( + None, None, mock_win32_service.SC_MANAGER_ALL_ACCESS + ) + mock_open_service.assert_not_called() + mock_change_service_config2.assert_not_called() + mock_close_service_handle.assert_not_called() + mock_logging_debug.assert_called_once_with("Opening the Service Control Manager...") + + def test_fail_open_service( + self, + mock_win32_service: Mock, + mock_open_sc_manager: Mock, + mock_open_service: Mock, + mock_close_service_handle: Mock, + mock_change_service_config2: Mock, + mock_logging_debug: Mock, + service_name: str, + ) -> None: + # GIVEN + error = Exception("some error") + mock_open_service.side_effect = error + + # WHEN + def when() -> None: + win_installer.configure_service_failure_actions(service_name) + + # THEN + with pytest.raises(type(error)) as raise_ctxt: + when() + assert raise_ctxt.value is error + mock_open_sc_manager.assert_called_once_with( + None, None, mock_win32_service.SC_MANAGER_ALL_ACCESS + ) + mock_open_service.assert_called_once_with( + mock_open_sc_manager.return_value, + service_name, + mock_win32_service.SERVICE_ALL_ACCESS, + ) + mock_change_service_config2.assert_not_called() + mock_close_service_handle.assert_called_once_with(mock_open_sc_manager.return_value) + mock_logging_debug.assert_has_calls( + [ + call("Opening the Service Control Manager..."), + call("Successfully opened the Service Control Manager"), + call(f'Opening the Windows Service "{service_name}"'), + call("Closing the Service Control Manager..."), + call("Successfully closed the Service Control Manager"), + ], + any_order=False, + ) + assert mock_logging_debug.call_count == 5 + + def test_fail_change_service_config2( + self, + mock_win32_service: Mock, + mock_open_sc_manager: Mock, + mock_open_service: Mock, + mock_close_service_handle: Mock, + mock_change_service_config2: Mock, + mock_logging_debug: Mock, + service_name: str, + ) -> None: + # GIVEN + error = Exception("some error") + mock_change_service_config2.side_effect = error + + # WHEN + def when() -> None: + win_installer.configure_service_failure_actions(service_name) + + # THEN + with pytest.raises(type(error)) as raise_ctxt: + when() + assert raise_ctxt.value is error + mock_open_sc_manager.assert_called_once_with( + None, None, mock_win32_service.SC_MANAGER_ALL_ACCESS + ) + mock_open_service.assert_called_once_with( + mock_open_sc_manager.return_value, + service_name, + mock_win32_service.SERVICE_ALL_ACCESS, + ) + mock_change_service_config2.assert_called_once_with( + mock_open_service.return_value, + mock_win32_service.SERVICE_CONFIG_FAILURE_ACTIONS, + { + "ResetPeriod": 1200, + "RebootMsg": None, + "Command": None, + "Actions": [ + (mock_win32_service.SC_ACTION_RESTART, 2000 * 2**i) for i in range(8) + ], + }, + ) + mock_close_service_handle.assert_has_calls( + [call(mock_open_service.return_value), call(mock_open_sc_manager.return_value)], + any_order=False, + ) + assert mock_close_service_handle.call_count == 2 + mock_logging_debug.assert_has_calls( + [ + call("Opening the Service Control Manager..."), + call("Successfully opened the Service Control Manager"), + call(f'Opening the Windows Service "{service_name}"'), + call(f'Successfully opened the Windows Service "{service_name}"'), + call(f'Modifying the failure actions of Windows Service "{service_name}...'), + call(f'Closing the Windows Service "{service_name}"..'), + call(f'Successfully closed the Windows Service "{service_name}"'), + call("Closing the Service Control Manager..."), + call("Successfully closed the Service Control Manager"), + ], + any_order=False, + ) + assert mock_logging_debug.call_count == 9 diff --git a/test/unit/startup/test_entrypoint.py b/test/unit/startup/test_entrypoint.py index a1bbe2a2..b23c0079 100644 --- a/test/unit/startup/test_entrypoint.py +++ b/test/unit/startup/test_entrypoint.py @@ -121,10 +121,13 @@ def mock_worker_run() -> Generator[MagicMock, None, None]: @pytest.fixture(autouse=True) def mock_windows_credentials_resolver() -> Generator[MagicMock, None, None]: - with patch.object( - scheduler_mod.WindowsCredentialsResolver, "get_windows_session_user" - ) as mock_windows_credentials_resolver: - yield mock_windows_credentials_resolver + if sys.platform == "win32": + with patch.object( + scheduler_mod.WindowsCredentialsResolver, "get_windows_session_user" + ) as mock_windows_credentials_resolver: + yield mock_windows_credentials_resolver + else: + yield MagicMock() @pytest.fixture @@ -605,6 +608,7 @@ def test_passes_worker_logs_dir( host_metrics_logging=ANY, host_metrics_logging_interval_seconds=ANY, retain_session_dir=ANY, + stop=ANY, ) diff --git a/test/unit/test_windows_credentials_resolver.py b/test/unit/test_windows_credentials_resolver.py deleted file mode 100644 index 704e5d1c..00000000 --- a/test/unit/test_windows_credentials_resolver.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - -from datetime import datetime, timedelta -from openjd.sessions import WindowsSessionUser, BadCredentialsException -from unittest.mock import patch, MagicMock -from typing import Generator -from pytest import fixture, mark -import botocore -import os -import pytest - -import deadline_worker_agent.windows_credentials_resolver as credentials_mod - - -class TestWindowsCredentialsResolver: - @fixture(autouse=True) - def now(self) -> datetime: - return datetime(2000, 1, 1) - - @fixture(autouse=True) - def datetime_mock(self, now: datetime) -> Generator[MagicMock, None, None]: - with patch.object(credentials_mod, "datetime") as mock: - mock.side_effect = lambda *args, **kwargs: datetime(*args, **kwargs) - mock.fromtimestamp.side_effect = lambda *args, **kwargs: datetime.fromtimestamp( - *args, **kwargs - ) - mock.now.return_value = now - yield mock - - @mark.skipif(os.name != "nt", reason="Windows-only test.") - def test_prune_cache(self, datetime_mock: MagicMock): - # GIVEN - mock_boto_session = MagicMock() - now = datetime(2023, 1, 1, 12, 0, 0) - datetime_mock.now.return_value = now - resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) - - # Add a user to the cache that should be pruned - expired_user = WindowsSessionUser(user="expired_user", password="fake_password") - expired_entry = credentials_mod._WindowsCredentialsCacheEntry( - windows_session_user=expired_user, - last_fetched_at=now - timedelta(hours=13), - last_accessed=now - timedelta(hours=13), - ) - resolver._user_cache["expired_user_arn"] = expired_entry - - # Add a user to the cache that should be kept - valid_user = WindowsSessionUser(user="valid_user", password="fake_password") - valid_entry = credentials_mod._WindowsCredentialsCacheEntry( - windows_session_user=valid_user, - last_fetched_at=now - timedelta(hours=11), - last_accessed=now - timedelta(hours=11), - ) - resolver._user_cache["valid_user_arn"] = valid_entry - - # WHEN - resolver.prune_cache() - - # THEN - assert len(resolver._user_cache) == 1 - assert "valid_user_arn" in resolver._user_cache - assert "expired_user_arn" not in resolver._user_cache - - @mark.skipif(os.name != "nt", reason="Windows-only test.") - @patch( - "deadline_worker_agent.windows_credentials_resolver.WindowsCredentialsResolver._fetch_secret_from_secrets_manager" - ) - def test_get_windows_session_user_non_cached(self, fetch_secret_mock, datetime_mock): - # GIVEN - mock_boto_session = MagicMock() - now = datetime(2023, 1, 1, 12, 0, 0) - datetime_mock.now.return_value = now - resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) - secret_data = {"password": "fake_password"} - fetch_secret_mock.return_value = secret_data - user = "new_user" - password_arn = "new_password_arn" - - # WHEN - result = resolver.get_windows_session_user(user, password_arn) - - # THEN - fetch_secret_mock.assert_called_once_with(password_arn) - assert isinstance(result, WindowsSessionUser) - assert result.user == user - assert result.password == secret_data["password"] - - @mark.skipif(os.name != "nt", reason="Windows-only test.") - @patch( - "deadline_worker_agent.windows_credentials_resolver.WindowsCredentialsResolver._fetch_secret_from_secrets_manager" - ) - def test_get_windows_session_user_no_password_in_secret(self, fetch_secret_mock, datetime_mock): - # GIVEN - mock_boto_session = MagicMock() - now = datetime(2023, 1, 1, 12, 0, 0) - datetime_mock.now.return_value = now - resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) - secret_data = {"something-other-than-password": "fake_password"} - fetch_secret_mock.return_value = secret_data - user = "new_user" - password_arn = "new_password_arn" - - # WHEN - with pytest.raises(ValueError): - resolver.get_windows_session_user(user, password_arn) - - # THEN - fetch_secret_mock.assert_called_once_with(password_arn) - - @mark.skipif(os.name != "nt", reason="Windows-only test.") - @patch( - "deadline_worker_agent.windows_credentials_resolver.WindowsCredentialsResolver._fetch_secret_from_secrets_manager" - ) - def test_get_windows_session_user_cached(self, fetch_secret_mock, datetime_mock): - # GIVEN - mock_boto_session = MagicMock() - now = datetime(2023, 1, 1, 12, 0, 0) - datetime_mock.now.return_value = now - resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) - password_arn = "password_arn" - user = "user" - user_obj = WindowsSessionUser(user=user, password="fake_cached_password") - cached_entry = credentials_mod._WindowsCredentialsCacheEntry( - windows_session_user=user_obj, - last_fetched_at=now - timedelta(hours=11), - last_accessed=now - timedelta(hours=11), - ) - resolver._user_cache[f"{user}_{password_arn}"] = cached_entry - secret_data = {"password": "fake_new_password"} - fetch_secret_mock.return_value = secret_data - - # WHEN - result = resolver.get_windows_session_user(user, password_arn) - - # THEN - fetch_secret_mock.assert_not_called() - assert isinstance(result, WindowsSessionUser) - assert result.user == user - assert result.password == "fake_cached_password" - - @mark.skipif(os.name != "nt", reason="Windows-only test.") - @patch( - "deadline_worker_agent.windows_credentials_resolver.WindowsCredentialsResolver._fetch_secret_from_secrets_manager" - ) - def test_get_windows_session_user_invalid_credentials(self, fetch_secret_mock, datetime_mock): - # GIVEN - mock_boto_session = MagicMock() - now = datetime(2023, 1, 1, 12, 0, 0) - datetime_mock.now.return_value = now - resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) - secret_data = {"password": "fake_password"} - fetch_secret_mock.return_value = secret_data - user = "new_user" - password_arn = "new_password_arn" - - with patch( - "deadline_worker_agent.windows_credentials_resolver.WindowsSessionUser", - side_effect=BadCredentialsException("Invalid credentials"), - ): - # WHEN - with pytest.raises(ValueError): - resolver.get_windows_session_user(user, password_arn) - assert resolver._user_cache[f"{user}_{password_arn}"].windows_session_user is None - - @pytest.mark.parametrize( - "exception_code", - [ - "ResourceNotFoundException", - "InvalidRequestException", - "DecryptionFailure", - ], - ) - @mark.skipif(os.name != "nt", reason="Windows-only test.") - @patch( - "deadline_worker_agent.windows_credentials_resolver.WindowsCredentialsResolver._get_secrets_manager_client" - ) - def test_fetch_secrets_manager_non_retriable_exception( - self, secrets_manager_client_mock: MagicMock, exception_code: str - ): - # GIVEN - mock_boto_session = MagicMock() - resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) - password_arn = "password_arn" - exc = botocore.exceptions.ClientError( - {"Error": {"Code": exception_code, "Message": "A message"}}, "GetSecretValue" - ) - secrets_manager_client_mock.side_effect = exc - - # THEN - with pytest.raises(RuntimeError): - resolver._fetch_secret_from_secrets_manager(password_arn) - - @pytest.mark.parametrize( - "exception_code", - [ - "InternalServiceError", - "ThrottlingException", - ], - ) - @mark.skipif(os.name != "nt", reason="Windows-only test.") - @patch( - "deadline_worker_agent.windows_credentials_resolver.WindowsCredentialsResolver._get_secrets_manager_client" - ) - def test_fetch_secrets_manager_retriable_exception( - self, secrets_manager_client_mock: MagicMock, exception_code: str - ): - # GIVEN - mock_boto_session = MagicMock() - resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) - password_arn = "password_arn" - exc = botocore.exceptions.ClientError( - {"Error": {"Code": exception_code, "Message": "A message"}}, "GetSecretValue" - ) - secrets_manager_client_mock.side_effect = exc - - # THEN - # Assert raising DeadlineRequestUnrecoverableError after 10 retries - with pytest.raises(RuntimeError): - resolver._fetch_secret_from_secrets_manager(password_arn) - assert secrets_manager_client_mock.call_count == 10 - - @mark.skipif(os.name != "nt", reason="Windows-only test.") - @patch( - "deadline_worker_agent.windows_credentials_resolver.WindowsCredentialsResolver._get_secrets_manager_client" - ) - def test_fetch_secrets_manager_non_json_secret_exception( - self, - secrets_manager_client_mock: MagicMock, - ): - # GIVEN - mock_boto_session = MagicMock() - resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) - password_arn = "password_arn" - secrets_manager_client_mock.get_secret_value.return_value = {"SecretString": "_a string_"} - - # THEN - with pytest.raises(ValueError): - resolver._fetch_secret_from_secrets_manager(password_arn) diff --git a/test/unit/test_worker.py b/test/unit/test_worker.py index b7079281..9f38e12e 100644 --- a/test/unit/test_worker.py +++ b/test/unit/test_worker.py @@ -155,6 +155,7 @@ def test_passes_worker_logs_dir( worker_persistence_dir=ANY, worker_logs_dir=worker_logs_dir, retain_session_dir=ANY, + stop=ANY, ) diff --git a/test/unit/windows/test_win_credentials_resolver.py b/test/unit/windows/test_win_credentials_resolver.py new file mode 100644 index 00000000..057d61a1 --- /dev/null +++ b/test/unit/windows/test_win_credentials_resolver.py @@ -0,0 +1,244 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock +from typing import Generator +import botocore +import sys + +from openjd.sessions import WindowsSessionUser, BadCredentialsException +from pytest import fixture +import pytest + +# This if is required for two purposes: +# 1. It short-circuits mypy from type checking this module on platforms other than Windows +# https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks +# 2. It causes the tests to not be discovered/ran on non-Windows platforms +if sys.platform == "win32": + import deadline_worker_agent.windows.win_credentials_resolver as credentials_mod + + class TestWindowsCredentialsResolver: + @fixture(autouse=True) + def now(self) -> datetime: + return datetime(2000, 1, 1) + + @fixture(autouse=True) + def datetime_mock(self, now: datetime) -> Generator[MagicMock, None, None]: + with patch.object(credentials_mod, "datetime") as mock: + mock.side_effect = lambda *args, **kwargs: datetime(*args, **kwargs) + mock.fromtimestamp.side_effect = lambda *args, **kwargs: datetime.fromtimestamp( + *args, **kwargs + ) + mock.now.return_value = now + yield mock + + def test_prune_cache(self, datetime_mock: MagicMock): + # GIVEN + mock_boto_session = MagicMock() + now = datetime(2023, 1, 1, 12, 0, 0) + datetime_mock.now.return_value = now + resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) + + # Add a user to the cache that should be pruned + expired_user = WindowsSessionUser(user="expired_user", password="fake_password") + expired_entry = credentials_mod._WindowsCredentialsCacheEntry( + windows_session_user=expired_user, + last_fetched_at=now - timedelta(hours=13), + last_accessed=now - timedelta(hours=13), + ) + resolver._user_cache["expired_user_arn"] = expired_entry + + # Add a user to the cache that should be kept + valid_user = WindowsSessionUser(user="valid_user", password="fake_password") + valid_entry = credentials_mod._WindowsCredentialsCacheEntry( + windows_session_user=valid_user, + last_fetched_at=now - timedelta(hours=11), + last_accessed=now - timedelta(hours=11), + ) + resolver._user_cache["valid_user_arn"] = valid_entry + + # WHEN + resolver.prune_cache() + + # THEN + assert len(resolver._user_cache) == 1 + assert "valid_user_arn" in resolver._user_cache + assert "expired_user_arn" not in resolver._user_cache + + @patch( + "deadline_worker_agent.windows.win_credentials_resolver.WindowsCredentialsResolver._fetch_secret_from_secrets_manager" + ) + def test_get_windows_session_user_non_cached(self, fetch_secret_mock, datetime_mock): + # GIVEN + mock_boto_session = MagicMock() + now = datetime(2023, 1, 1, 12, 0, 0) + datetime_mock.now.return_value = now + resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) + secret_data = {"password": "fake_password"} + fetch_secret_mock.return_value = secret_data + user = "new_user" + password_arn = "new_password_arn" + + # WHEN + result = resolver.get_windows_session_user(user, password_arn) + + # THEN + fetch_secret_mock.assert_called_once_with(password_arn) + assert isinstance(result, WindowsSessionUser) + assert result.user == user + assert result.password == secret_data["password"] + + @patch( + "deadline_worker_agent.windows.win_credentials_resolver.WindowsCredentialsResolver._fetch_secret_from_secrets_manager" + ) + def test_get_windows_session_user_no_password_in_secret( + self, fetch_secret_mock, datetime_mock + ): + # GIVEN + mock_boto_session = MagicMock() + now = datetime(2023, 1, 1, 12, 0, 0) + datetime_mock.now.return_value = now + resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) + secret_data = {"something-other-than-password": "fake_password"} + fetch_secret_mock.return_value = secret_data + user = "new_user" + password_arn = "new_password_arn" + + # WHEN + with pytest.raises(ValueError): + resolver.get_windows_session_user(user, password_arn) + + # THEN + fetch_secret_mock.assert_called_once_with(password_arn) + + @patch( + "deadline_worker_agent.windows.win_credentials_resolver.WindowsCredentialsResolver._fetch_secret_from_secrets_manager" + ) + def test_get_windows_session_user_cached(self, fetch_secret_mock, datetime_mock): + # GIVEN + mock_boto_session = MagicMock() + now = datetime(2023, 1, 1, 12, 0, 0) + datetime_mock.now.return_value = now + resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) + password_arn = "password_arn" + user = "user" + user_obj = WindowsSessionUser(user=user, password="fake_cached_password") + cached_entry = credentials_mod._WindowsCredentialsCacheEntry( + windows_session_user=user_obj, + last_fetched_at=now - timedelta(hours=11), + last_accessed=now - timedelta(hours=11), + ) + resolver._user_cache[f"{user}_{password_arn}"] = cached_entry + secret_data = {"password": "fake_new_password"} + fetch_secret_mock.return_value = secret_data + + # WHEN + result = resolver.get_windows_session_user(user, password_arn) + + # THEN + fetch_secret_mock.assert_not_called() + assert isinstance(result, WindowsSessionUser) + assert result.user == user + assert result.password == "fake_cached_password" + + @patch( + "deadline_worker_agent.windows.win_credentials_resolver.WindowsCredentialsResolver._fetch_secret_from_secrets_manager" + ) + def test_get_windows_session_user_invalid_credentials( + self, fetch_secret_mock, datetime_mock + ): + # GIVEN + mock_boto_session = MagicMock() + now = datetime(2023, 1, 1, 12, 0, 0) + datetime_mock.now.return_value = now + resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) + secret_data = {"password": "fake_password"} + fetch_secret_mock.return_value = secret_data + user = "new_user" + password_arn = "new_password_arn" + + with patch( + "deadline_worker_agent.windows.win_credentials_resolver.WindowsSessionUser", + side_effect=BadCredentialsException("Invalid credentials"), + ): + # WHEN + with pytest.raises(ValueError): + resolver.get_windows_session_user(user, password_arn) + assert ( + resolver._user_cache[f"{user}_{password_arn}"].windows_session_user is None + ) + + @pytest.mark.parametrize( + "exception_code", + [ + "ResourceNotFoundException", + "InvalidRequestException", + "DecryptionFailure", + ], + ) + @patch( + "deadline_worker_agent.windows.win_credentials_resolver.WindowsCredentialsResolver._get_secrets_manager_client" + ) + def test_fetch_secrets_manager_non_retriable_exception( + self, secrets_manager_client_mock: MagicMock, exception_code: str + ): + # GIVEN + mock_boto_session = MagicMock() + resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) + password_arn = "password_arn" + exc = botocore.exceptions.ClientError( + {"Error": {"Code": exception_code, "Message": "A message"}}, "GetSecretValue" + ) + secrets_manager_client_mock.side_effect = exc + + # THEN + with pytest.raises(RuntimeError): + resolver._fetch_secret_from_secrets_manager(password_arn) + + @pytest.mark.parametrize( + "exception_code", + [ + "InternalServiceError", + "ThrottlingException", + ], + ) + @patch( + "deadline_worker_agent.windows.win_credentials_resolver.WindowsCredentialsResolver._get_secrets_manager_client" + ) + def test_fetch_secrets_manager_retriable_exception( + self, secrets_manager_client_mock: MagicMock, exception_code: str + ): + # GIVEN + mock_boto_session = MagicMock() + resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) + password_arn = "password_arn" + exc = botocore.exceptions.ClientError( + {"Error": {"Code": exception_code, "Message": "A message"}}, "GetSecretValue" + ) + secrets_manager_client_mock.side_effect = exc + + # THEN + # Assert raising DeadlineRequestUnrecoverableError after 10 retries + with pytest.raises(RuntimeError): + resolver._fetch_secret_from_secrets_manager(password_arn) + assert secrets_manager_client_mock.call_count == 10 + + @patch( + "deadline_worker_agent.windows.win_credentials_resolver.WindowsCredentialsResolver._get_secrets_manager_client" + ) + def test_fetch_secrets_manager_non_json_secret_exception( + self, + secrets_manager_client_mock: MagicMock, + ): + # GIVEN + mock_boto_session = MagicMock() + resolver = credentials_mod.WindowsCredentialsResolver(mock_boto_session) + password_arn = "password_arn" + secrets_manager_client_mock.get_secret_value.return_value = { + "SecretString": "_a string_" + } + + # THEN + with pytest.raises(ValueError): + resolver._fetch_secret_from_secrets_manager(password_arn) diff --git a/test/unit/windows/test_win_service.py b/test/unit/windows/test_win_service.py new file mode 100644 index 00000000..ddd0a88d --- /dev/null +++ b/test/unit/windows/test_win_service.py @@ -0,0 +1,104 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from unittest.mock import patch + +import pytest +import sys + +if sys.platform != "win32": + pytest.skip("Windows-specific tests", allow_module_level=True) + +from win32serviceutil import ServiceFramework + +from deadline_worker_agent.windows.win_service import WorkerAgentWindowsService +from deadline_worker_agent.windows import win_service + + +def test_get_current_process_session() -> None: + """Tests that the _get_current_process_session() function uses the expected pywin32 API calls""" + + # GIVEN + with ( + patch.object( + win_service.win32process, "GetCurrentProcessId" + ) as mock_get_current_process_id, + patch.object(win_service.win32ts, "ProcessIdToSessionId") as mock_process_id_to_session_id, + ): + # WHEN + result = win_service._get_current_process_session() + + # THEN + mock_get_current_process_id.assert_called_once_with() + mock_process_id_to_session_id.assert_called_once_with(mock_get_current_process_id.return_value) + assert result == mock_process_id_to_session_id.return_value + + +@pytest.mark.parametrize( + argnames="session_id,expected_result", + argvalues=( + pytest.param(0, True, id="session-zero"), + pytest.param(1, False, id="session-non-zero"), + ), +) +def test_is_windows_session_zero(session_id: int, expected_result: bool) -> None: + """Tests that the is_windows_session_zero() function returns true iff the return value of + _get_current_process_session is 0""" + + # GIVEN + # clear the cache decorator to ensure the function result is not cached between tests + win_service.is_windows_session_zero.cache_clear() + with patch.object(win_service, "_get_current_process_session", return_value=session_id): + # WHEN + result = win_service.is_windows_session_zero() + + # THEN + assert result == expected_result + + +def test_is_windows_session_zero_cached() -> None: + """Tests that the is_windows_session_zero() function caches the result between calls""" + + # GIVEN + # clear the cache decorator to ensure the function result is not cached on first run + win_service.is_windows_session_zero.cache_clear() + with patch.object( + win_service, "_get_current_process_session" + ) as mock_get_current_process_session: + # We make our mocked _get_current_process_session return different session IDs between calls + mock_get_current_process_session.side_effect = [0, 1] + first_result = win_service.is_windows_session_zero() + # WHEN + second_result = win_service.is_windows_session_zero() + + # THEN + assert first_result is True + assert second_result == first_result + mock_get_current_process_session.assert_called_once_with() + + +def test_svc_name() -> None: + """Tests that the service name (ID used for the service) is "DeadlineWorker" """ + # THEN + assert WorkerAgentWindowsService._svc_name_ == "DeadlineWorker" + + +def test_svc_description() -> None: + """Tests that the description of the service is correct""" + # THEN + assert WorkerAgentWindowsService._svc_description_ == ( + "Service hosting the AWS Deadline Cloud Worker Agent. Connects to AWS " + "Deadline Cloud and runs jobs as a worker in a fleet." + ) + + +def test_display_name() -> None: + """Tests that the display name of the service is "AWS Deadline Cloud Worker Agent" """ + # THEN + assert WorkerAgentWindowsService._svc_display_name_ == "AWS Deadline Cloud Worker" + + +def test_parent_class() -> None: + """Tests that the WorkerAgentWindowsService subclasses win32serviceutil.ServiceFramework""" + + # THEN + assert issubclass(WorkerAgentWindowsService, ServiceFramework)