diff --git a/src/cloudai/__init__.py b/src/cloudai/__init__.py index 84626b10..7edafbff 100644 --- a/src/cloudai/__init__.py +++ b/src/cloudai/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from cloudai.installer.slurm_installer import SlurmInstaller +from cloudai.installer.standalone_installer import StandaloneInstaller from cloudai.parser.system_parser.slurm_system_parser import SlurmSystemParser from cloudai.parser.system_parser.standalone_system_parser import StandaloneSystemParser from cloudai.runner.slurm.slurm_runner import SlurmRunner @@ -61,7 +63,7 @@ from ._core.registry import Registry from .grader import Grader -from .installer import Installer +from .installer.installer import Installer from .parser.core.parser import Parser from .report_generator import ReportGenerator from .runner.core.runner import Runner @@ -112,6 +114,9 @@ Registry().add_test_template("Sleep", Sleep) Registry().add_test_template("UCCTest", UCCTest) +Registry().add_installer("slurm", SlurmInstaller) +Registry().add_installer("standalone", StandaloneInstaller) + __all__ = [ "Grader", "Installer", diff --git a/src/cloudai/_core/registry.py b/src/cloudai/_core/registry.py index c34f627a..d0b81f1d 100644 --- a/src/cloudai/_core/registry.py +++ b/src/cloudai/_core/registry.py @@ -1,5 +1,6 @@ from typing import Dict, List, Tuple, Type, Union +from cloudai.installer.base_installer import BaseInstaller from cloudai.parser.core.base_system_parser import BaseSystemParser from cloudai.runner.core.base_runner import BaseRunner from cloudai.schema.core.strategy.job_id_retrieval_strategy import JobIdRetrievalStrategy @@ -34,6 +35,7 @@ class Registry(metaclass=Singleton): Type[Union[TestTemplateStrategy, ReportGenerationStrategy, JobIdRetrievalStrategy]], ] = {} test_templates_map: Dict[str, Type[TestTemplate]] = {} + installers_map: Dict[str, Type[BaseInstaller]] = {} def add_system_parser(self, name: str, value: Type[BaseSystemParser]) -> None: """ @@ -171,3 +173,33 @@ def update_test_template(self, name: str, value: Type[TestTemplate]) -> None: f"Invalid test template implementation for '{name}', should be subclass of 'TestTemplate'." ) self.test_templates_map[name] = value + + def add_installer(self, name: str, value: Type[BaseInstaller]) -> None: + """ + Add a new installer implementation mapping. + + Args: + name (str): The name of the installer. + value (Type[BaseInstaller]): The installer implementation. + + Raises: + ValueError: If the installer implementation already exists. + """ + if name in self.installers_map: + raise ValueError(f"Duplicating implementation for '{name}', use 'update()' for replacement.") + self.update_installer(name, value) + + def update_installer(self, name: str, value: Type[BaseInstaller]) -> None: + """ + Create or replace installer implementation mapping. + + Args: + name (str): The name of the installer. + value (Type[BaseInstaller]): The installer implementation. + + Raises: + ValueError: If value is not a subclass of BaseInstaller. + """ + if not issubclass(value, BaseInstaller): + raise ValueError(f"Invalid installer implementation for '{name}', should be subclass of 'BaseInstaller'.") + self.installers_map[name] = value diff --git a/src/cloudai/installer/__init__.py b/src/cloudai/installer/__init__.py deleted file mode 100644 index 20fe0da9..00000000 --- a/src/cloudai/installer/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .base_installer import BaseInstaller -from .installer import Installer -from .slurm_installer import SlurmInstaller -from .standalone_installer import StandaloneInstaller - -__all__ = [ - "StandaloneInstaller", - "SlurmInstaller", - "Installer", - "BaseInstaller", -] diff --git a/src/cloudai/installer/installer.py b/src/cloudai/installer/installer.py index 160395ee..dec96e11 100644 --- a/src/cloudai/installer/installer.py +++ b/src/cloudai/installer/installer.py @@ -13,8 +13,9 @@ # limitations under the License. import logging -from typing import Callable, Iterable +from typing import Iterable +from cloudai._core.registry import Registry from cloudai.schema.core.system import System from cloudai.schema.core.test_template import TestTemplate @@ -36,8 +37,6 @@ class Installer: logger (logging.Logger): Logger for capturing installation activities. """ - _installers = {} - def __init__(self, system: System): """ Initialize the Installer with a system object and installation path. @@ -49,25 +48,6 @@ def __init__(self, system: System): self.logger.info("Initializing Installer with system configuration.") self.installer = self.create_installer(system) - @classmethod - def register(cls, scheduler_type: str) -> Callable: - """ - Register installer subclasses under specific scheduler types. - - Args: - scheduler_type (str): The scheduler type string that the installer - subclass can handle. - - Returns: - Callable: A decorator function that registers the installer class. - """ - - def decorator(installer_class): - cls._installers[scheduler_type] = installer_class - return installer_class - - return decorator - @classmethod def create_installer(cls, system: System) -> BaseInstaller: """ @@ -85,7 +65,8 @@ def create_installer(cls, system: System) -> BaseInstaller: system's scheduler. """ scheduler_type = system.scheduler - installer_class = cls._installers.get(scheduler_type) + registry = Registry() + installer_class = registry.installers_map.get(scheduler_type) if installer_class is None: raise NotImplementedError(f"No installer available for scheduler: {scheduler_type}") return installer_class(system) diff --git a/src/cloudai/installer/slurm_installer.py b/src/cloudai/installer/slurm_installer.py index 07d8654d..e6b5faee 100644 --- a/src/cloudai/installer/slurm_installer.py +++ b/src/cloudai/installer/slurm_installer.py @@ -18,16 +18,13 @@ from typing import Iterable, cast import toml - from cloudai.schema.core.system import System from cloudai.schema.core.test_template import TestTemplate from cloudai.schema.system import SlurmSystem from .base_installer import BaseInstaller -from .installer import Installer -@Installer.register("slurm") class SlurmInstaller(BaseInstaller): """ Installer for systems that use the Slurm scheduler. diff --git a/src/cloudai/installer/standalone_installer.py b/src/cloudai/installer/standalone_installer.py index 73667e91..46641d47 100644 --- a/src/cloudai/installer/standalone_installer.py +++ b/src/cloudai/installer/standalone_installer.py @@ -13,10 +13,8 @@ # limitations under the License. from .base_installer import BaseInstaller -from .installer import Installer -@Installer.register("standalone") class StandaloneInstaller(BaseInstaller): """ Installer for systems that do not use a scheduler (standalone systems). diff --git a/tests/test_init.py b/tests/test_init.py index 1459b496..78823bd6 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,6 +1,8 @@ import cloudai # noqa: F401 import pytest from cloudai._core.registry import Registry +from cloudai.installer.slurm_installer import SlurmInstaller +from cloudai.installer.standalone_installer import StandaloneInstaller from cloudai.schema.core.strategy.command_gen_strategy import CommandGenStrategy from cloudai.schema.core.strategy.grading_strategy import GradingStrategy from cloudai.schema.core.strategy.install_strategy import InstallStrategy @@ -111,3 +113,10 @@ def test_test_templates(): assert test_templates["NeMoLauncher"] == NeMoLauncher assert test_templates["Sleep"] == Sleep assert test_templates["UCCTest"] == UCCTest + + +def test_installers(): + installers = Registry().installers_map + assert len(installers) == 2 + assert installers["standalone"] == StandaloneInstaller + assert installers["slurm"] == SlurmInstaller diff --git a/tests/test_registry.py b/tests/test_registry.py index 1d1fcf18..aa28e7d8 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,5 +1,6 @@ import pytest -from cloudai._core.registry import Registry +from cloudai import Registry +from cloudai.installer.base_installer import BaseInstaller from cloudai.parser.core.base_system_parser import BaseSystemParser from cloudai.runner.core.base_runner import BaseRunner from cloudai.schema.core.strategy.job_id_retrieval_strategy import JobIdRetrievalStrategy @@ -202,3 +203,37 @@ def test_invalid_type(self, registry: Registry): with pytest.raises(ValueError) as exc_info: registry.update_test_template("TestTemplate", str) # pyright: ignore assert "Invalid test template implementation for 'TestTemplate'" in str(exc_info.value) + + +class MyInstaller(BaseInstaller): + pass + + +class AnotherInstaller(BaseInstaller): + pass + + +class TestRegistry__Installers: + """This test verifies Registry class functionality. + + Since Registry is a Singleton, the order of cases is important. + Only covers the installers_map attribute. + """ + + def test_add_installer(self, registry: Registry): + registry.add_installer("installer", MyInstaller) + assert registry.installers_map["installer"] == MyInstaller + + def test_add_installer_duplicate(self, registry: Registry): + with pytest.raises(ValueError) as exc_info: + registry.add_installer("installer", MyInstaller) + assert "Duplicating implementation for 'installer'" in str(exc_info.value) + + def test_update_installer(self, registry: Registry): + registry.update_installer("installer", AnotherInstaller) + assert registry.installers_map["installer"] == AnotherInstaller + + def test_invalid_type(self, registry: Registry): + with pytest.raises(ValueError) as exc_info: + registry.update_installer("TestInstaller", str) # pyright: ignore + assert "Invalid installer implementation for 'TestInstaller'" in str(exc_info.value)