Skip to content

Commit

Permalink
Add pre to ModelBlock fit/evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy committed Dec 2, 2022
1 parent 1abd8d5 commit 1eef7b8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
27 changes: 24 additions & 3 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,31 @@ def fit(
workers=1,
use_multiprocessing=False,
train_metrics_steps=1,
pre=None,
**kwargs,
):
x = _maybe_convert_merlin_dataset(x, batch_size, **kwargs)
validation_data = _maybe_convert_merlin_dataset(
validation_data, batch_size, shuffle=shuffle, **kwargs
)
callbacks = self._add_metrics_callback(callbacks, train_metrics_steps)

fit_kwargs = {
k: v
for k, v in locals().items()
if k not in ["self", "kwargs", "train_metrics_steps", "__class__"]
if k not in ["self", "kwargs", "train_metrics_steps", "pre", "__class__"]
}

return super().fit(**fit_kwargs)
if pre:
self._reset_compile_cache()
self.train_pre = pre

out = super().fit(**fit_kwargs)

if pre:
del self.train_pre

return out

def evaluate(
self,
Expand All @@ -196,11 +207,16 @@ def evaluate(
workers=1,
use_multiprocessing=False,
return_dict=False,
pre=None,
**kwargs,
):
x = _maybe_convert_merlin_dataset(x, batch_size, **kwargs)

return super().evaluate(
if pre:
self._reset_compile_cache()
self.test_pre = pre

out = super().evaluate(
x,
y,
batch_size,
Expand All @@ -215,6 +231,11 @@ def evaluate(
**kwargs,
)

if pre:
del self.test_pre

return out

def compute_output_shape(self, input_shape):
return self.block.compute_output_shape(input_shape)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tf/models/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,5 +987,5 @@ def test_youtube_dnn_topk_evaluation(sequence_testing_data: Dataset, run_eagerly
topk_model = model.to_top_k_encoder(k=20)
topk_model.compile(run_eagerly=run_eagerly)

metrics = topk_model.evaluate(dataloader, return_dict=True)
metrics = topk_model.evaluate(dataloader, return_dict=True, pre=predict_next)
assert all([metric >= 0 for metric in metrics.values()])

0 comments on commit 1eef7b8

Please sign in to comment.