From 9749e3a75008036f92a0df09ac9f4d6b53ada650 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 15 May 2024 17:24:45 -0400 Subject: [PATCH] Allow overriding of runner mappings in Runner class --- src/cloudai/runner/core/runner.py | 2 -- tests/runner/core/test_runner.py | 21 +++++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 tests/runner/core/test_runner.py diff --git a/src/cloudai/runner/core/runner.py b/src/cloudai/runner/core/runner.py index 29a2cce9..7568da30 100644 --- a/src/cloudai/runner/core/runner.py +++ b/src/cloudai/runner/core/runner.py @@ -61,8 +61,6 @@ def register(cls, system_type: str) -> Callable: """ def decorator(runner_class: Type[BaseRunner]) -> Type[BaseRunner]: - if system_type in cls._runners: - raise KeyError(f"Runner for {system_type} already registered.") cls._runners[system_type] = runner_class return runner_class diff --git a/tests/runner/core/test_runner.py b/tests/runner/core/test_runner.py new file mode 100644 index 00000000..2bdc18da --- /dev/null +++ b/tests/runner/core/test_runner.py @@ -0,0 +1,21 @@ +from unittest.mock import patch + +from cloudai.runner.core import BaseRunner, Runner + + +def test_register_multiple_runners(): + """ + Test registering multiple different runners for the same type to ensure that + only the last registered runner is kept. + """ + with patch.dict("cloudai.runner.Runner._runners", clear=True): + + @Runner.register("slurm") + class FirstSlurmRunner(BaseRunner): + pass + + @Runner.register("slurm") + class SecondSlurmRunner(BaseRunner): + pass + + assert Runner._runners["slurm"] is SecondSlurmRunner