diff --git a/composer/callbacks/health_checker.py b/composer/callbacks/health_checker.py index b052bb47cb..272a612eae 100644 --- a/composer/callbacks/health_checker.py +++ b/composer/callbacks/health_checker.py @@ -3,26 +3,18 @@ """Check GPU Health during training.""" import logging +import os from collections import deque from datetime import datetime from typing import List, Optional, Tuple -import torch - -try: - import pynvml -except ImportError: - pynvml = None - -import os - import numpy as np -from slack_sdk.webhook import WebhookClient +import torch from composer.core import Callback, State from composer.core.time import Timestamp from composer.loggers import Logger -from composer.utils import dist +from composer.utils import MissingConditionalImportError, dist log = logging.getLogger(__name__) @@ -69,6 +61,14 @@ def __init__( if not self.slack_webhook_url: self.slack_webhook_url = os.environ.get('SLACK_WEBHOOK_URL', None) + if self.slack_webhook_url: + # fail fast if missing import + try: + import slack_sdk + del slack_sdk + except ImportError as e: + raise MissingConditionalImportError('health_checker', 'slack_sdk', None) from e + self.last_sample = 0 self.last_check = 0 @@ -133,6 +133,7 @@ def _alert(self, message: str, state: State) -> None: logging.warning(message) if self.slack_webhook_url: + from slack_sdk.webhook import WebhookClient client = WebhookClient(url=self.slack_webhook_url) client.send(text=message) @@ -141,12 +142,13 @@ def _is_available() -> bool: if not torch.cuda.is_available(): return False try: + import pynvml pynvml.nvmlInit() # type: ignore return True + except ImportError: + raise MissingConditionalImportError('health_checker', 'pynvml', None) except pynvml.NVMLError_LibraryNotFound: # type: ignore logging.warning('NVML not found, disabling GPU health checking') - except ImportError: - logging.warning('pynvml library not found, disabling GPU health checking.') except Exception as e: logging.warning(f'Error initializing NVML: {e}') @@ -168,13 +170,18 @@ def sample(self) -> None: self.samples.append(sample) def _sample(self) -> Optional[List]: + try: + import pynvml + except ImportError: + raise MissingConditionalImportError('health_checker', 'pynvml', None) + try: samples = [] - device_count = pynvml.nvmlDeviceGetCount() # type: ignore + device_count = pynvml.nvmlDeviceGetCount() for i in range(device_count): - handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore - samples.append(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu) # type: ignore - except pynvml.NVMLError: # type: ignore + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + samples.append(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu) + except pynvml.NVMLError: return None return samples