From 2eef281b0be5de536957e4a7166ef8518f80dd28 Mon Sep 17 00:00:00 2001 From: woodybury Date: Thu, 7 Nov 2024 12:23:41 -0500 Subject: [PATCH] public get_callbacks for better extensions --- nam/train/core.py | 4 ++-- tests/test_nam/test_train/test_core.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/nam/train/core.py b/nam/train/core.py index b719a3b..79cb1ae 100644 --- a/nam/train/core.py +++ b/nam/train/core.py @@ -1235,7 +1235,7 @@ def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: nam_path.unlink() -def _get_callbacks( +def get_callbacks( threshold_esr: Optional[float], user_metadata: Optional[UserMetadata] = None, settings_metadata: Optional[metadata.Settings] = None, @@ -1432,7 +1432,7 @@ def parse_user_latency( data_metadata = metadata.Data(latency=latency_analysis, checks=data_check_output) trainer = pl.Trainer( - callbacks=_get_callbacks( + callbacks=get_callbacks( threshold_esr, user_metadata=user_metadata, settings_metadata=settings_metadata, diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py index 9544312..a38c123 100644 --- a/tests/test_nam/test_train/test_core.py +++ b/tests/test_nam/test_train/test_core.py @@ -295,5 +295,30 @@ def test_end_to_end(): assert isinstance(train_output.model, Model) +def test_get_callbacks(): + """ + Sanity check for get_callbacks with a custom extension callback and threshold_esr + """ + threshold_esr = 0.01 + callbacks = core.get_callbacks(threshold_esr=threshold_esr) + + # dumb example of a user-extended custom callback + class CustomCallback: + pass + extended_callbacks = callbacks + [CustomCallback()] + + # sanity default callbacks + assert any(isinstance(cb, core._ModelCheckpoint) for cb in extended_callbacks), \ + "Expected _ModelCheckpoint to be part of the default callbacks." + + # custom callback + assert any(isinstance(cb, CustomCallback) for cb in extended_callbacks), \ + "Expected CustomCallback to be added to the extended callbacks." + + # _ValidationStopping cb when threshold_esr is prvided + assert any(isinstance(cb, core._ValidationStopping) for cb in extended_callbacks), \ + "_ValidationStopping should still be present after adding a custom callback." + + if __name__ == "__main__": pytest.main()