diff --git a/src/cloudai/__init__.py b/src/cloudai/__init__.py index a652626b..19852423 100644 --- a/src/cloudai/__init__.py +++ b/src/cloudai/__init__.py @@ -103,7 +103,9 @@ Registry().add_strategy(InstallStrategy, [SlurmSystem], [NcclTest], NcclTestSlurmInstallStrategy) Registry().add_strategy(InstallStrategy, [SlurmSystem], [NeMoLauncher], NeMoLauncherSlurmInstallStrategy) -Registry().add_strategy(ReportGenerationStrategy, [SlurmSystem], [NcclTest], NcclTestReportGenerationStrategy) +Registry().add_strategy( + ReportGenerationStrategy, [SlurmSystem, KubernetesSystem], [NcclTest], NcclTestReportGenerationStrategy +) Registry().add_strategy(CommandGenStrategy, [StandaloneSystem], [Sleep], SleepStandaloneCommandGenStrategy) Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [Sleep], SleepSlurmCommandGenStrategy) Registry().add_strategy(JsonGenStrategy, [KubernetesSystem], [Sleep], SleepKubernetesJsonGenStrategy) diff --git a/tests/test_init.py b/tests/test_init.py index deb3a65b..25a7b51d 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -120,6 +120,7 @@ def test_runners(): ((ReportGenerationStrategy, SlurmSystem, ChakraReplay), ChakraReplayReportGenerationStrategy), ((ReportGenerationStrategy, SlurmSystem, JaxToolbox), JaxToolboxReportGenerationStrategy), ((ReportGenerationStrategy, SlurmSystem, NcclTest), NcclTestReportGenerationStrategy), + ((ReportGenerationStrategy, KubernetesSystem, NcclTest), NcclTestReportGenerationStrategy), ((ReportGenerationStrategy, SlurmSystem, NeMoLauncher), NeMoLauncherReportGenerationStrategy), ((ReportGenerationStrategy, SlurmSystem, Sleep), SleepReportGenerationStrategy), ((ReportGenerationStrategy, SlurmSystem, UCCTest), UCCTestReportGenerationStrategy),