Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Registry for test templates #35

Merged
merged 9 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/cloudai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cloudai.schema.test_template # noqa
from cloudai.parser.system_parser.slurm_system_parser import SlurmSystemParser
from cloudai.parser.system_parser.standalone_system_parser import StandaloneSystemParser
from cloudai.runner.slurm.slurm_runner import SlurmRunner
Expand Down Expand Up @@ -106,6 +105,12 @@
Registry().add_strategy(GradingStrategy, [SlurmSystem], [ChakraReplay], ChakraReplayGradingStrategy)
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [ChakraReplay], ChakraReplaySlurmCommandGenStrategy)

Registry().add_test_template("ChakraReplay", ChakraReplay)
Registry().add_test_template("JaxToolbox", JaxToolbox)
Registry().add_test_template("NcclTest", NcclTest)
Registry().add_test_template("NeMoLauncher", NeMoLauncher)
Registry().add_test_template("Sleep", Sleep)
Registry().add_test_template("UCCTest", UCCTest)

__all__ = [
"Grader",
Expand Down
33 changes: 33 additions & 0 deletions src/cloudai/_core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Registry(metaclass=Singleton):
],
Type[Union[TestTemplateStrategy, ReportGenerationStrategy, JobIdRetrievalStrategy]],
] = {}
test_templates_map: Dict[str, Type[TestTemplate]] = {}

def add_system_parser(self, name: str, value: Type[BaseSystemParser]) -> None:
"""
Expand Down Expand Up @@ -138,3 +139,35 @@ def update_strategy(
):
raise ValueError(f"Invalid strategy implementation {value}, should be subclass of 'TestTemplateStrategy'.")
self.strategies_map[key] = value

def add_test_template(self, name: str, value: Type[TestTemplate]) -> None:
"""
Add a new test template implementation mapping.

Args:
name (str): The name of the test template.
value (Type[TestTemplate]): The test template implementation.

Raises:
ValueError: If the test template implementation already exists.
"""
if name in self.test_templates_map:
raise ValueError(f"Duplicating implementation for '{name}', use 'update()' for replacement.")
self.update_test_template(name, value)

def update_test_template(self, name: str, value: Type[TestTemplate]) -> None:
"""
Create or replace test template implementation mapping.

Args:
name (str): The name of the test template.
value (Type[TestTemplate]): The test template implementation.

Raises:
ValueError: If value is not a subclass of TestTemplate.
"""
if not issubclass(value, TestTemplate):
raise ValueError(
f"Invalid test template implementation for '{name}', should be subclass of 'TestTemplate'."
)
self.test_templates_map[name] = value
15 changes: 1 addition & 14 deletions src/cloudai/parser/core/test_template_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,13 @@ def _get_test_template_class(self, name: str) -> Type[TestTemplate]:
Returns:
Type[TestTemplate]: A subclass of TestTemplate corresponding to the given name.
"""
template_classes = self._enumerate_test_template_classes()
template_classes = Registry().test_templates_map

if name in template_classes:
return template_classes[name]
else:
raise ValueError(f"Unsupported test_template name: {name}")

@staticmethod
def _enumerate_test_template_classes() -> Dict[str, Type[TestTemplate]]:
"""
Dynamically enumerate all subclasses of TestTemplate available in the current namespace.

Maps their class names to the class objects.

Returns
Dict[str, Type[TestTemplate]]: A dictionary mapping class names to
TestTemplate subclasses.
"""
return {cls.__name__: cls for cls in TestTemplate.__subclasses__()}

def _extract_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract arguments, maintaining their structure, and includes 'values' and 'default' fields where they exist.
Expand Down
11 changes: 11 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,14 @@ def test_runners():
def test_strategies(key: tuple, value: type):
strategies = Registry().strategies_map
assert strategies[key] == value


def test_test_templates():
test_templates = Registry().test_templates_map
assert len(test_templates) == 6
assert test_templates["ChakraReplay"] == ChakraReplay
assert test_templates["JaxToolbox"] == JaxToolbox
assert test_templates["NcclTest"] == NcclTest
assert test_templates["NeMoLauncher"] == NeMoLauncher
assert test_templates["Sleep"] == Sleep
assert test_templates["UCCTest"] == UCCTest
26 changes: 26 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,29 @@ def test_add_multiple_strategies(self, registry: Registry):
assert registry.strategies_map[(MyStrategy, MySystem, AnotherTestTemplate)] == MyStrategy
assert registry.strategies_map[(MyStrategy, AnotherSystem, MyTestTemplate)] == MyStrategy
assert registry.strategies_map[(MyStrategy, AnotherSystem, AnotherTestTemplate)] == MyStrategy


class TestRegistry__TestTemplatesMap:
"""This test verifies Registry class functionality.

Since Registry is a Singleton, the order of cases is important.
Only covers the test_templates_map attribute.
"""

def test_add_test_template(self, registry: Registry):
registry.add_test_template("test_template", MyTestTemplate)
assert registry.test_templates_map["test_template"] == MyTestTemplate

def test_add_test_template_duplicate(self, registry: Registry):
with pytest.raises(ValueError) as exc_info:
registry.add_test_template("test_template", MyTestTemplate)
assert "Duplicating implementation for 'test_template'" in str(exc_info.value)

def test_update_test_template(self, registry: Registry):
registry.update_test_template("test_template", AnotherTestTemplate)
assert registry.test_templates_map["test_template"] == AnotherTestTemplate

def test_invalid_type(self, registry: Registry):
with pytest.raises(ValueError) as exc_info:
registry.update_test_template("TestTemplate", str) # pyright: ignore
assert "Invalid test template implementation for 'TestTemplate'" in str(exc_info.value)