diff --git a/.github/workflows/pytest-gpu.yaml b/.github/workflows/pytest-gpu.yaml index d2faab47af..ec123c26a5 100644 --- a/.github/workflows/pytest-gpu.yaml +++ b/.github/workflows/pytest-gpu.yaml @@ -55,7 +55,6 @@ jobs: run: | set -ex python -m pip install mosaicml-cli - mcli init --mcloud mcli version - name: Submit Run id: tests diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 62ffcd565c..08e9337681 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 try: + from llmfoundry.callbacks.async_eval_callback import AsyncEval from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet from llmfoundry.callbacks.fdiff_callback import FDiffMetrics from llmfoundry.callbacks.generate_callback import Generate @@ -28,4 +29,5 @@ 'EvalGauntlet', 'ModelGauntlet', 'HuggingFaceCheckpointer', + 'AsyncEval', ] diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py new file mode 100644 index 0000000000..8352a9e283 --- /dev/null +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -0,0 +1,375 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Run the eval loop asynchronously as part of a MosaicML platform run. + +This callback is currently experimental. The API may change in the future. +""" + +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from composer.callbacks import CheckpointSaver +from composer.core import Callback, Event, State, Time, TimeUnit +from composer.loggers import Logger +from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR, + RUN_NAME_ENV_VAR) +from composer.utils import dist +from composer.utils.misc import create_interval_scheduler + +from mcli import ComputeConfig, Run, RunConfig, create_run, get_run + +log = logging.getLogger(__name__) + +REQUIRED_PARAMS_FOR_EVAL = { + 'device_eval_batch_size', + 'icl_tasks', # only required for eval, may not be specified in pure training + 'max_seq_len', + 'model', # converted into models + 'tokenizer', # converted into models +} +OPTIONAL_PARAMS_FOR_EVAL = { + 'dist_timeout', + 'eval_gauntlet', + 'fsdp_config', + 'icl_subset_num_batches', + 'loggers', + 'precision', + 'python_log_level', + 'seed', +} + +RUN_NAME_PREFIX = 'eval' +MAX_RUN_NAME_BASE_LENGTH = 55 + + +def get_run_name(training_run_name: str, current_interval: str) -> str: + """Get the new eval run name. + + Args: + training_run_name: The name of the current training run + current_interval: The current interval string of the training run + + Returns: + The new run name + """ + name_without_uuid_suffix = training_run_name.rsplit('-', 1)[0] + + max_length = MAX_RUN_NAME_BASE_LENGTH - len(RUN_NAME_PREFIX) - len( + current_interval) - 2 + + # A run name that is too long will fail a createRun call + if len(name_without_uuid_suffix) > max_length: + new_name = name_without_uuid_suffix[:max_length] + log.warning( + f'Training run name {name_without_uuid_suffix} may be too long,' + + f' truncating to {new_name}') + name_without_uuid_suffix = new_name + + return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}' + + +def get_latest_checkpoint(event: Event, state: State) -> Optional[str]: + """Get the latest checkpoint from the training run. + + Args: + event: The current run event + state: The current state of the training run + + Returns: + The path to the latest checkpoint, or None if there is not a latest checkpoint + """ + checkpointer = None + for callback in state.callbacks: + if isinstance(callback, CheckpointSaver): + checkpointer = callback + break + + if not checkpointer: + log.warning('No checkpoint saver callback found') + return None + + if not checkpointer.saved_checkpoints: + log.warning('No saved checkpoints found on the checkpointer') + return None + + latest = checkpointer.saved_checkpoints[-1] + return str(Path(latest).parts[-1]) + + +def get_eval_parameters( + parameters: Dict[str, Any], + checkpoint: str, + training_run_name: str, +) -> Dict[str, Any]: + """Get the parameters needed for the eval run. + + Args: + parameters: The parameters from the training run + checkpoint: The path to the latest checkpoint + training_run_name: The name of the training run + + Returns: + The parameters needed for the eval run as a dict + """ + looking_for = REQUIRED_PARAMS_FOR_EVAL.copy() + + # Go through all parameters and pull out the ones needed for eval + subset_keys = {} + for key in parameters: + if key in OPTIONAL_PARAMS_FOR_EVAL: + subset_keys[key] = parameters[key] + elif key in REQUIRED_PARAMS_FOR_EVAL: + subset_keys[key] = parameters[key] + looking_for.remove(key) + + if looking_for: + raise Exception( + f'Missing the following required parameters for async eval: {looking_for}' + ) + + for logger, config in subset_keys.get('loggers', {}).items(): + if logger == 'wandb': + config['group'] = config.pop('name', training_run_name) + + # mlflow currently does not support grouping, so this will just launch + # a new mlflow run + + # Create new eval models list + model = subset_keys.pop('model') + + model_name = model.get('name', None) + if not model_name: + raise Exception(f'Async evaluation requires "name" keys for models') + new_models = { + 'model_name': model_name, + 'model': model, + 'load_path': checkpoint + } + + tokenizer = subset_keys.pop('tokenizer', None) + if tokenizer is not None: + new_models['tokenizer'] = tokenizer + subset_keys['models'] = [new_models] + return subset_keys + + +def validate_interval(interval: Union[str, int, Time], + save_interval: Union[str, int, Time]) -> Time: + if isinstance(save_interval, str): + new_save_interval: Time = Time.from_timestring(save_interval) + elif isinstance(save_interval, int): + new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH) + else: + new_save_interval: Time = save_interval + + if isinstance(interval, str): + result: Time = Time.from_timestring(interval) + elif isinstance(interval, int): + result: Time = Time(interval, TimeUnit.EPOCH) + else: + result: Time = interval + + if new_save_interval.unit != result.unit: + raise ValueError( + 'Save interval and async eval interval must be in the same unit') + if result < new_save_interval: + raise ValueError( + 'Async eval interval must be equal or greater (less frequent) than save interval' + ) + if result.value % new_save_interval.value != 0: + raise ValueError( + 'Async eval interval must be a multiple of save interval') + return result + + +class AsyncEval(Callback): + """Run the eval loop asynchronously as part of a MosaicML platform run. + + This callback is currently experimental. The API may change in the future. + + Args: + training_config: Dict[str, Any]: The config from the training run + interval: Union[str, int, Time]: The interval describing how often eval runs should be + launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. + Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, + :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. + compute: Optional[Union[ComputeConfig, Dict[str, Any]]]: The compute configuration to + use for the eval run. If not provided, the same cluster as the current run and a + single, full GPU node will be used. + """ + + def __init__( + self, + training_config: Dict[str, Any], + interval: Union[str, int, Time], + compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None, + ): + + for required in ('save_interval', 'save_folder'): + if required not in training_config: + raise ValueError(f'{required} required for async eval') + + self.checkpoint_save_folder = training_config['save_folder'] + self.training_config = training_config + self.interval = validate_interval(interval, + self.training_config['save_interval']) + self.check_interval = create_interval_scheduler( + interval, + # There is a custom close to ensure that the final checkpoint + # (which is the most important) is evaled after it is written + include_end_of_training=False, + ) + self.compute = compute + self.last_checkpoint: Optional[str] = None + + # Run these during init to fail fast in any of the error cases + self.current_run = self._get_current_run() + get_eval_parameters( + parameters=training_config, + checkpoint='test', + training_run_name=self.current_run.name, + ) + log.info( + f'Initialized AsyncEval callback. Will generate runs at interval {interval}' + ) + + def run_event(self, event: Event, state: State, logger: Logger) -> None: + del logger + + should_launch_run = all([ + state.get_elapsed_duration() is not None, + self.check_interval(state, event), + dist.get_global_rank() == 0, + ]) + + if should_launch_run: + current_interval = state.timestamp.get(self.interval.unit) + checkpoint = get_latest_checkpoint(event, state) + if not checkpoint: + return # warnings logged in get_latest_checkpoint + + # TODO: ensure the checkpoint is fully written before launching the eval run + full_checkpoint = f'{self.checkpoint_save_folder}/{checkpoint}' + if full_checkpoint == self.last_checkpoint: + # Do not eval a checkpoint that has already been evaluated. + log.info( + 'Skipping async eval because the checkpoint has not changed' + ) + return + + self.launch_run(full_checkpoint, current_interval) + self.last_checkpoint = full_checkpoint + + def close(self, state: State, logger: Logger) -> None: + del logger + + if dist.get_global_rank() != 0: + return + + save_latest_filename = self.training_config.get('save_latest_filename', + None) + + if not save_latest_filename: + rank = dist.get_global_rank() + save_latest_filename = f'latest-rank{rank}.pt' + + checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}' + self.launch_run(checkpoint, state.timestamp.get(self.interval.unit)) + + def _get_current_run(self) -> Run: + if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, + 'false').lower() == 'false': + raise RuntimeError( + 'AsyncEval callback is only supported when running on the MosaicML platform' + ) + + run_name = os.environ.get(RUN_NAME_ENV_VAR, None) + if not run_name: + raise RuntimeError( + 'RUN_NAME environment variable must be set to use the AsyncEval callback' + ) + + # Allows the MapiException to be raised if the run doesn't exist + return get_run(run_name, include_details=True) + + def launch_run(self, checkpoint: str, current_interval: Time) -> Run: + log.info(f'Launching eval run for {checkpoint} at {current_interval}') + + cfg = self.current_run.submitted_config + default_compute = { + 'gpus': 8, + 'cluster': self.current_run.cluster, + } + + run_name = get_run_name(self.current_run.name, str(current_interval)) + + params = get_eval_parameters( + parameters=self.training_config, + checkpoint=checkpoint, + training_run_name=self.current_run.name, + ) + params['run_name'] = run_name + + integrations = cfg.integrations + found_llm_foundry, installation_path = False, 'llm-foundry' + for i in integrations: + if i['integration_type'] != 'git_repo': + continue + + if not i['git_repo'].endswith('llm-foundry'): # detects forks + continue + + found_llm_foundry = True + if i.get('path'): + installation_path = i['path'] + + if not found_llm_foundry: + from llmfoundry import __version__ as latest_foundry_version + + # If github integration is not found, foundry is likely installed + # through the run command. In this case, we'll add the integration + # so the eval run will still work. However, it could cause unexpected + # behaviors because its not using custom repos or branches specified + # in the training run. For this reason, we'll log a warning + version = f'v{latest_foundry_version}' + log.warning( + 'No github integration found for llm-foundry. Adding installation ' + + f'to eval run for latest foundry release ({version}). ' + + 'To use a fork, custom branch, or custom version, configure ' + + 'llm-foundry installation through a github integration') + integrations.append({ + 'integration_type': 'git_repo', + 'git_repo': 'mosaicml/llm-foundry', + 'git_branch': version, + 'pip_install': '-e .[gpu]', + 'ssh_clone': False, + }) + + # This will record the timestamp and make it available for grouping + # and plotting in wandb + metadata = cfg.metadata + metadata['eval_timestamp'] = current_interval.value + metadata['eval_timestamp_unit'] = current_interval.unit.value + + # TODO: This just runs an eval run, but we also want to attach the + # deployment, which would require a hf conversion and parametrizing the + # dependent_deployment in the run config + command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS' + run_config = RunConfig( + name=run_name, + image=self.current_run.image, + compute=self.compute or default_compute, + command=command, + integrations=integrations, + env_variables=cfg.env_variables, + metadata=cfg.metadata, + parameters=params, + ) + + log.info(f'Creating new run with config: \n{run_config}') + new_run = create_run(run_config) + log.info(f'Launched new run {new_run.name} inside eval loop') + return new_run diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index a672fbee55..404ad604ab 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -31,9 +31,9 @@ from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase -from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling, - HuggingFaceCheckpointer, LayerFreezing, - MonolithicCheckpointSaver, +from llmfoundry.callbacks import (AsyncEval, EvalGauntlet, FDiffMetrics, + GlobalLRScaling, HuggingFaceCheckpointer, + LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) from llmfoundry.data.dataloader import build_dataloader from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, @@ -157,8 +157,11 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb -def build_callback(name: str, kwargs: Union[DictConfig, Dict[str, - Any]]) -> Callback: +def build_callback( + name: str, + kwargs: Union[DictConfig, Dict[str, Any]], + config: Any = None, +) -> Callback: if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -205,21 +208,32 @@ def build_callback(name: str, kwargs: Union[DictConfig, Dict[str, if isinstance(kwargs, DictConfig): kwargs = om.to_object(kwargs) # pyright: ignore return HuggingFaceCheckpointer(**kwargs) + elif name == 'async_eval': + if config is None: + raise ValueError( + 'Parameters config is required for async eval callback') + + return AsyncEval(**kwargs, training_config=config) else: raise ValueError(f'Not sure how to build callback: {name}') def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: + kwargs_dict = { + k: v if isinstance(v, str) else om.to_container(v, resolve=True) + for k, v in kwargs.items() + } + if name == 'wandb': - return WandBLogger(**kwargs) + return WandBLogger(**kwargs_dict) elif name == 'tensorboard': - return TensorboardLogger(**kwargs) + return TensorboardLogger(**kwargs_dict) elif name == 'in_memory_logger': - return InMemoryLogger(**kwargs) + return InMemoryLogger(**kwargs_dict) elif name == 'mlflow': - return MLFlowLogger(**kwargs) + return MLFlowLogger(**kwargs_dict) elif name == 'inmemory': - return InMemoryLogger(**kwargs) + return InMemoryLogger(**kwargs_dict) else: raise ValueError(f'Not sure how to build logger: {name}') diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 5c74b9fd8f..8dbe91e6d2 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import logging import os import sys @@ -10,6 +11,7 @@ import pandas as pd import torch +from composer.loggers import MosaicMLLogger from composer.loggers.logger_destination import LoggerDestination from composer.models.base import ComposerModel from composer.trainer import Trainer @@ -24,7 +26,8 @@ from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_evaluators, build_logger, build_tokenizer) -from llmfoundry.utils.config_utils import pop_config, process_init_device +from llmfoundry.utils.config_utils import (log_config, pop_config, + process_init_device) log = logging.getLogger(__name__) @@ -116,6 +119,8 @@ def evaluate_model( precision: str, eval_gauntlet_df: Optional[pd.DataFrame], icl_subset_num_batches: Optional[int], + metadata: Optional[Dict[str, str]], + logged_config: DictConfig, ): log.info(f'Evaluating model: {model_cfg.model_name}') @@ -146,6 +151,20 @@ def evaluate_model( for name, logger_cfg in loggers_cfg.items() ] + if metadata is not None: + # Flatten the metadata for logging + loggers_cfg.pop('metadata', None) + loggers_cfg.update(metadata, merge=True) + + # Find the MosaicMLLogger + mosaicml_logger = next(( + logger for logger in loggers if isinstance(logger, MosaicMLLogger)), + None) + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics(metadata) + mosaicml_logger._flush_metadata(force_flush=True) + if fsdp_config and model_cfg.model.get('load_in_8bit', False): raise ValueError( 'The FSDP config block is not supported when loading ' + @@ -179,6 +198,7 @@ def evaluate_model( assert composer_model is not None + log.info(f'Building trainer for {model_cfg.model_name}...') trainer = Trainer( run_name=run_name, seed=seed, @@ -195,6 +215,10 @@ def evaluate_model( python_log_level=python_log_level, ) + log.info('Evaluation config:') + log_config(logged_config) + + log.info(f'Starting eval for {model_cfg.model_name}...') if torch.cuda.is_available(): torch.cuda.synchronize() a = time.time() @@ -202,12 +226,17 @@ def evaluate_model( if torch.cuda.is_available(): torch.cuda.synchronize() b = time.time() + log.info(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: om.resolve(cfg) + + # Create copy of config for logging + logged_cfg: DictConfig = copy.deepcopy(cfg) + model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True) eval_gauntlet_config: Optional[Union[str, DictConfig]] = pop_config( cfg, 'eval_gauntlet', must_exist=False, default_value=None) @@ -272,6 +301,12 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: 'icl_subset_num_batches', must_exist=False, default_value=None) + metadata: Optional[Dict[str, str]] = pop_config(cfg, + 'metadata', + must_exist=False, + default_value=None, + convert=True) + # Pop out interpolation variables. pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None) @@ -315,7 +350,9 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: python_log_level=python_log_level, precision=precision, eval_gauntlet_df=eval_gauntlet_df, - icl_subset_num_batches=icl_subset_num_batches) + icl_subset_num_batches=icl_subset_num_batches, + metadata=metadata, + logged_config=logged_cfg) trainers.append(trainer) if eval_gauntlet_callback is not None: diff --git a/scripts/train/train.py b/scripts/train/train.py index 2c1099ff00..ef7a3b91db 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -24,6 +24,7 @@ from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, MPTForCausalLM) +from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_algorithm, build_callback, @@ -505,10 +506,13 @@ def main(cfg: DictConfig) -> Trainer: # Callbacks callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg) + build_callback(str(name), callback_cfg, om.to_container(logged_cfg)) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] + use_async_eval = any( + isinstance(callback, AsyncEval) for callback in callbacks) + # Algorithms algorithms = [ build_algorithm(str(name), algorithm_cfg) @@ -529,17 +533,19 @@ def main(cfg: DictConfig) -> Trainer: ## Evaluation log.info('Building eval loader...') eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len + # TODO: evaluators should not be built at all if use_async_eval is True + # This will be fixed when eval_loader support is fully added to AsyncEval evaluators, _, eval_gauntlet_callback = build_evaluators( eval_loader_config, - icl_tasks_config, - eval_gauntlet_config, + icl_tasks_config if not use_async_eval else None, + eval_gauntlet_config if not use_async_eval else None, tokenizer=tokenizer, device_eval_batch_size=device_eval_batch_size, icl_seq_len=eval_icl_seq_len, icl_subset_num_batches=icl_subset_num_batches, ) - if eval_gauntlet_callback is not None: + if eval_gauntlet_callback is not None and not use_async_eval: callbacks.append(eval_gauntlet_callback) # Build Model diff --git a/setup.py b/setup.py index 2283e60d9c..923705699c 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ 'einops==0.5.0', 'omegaconf>=2.2.3,<3', 'slack-sdk<4', - 'mosaicml-cli>=0.3,<1', + 'mosaicml-cli>=0.5.27,<1', 'onnx==1.14.0', 'onnxruntime==1.15.1', 'cmake>=3.25.0,<=3.26.3', # required for triton-pre-mlir below diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py new file mode 100644 index 0000000000..b3a1e98f79 --- /dev/null +++ b/tests/callbacks/test_async_eval_callback.py @@ -0,0 +1,331 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from copy import deepcopy +from unittest.mock import MagicMock, patch + +import pytest +from composer.core import Time, TimeUnit + +from llmfoundry.callbacks.async_eval_callback import (AsyncEval, + get_eval_parameters, + get_run_name, + validate_interval) +from mcli import Run, RunConfig, RunStatus + +# here +RUN_NAME = 'foo_bar-1234' +BASIC_PARAMS = { + 'save_interval': '1ba', + 'save_folder': 'foobar', + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'model': { + 'name': 'model_example', + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + } + } + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example', + }, +} + + +def test_get_run_name(): + a = get_run_name('foo-1234', '1ba') + assert a == 'eval-1ba-foo' + + # Run name should be truncated + b = get_run_name(50 * 'foo' + '-1234', '1ba') + assert b == 'eval-1ba-foofoofoofoofoofoofoofoofoofoofoofoofoofoofoof' + + +@pytest.fixture(autouse=True, scope='module') +def set_os_env_vars(): + with patch.dict('os.environ', { + 'MOSAICML_PLATFORM': 'true', + 'RUN_NAME': RUN_NAME + }): + yield + + +def test_fails_when_not_on_platform(): + with patch.dict('os.environ', {'MOSAICML_PLATFORM': 'false'}): + with pytest.raises( + Exception, + match= + 'AsyncEval callback is only supported when running on the MosaicML platform' + ): + AsyncEval(BASIC_PARAMS, interval='2ba') + + +def test_fails_when_no_run_name(): + with patch.dict('os.environ', { + 'MOSAICML_PLATFORM': 'true', + 'RUN_NAME': '' + }): + with pytest.raises( + Exception, + match= + 'RUN_NAME environment variable must be set to use the AsyncEval callback' + ): + AsyncEval(BASIC_PARAMS, interval='2ba') + + +def test_get_eval_parameters(): + with pytest.raises( + Exception, + match='Missing the following required parameters for async eval:'): + get_eval_parameters({}, 'checkpoints/file', RUN_NAME) + + # minimal example + params = get_eval_parameters(BASIC_PARAMS, 'checkpoints/file', RUN_NAME) + assert params == { + 'device_eval_batch_size': + 2, + 'icl_tasks': + 'icl_task_example', + 'max_seq_len': + 3, + 'models': [{ + 'model_name': 'model_example', + 'model': { + 'name': 'model_example', + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, + }, + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example' + }, + 'load_path': 'checkpoints/file', + }], + } + + # maximal example + params2 = get_eval_parameters( + { + # required + **BASIC_PARAMS, + # optional + 'dist_timeout': 1, + 'eval_gauntlet': 'eval_gauntlet_example', + 'fsdp_config': { + 'fsdp_cfg_example': 'fsdp_cfg_example' + }, + 'icl_subset_num_batches': 4, + 'loggers': { + 'wandb': { + 'init_kwargs': { + 'fee': 'bee' + } + } + }, + 'precision': 'precision_example', + 'python_log_level': 'debug', + 'seed': 5, + # ignore this + 'ignore_this': 'ignore_this', + }, + 'checkpoints/file', + RUN_NAME, + ) + assert params2 == { + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'dist_timeout': 1, + 'models': [{ + 'model_name': 'model_example', + 'model': { + 'name': 'model_example', + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, + }, + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example' + }, + 'load_path': 'checkpoints/file', + }], + 'eval_gauntlet': 'eval_gauntlet_example', + 'fsdp_config': { + 'fsdp_cfg_example': 'fsdp_cfg_example' + }, + 'icl_subset_num_batches': 4, + 'loggers': { + 'wandb': { + 'group': 'foo_bar-1234', + 'init_kwargs': { + 'fee': 'bee' + }, + } + }, + 'precision': 'precision_example', + 'python_log_level': 'debug', + 'seed': 5, + } + + +def test_validate_interval(): + with pytest.raises(ValueError): + validate_interval('1ba', '1ep') # different units + with pytest.raises(ValueError): + validate_interval('1ba', '2ba') # checkpointing happens less often + with pytest.raises(ValueError): + validate_interval('3ba', '2ba') # not a multiple + + assert validate_interval('2ba', '1ba') == Time(2, TimeUnit.BATCH) + two_epochs = Time(2, TimeUnit.EPOCH) + assert validate_interval(2, 2) == two_epochs + assert validate_interval(two_epochs, two_epochs) == two_epochs + assert validate_interval('2ep', two_epochs) == two_epochs + + +FAKE_RUN = Run( + run_uid='123', + name=RUN_NAME, + image='fake-image', + status=RunStatus.RUNNING, + created_at=datetime.datetime(2021, 1, 1), + updated_at=datetime.datetime(2021, 1, 1), + created_by='me', + priority='low', + preemptible=False, + retry_on_system_failure=True, + cluster='c1z2', + gpu_type='a100', + gpus=16, + cpus=0, + node_count=2, + latest_resumption=None, # type: ignore + submitted_config=RunConfig( + name=RUN_NAME, + image='fake-image', + command='echo hi', + parameters={}, + ), +) + + +@patch('llmfoundry.callbacks.async_eval_callback.get_run', + return_value=FAKE_RUN) +@patch('llmfoundry.callbacks.async_eval_callback.create_run', + return_value=FAKE_RUN) +def test_async_eval_callback_minimal(mock_create_run: MagicMock, + mock_get_run: MagicMock): + callback = AsyncEval(BASIC_PARAMS, + interval='2ba', + compute={ + 'cluster': 'c2z3', + 'nodes': 2, + }) + assert callback.current_run.name == RUN_NAME + assert mock_get_run.call_count == 1 + assert mock_get_run.call_args[0][0] == RUN_NAME + + launch_time = Time(1, TimeUnit.BATCH) + callback.launch_run('checkpoint/path', launch_time) + assert mock_create_run.call_count == 1 + + run_config_created = mock_create_run.call_args[0][0] + assert run_config_created.name == 'eval-1ba-foo_bar' + assert run_config_created.image == 'fake-image' + + metadata = run_config_created.metadata + assert 'eval_timestamp' in metadata + assert isinstance(metadata['eval_timestamp'], int) + assert metadata['eval_timestamp'] == launch_time.value + + assert 'eval_timestamp_unit' in metadata + assert isinstance(metadata['eval_timestamp_unit'], str) + assert metadata['eval_timestamp_unit'] == launch_time.unit.value + + assert 'cd llm-foundry/scripts' in run_config_created.command + + integrations = run_config_created.integrations + assert len(integrations) == 1 + assert integrations[0]['integration_type'] == 'git_repo' + assert integrations[0]['git_repo'] == 'mosaicml/llm-foundry' + assert integrations[0]['git_branch'].startswith('v') + + compute = run_config_created.compute + assert compute['cluster'] == 'c2z3' + assert compute['nodes'] == 2 + + parameters = run_config_created.parameters + assert parameters['device_eval_batch_size'] == 2 + assert parameters['icl_tasks'] == 'icl_task_example' + assert parameters['max_seq_len'] == 3 + assert parameters['models'] == [{ + 'model_name': 'model_example', + 'model': { + 'name': 'model_example', + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, + }, + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example' + }, + 'load_path': 'checkpoint/path', + }] + assert parameters['run_name'] == 'eval-1ba-foo_bar' # original run + + +INTEGRATION_GIT_LLMFOUNDRY = { + 'integration_type': 'git_repo', + 'git_repo': 'mosaicml/llm-foundry', + 'git_branch': 'custom_branch', + 'path': 'custom/llm-foundry', + 'pip_install': '-e .[gpu]', + 'ssh_clone': False, +} +INTEGRATION_GIT_RANDOM = { + 'integration_type': 'git_repo', + 'git_repo': 'another-repo', + 'git_branch': 'foobar', +} + +FAKE_RUN_WITH_INTEGRATIONS = deepcopy(FAKE_RUN) +FAKE_RUN_WITH_INTEGRATIONS.submitted_config.integrations = [ + INTEGRATION_GIT_LLMFOUNDRY, INTEGRATION_GIT_RANDOM +] + + +@patch('llmfoundry.callbacks.async_eval_callback.get_run', + return_value=FAKE_RUN_WITH_INTEGRATIONS) +@patch('llmfoundry.callbacks.async_eval_callback.create_run', + return_value=FAKE_RUN_WITH_INTEGRATIONS) +def test_async_eval_callback_integrations(mock_create_run: MagicMock, + mock_get_run: MagicMock): + callback = AsyncEval(BASIC_PARAMS, + interval='2ba', + compute={ + 'cluster': 'c2z3', + 'nodes': 2, + }) + assert mock_get_run.call_count == 1 + + callback.launch_run('checkpoint/path', Time(1, TimeUnit.BATCH)) + assert mock_create_run.call_count == 1 + run_config_created = mock_create_run.call_args[0][0] + + assert len(run_config_created.integrations) == 2 + # order should be retained + assert run_config_created.integrations[0] == INTEGRATION_GIT_LLMFOUNDRY + assert run_config_created.integrations[1] == INTEGRATION_GIT_RANDOM + + custom_path = run_config_created.integrations[0]['path'] + assert f'cd {custom_path}/scripts' in run_config_created.command diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 5c38ed8602..9be6630075 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -12,6 +12,7 @@ import torch.nn as nn from composer.callbacks import Generate from composer.core import Evaluator +from composer.loggers import WandBLogger from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase @@ -20,8 +21,8 @@ from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_callback, build_eval_loaders, - build_evaluators, build_optimizer, - build_tokenizer) + build_evaluators, build_logger, + build_optimizer, build_tokenizer) @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -49,7 +50,7 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): def test_build_callback_fails(): with pytest.raises(ValueError): - build_callback('nonexistent_callback', {}) + build_callback('nonexistent_callback', {}, {}) @pytest.mark.parametrize( @@ -65,12 +66,15 @@ def test_build_generate_callback( autospec=True) as mock_generate: mock_generate.return_value = None build_callback( - 'generate_callback', { + 'generate_callback', + { 'prompts': ['hello'], interval_key: interval_value, 'foo': 'bar', 'something': 'else', - }) + }, + {}, + ) assert mock_generate.call_count == 1 _, _, kwargs = mock_generate.mock_calls[0] @@ -85,11 +89,15 @@ def test_build_generate_callback_unspecified_interval(): with mock.patch.object(Generate, '__init__', autospec=True) as mock_generate: mock_generate.return_value = None - build_callback('generate_callback', { - 'prompts': ['hello'], - 'foo': 'bar', - 'something': 'else', - }) + build_callback( + 'generate_callback', + { + 'prompts': ['hello'], + 'foo': 'bar', + 'something': 'else', + }, + {}, + ) def test_build_hf_checkpointer_callback(): @@ -111,7 +119,8 @@ def test_build_hf_checkpointer_callback(): 'save_folder': save_folder, 'save_interval': save_interval, 'mlflow_logging_config': mlflow_logging_config_dict - })) + }), + config={}) assert mock_hf_checkpointer.call_count == 1 _, _, kwargs = mock_hf_checkpointer.mock_calls[0] @@ -122,6 +131,31 @@ def test_build_hf_checkpointer_callback(): assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict +def test_build_logger(): + with pytest.raises(ValueError): + _ = build_logger('unknown', {}) + + logger_cfg = DictConfig({ + 'project': 'foobar', + 'init_kwargs': { + 'config': { + 'foo': 'bar', + } + } + }) + wandb_logger = build_logger('wandb', logger_cfg) # type: ignore + assert isinstance(wandb_logger, WandBLogger) + assert wandb_logger.project == 'foobar' + + # confirm the typing conversion from DictConfig to dict, + # wandb.init() will fail if config is not explicitly + # dict type + ik = wandb_logger._init_kwargs + assert ik == {'config': {'foo': 'bar'}, 'project': 'foobar'} + assert isinstance(ik, dict) + assert isinstance(ik['config'], dict) + + class _DummyModule(nn.Module): def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32):