From 1eef7b89192b9b78d7307f36be8d6c7b5677df52 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 2 Dec 2022 10:20:24 +0000 Subject: [PATCH] Add `pre` to ModelBlock fit/evaluate --- merlin/models/tf/models/base.py | 27 +++++++++++++++++++++++--- tests/unit/tf/models/test_retrieval.py | 2 +- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index db6d914ea4..c87b52192a 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -168,6 +168,7 @@ def fit( workers=1, use_multiprocessing=False, train_metrics_steps=1, + pre=None, **kwargs, ): x = _maybe_convert_merlin_dataset(x, batch_size, **kwargs) @@ -175,13 +176,23 @@ def fit( validation_data, batch_size, shuffle=shuffle, **kwargs ) callbacks = self._add_metrics_callback(callbacks, train_metrics_steps) + fit_kwargs = { k: v for k, v in locals().items() - if k not in ["self", "kwargs", "train_metrics_steps", "__class__"] + if k not in ["self", "kwargs", "train_metrics_steps", "pre", "__class__"] } - return super().fit(**fit_kwargs) + if pre: + self._reset_compile_cache() + self.train_pre = pre + + out = super().fit(**fit_kwargs) + + if pre: + del self.train_pre + + return out def evaluate( self, @@ -196,11 +207,16 @@ def evaluate( workers=1, use_multiprocessing=False, return_dict=False, + pre=None, **kwargs, ): x = _maybe_convert_merlin_dataset(x, batch_size, **kwargs) - return super().evaluate( + if pre: + self._reset_compile_cache() + self.test_pre = pre + + out = super().evaluate( x, y, batch_size, @@ -215,6 +231,11 @@ def evaluate( **kwargs, ) + if pre: + del self.test_pre + + return out + def compute_output_shape(self, input_shape): return self.block.compute_output_shape(input_shape) diff --git a/tests/unit/tf/models/test_retrieval.py b/tests/unit/tf/models/test_retrieval.py index c441790963..6e28827cf8 100644 --- a/tests/unit/tf/models/test_retrieval.py +++ b/tests/unit/tf/models/test_retrieval.py @@ -987,5 +987,5 @@ def test_youtube_dnn_topk_evaluation(sequence_testing_data: Dataset, run_eagerly topk_model = model.to_top_k_encoder(k=20) topk_model.compile(run_eagerly=run_eagerly) - metrics = topk_model.evaluate(dataloader, return_dict=True) + metrics = topk_model.evaluate(dataloader, return_dict=True, pre=predict_next) assert all([metric >= 0 for metric in metrics.values()])