diff --git a/.github/mcp/mcp_pytest.py b/.github/mcp/mcp_pytest.py index c823d0b151..25f5153b4e 100644 --- a/.github/mcp/mcp_pytest.py +++ b/.github/mcp/mcp_pytest.py @@ -4,8 +4,9 @@ """Run pytest using MCP.""" import argparse +from concurrent.futures import TimeoutError -from mcli.sdk import RunConfig, RunStatus, create_run, follow_run_logs, wait_for_run_status +from mcli.sdk import RunConfig, RunStatus, create_run, get_run_logs, stop_runs, wait_for_run_status if __name__ == '__main__': @@ -74,13 +75,26 @@ # Create run run = create_run(config) - # Wait till run starts before fetching logs + # Wait until run starts before fetching logs run = wait_for_run_status(run, status='running') + print('Run started. Waiting for run to complete...') + + # Wait up to 30 minutes for run to complete + try: + run = wait_for_run_status(run, status='completed', timeout=60 * 30) + except TimeoutError: + print('Run timed out and did not complete in 30 minutes.') + + # Get run status and stop run + success = run.status == RunStatus.COMPLETED + print(f'Run completed with status: {run.status} (success={success})') + if run.status == RunStatus.RUNNING: + stop_runs([run]) + print('Run stopped.') # Print logs - for line in follow_run_logs(run): + for line in get_run_logs(run): print(line, end='') - # Fail if command exited with non-zero exit code - run = wait_for_run_status(run, 'completed') - assert run.status == RunStatus.COMPLETED + # Fail if command exited with non-zero exit code or timed out + assert success diff --git a/composer/algorithms/squeeze_excite/README.md b/composer/algorithms/squeeze_excite/README.md index 10d447d816..36887d6a7a 100644 --- a/composer/algorithms/squeeze_excite/README.md +++ b/composer/algorithms/squeeze_excite/README.md @@ -75,7 +75,7 @@ trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, - max_duration='10ep', + max_duration='1ep', algorithms=[algo] ) diff --git a/setup.py b/setup.py index af97f3db37..5656118ecd 100644 --- a/setup.py +++ b/setup.py @@ -170,6 +170,7 @@ def package_files(prefix: str, directory: str, extension: str): ] extra_deps['nlp'] = [ + 'datasets>=2.4,<3', 'transformers>=4.11,<5', ]