From 5173a96155c11355744a5aa287749ce59f05a41e Mon Sep 17 00:00:00 2001 From: sararb Date: Tue, 8 Nov 2022 15:58:17 -0500 Subject: [PATCH] add support of different thresholds `k` in topk-encoder --- merlin/models/tf/core/encoder.py | 51 +++++++++++++++++++++++++++++- merlin/models/tf/outputs/topk.py | 30 +++++++++++++++--- tests/unit/tf/core/test_encoder.py | 22 ++++++++++--- 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/merlin/models/tf/core/encoder.py b/merlin/models/tf/core/encoder.py index 1478f2268c..75931a4823 100644 --- a/merlin/models/tf/core/encoder.py +++ b/merlin/models/tf/core/encoder.py @@ -1,3 +1,19 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from typing import Dict, Optional, Union import numpy as np @@ -139,7 +155,7 @@ def batch_predict( return merlin.io.Dataset(predictions) - def call(self, inputs, training=False, testing=False, targets=None): + def call(self, inputs, training=False, testing=False, targets=None, **kwargs): return combinators.call_sequentially( list(self.to_call), inputs=inputs, @@ -147,6 +163,7 @@ def call(self, inputs, training=False, testing=False, targets=None): targets=targets, training=training, testing=testing, + **kwargs, ) def build(self, input_shape): @@ -344,6 +361,38 @@ def from_candidate_dataset( ) return cls(query_encoder, topk_output, **kwargs) + def compile( + self, + optimizer="rmsprop", + loss=None, + metrics=None, + loss_weights=None, + weighted_metrics=None, + run_eagerly=None, + steps_per_execution=None, + jit_compile=None, + k: int = None, + **kwargs, + ): + """Extend the compile method of `BaseModel` to set the threshold `k` + of the top-k encoder. + """ + if k is not None: + self.topk_layer._k = k + self.k = k + BaseModel.compile( + self, + optimizer=optimizer, + loss=loss, + metrics=metrics, + weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + loss_weights=loss_weights, + steps_per_execution=steps_per_execution, + jit_compile=jit_compile, + **kwargs, + ) + @property def topk_layer(self): return self.blocks[-1].to_call diff --git a/merlin/models/tf/outputs/topk.py b/merlin/models/tf/outputs/topk.py index 4fb84cbe2a..937b984297 100644 --- a/merlin/models/tf/outputs/topk.py +++ b/merlin/models/tf/outputs/topk.py @@ -1,3 +1,19 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from typing import Optional, Union import tensorflow as tf @@ -90,7 +106,7 @@ def extract_ids_embeddings(self, data: merlin.io.Dataset, check_unique_ids: bool ids = tf.squeeze(ids) return ids, values - def call(self, inputs: tf.Tensor, targets=None, testing=False, **kwargs) -> tf.Tensor: + def call(self, inputs: tf.Tensor, targets=None, testing=False, k=None, **kwargs) -> tf.Tensor: """Method to return the tuple of top-k (ids, scores)""" raise NotImplementedError() @@ -166,9 +182,10 @@ def index(self, candidates: tf.Tensor, identifiers: Optional[tf.Tensor] = None) def call( self, - inputs, - targets=None, - testing=False, + inputs: tf.Tensor, + targets: tf.Tensor = None, + testing: bool = False, + k: int = None, ) -> Union[Prediction, TopKPrediction]: """Compute the scores between the query inputs and all indexed candidates, then retrieve the top-k candidates with the highest scores. @@ -181,7 +198,10 @@ def call( The tensor of positive candidates testing: bool Flag that indicates whether in evaluation mode, by default False + k: int + Number of candidates to return """ + k = k if k is not None else self._k if self._candidates is None: raise ValueError( "You should call the `index` method first to " "set the _candidates index." @@ -195,7 +215,7 @@ def call( f" dimension of {tf.shape(self._candidates)[1]} ", ) scores = self._score(inputs, self._candidates) - top_scores, top_idx = tf.math.top_k(scores, k=self._k) + top_scores, top_idx = tf.math.top_k(scores, k=k) top_ids = tf.gather(self._ids, top_idx) if testing: assert targets is not None, ValueError( diff --git a/tests/unit/tf/core/test_encoder.py b/tests/unit/tf/core/test_encoder.py index 0e2664e5a5..07ffe0ab63 100644 --- a/tests/unit/tf/core/test_encoder.py +++ b/tests/unit/tf/core/test_encoder.py @@ -46,6 +46,7 @@ def test_encoder_block(music_streaming_data: Dataset): def test_topk_encoder(music_streaming_data: Dataset): TOP_K = 10 + BATCH_SIZE = 32 music_streaming_data.schema = music_streaming_data.schema.select_by_name( ["user_id", "item_id", "country", "user_age"] ) @@ -68,12 +69,12 @@ def test_topk_encoder(music_streaming_data: Dataset): # 2. Get candidates embeddings for the top-k encoder candidate_features = unique_rows_by_features(music_streaming_data, Tags.ITEM, Tags.ITEM_ID) candidates = retrieval_model.candidate_embeddings( - candidate_features, batch_size=10, index=Tags.ITEM_ID + candidate_features, batch_size=BATCH_SIZE, index=Tags.ITEM_ID ) # 3. Set data-loader for top-k recommendation loader = mm.Loader( - music_streaming_data, batch_size=32, transform=mm.ToTarget(schema, "item_id") + music_streaming_data, batch_size=BATCH_SIZE, transform=mm.ToTarget(schema, "item_id") ) batch = next(iter(loader)) @@ -84,7 +85,7 @@ def test_topk_encoder(music_streaming_data: Dataset): # 5. Get top-k predictions batch_output = topk_encoder(batch[0]) predict_output = topk_encoder.predict(loader) - assert list(batch_output.scores.shape) == [32, TOP_K] + assert list(batch_output.scores.shape) == [BATCH_SIZE, TOP_K] assert list(predict_output.scores.shape) == [100, TOP_K] # 6. Compute top-k evaluation metrics (using the whole candidates catalog) @@ -116,10 +117,23 @@ def test_topk_encoder(music_streaming_data: Dataset): loaded_topk_encoder = tf.keras.models.load_model(tmpdir) batch_output = loaded_topk_encoder(batch[0]) - assert list(batch_output.scores.shape) == [32, TOP_K] + assert list(batch_output.scores.shape) == [BATCH_SIZE, TOP_K] tf.debugging.assert_equal( topk_encoder.topk_layer._candidates, loaded_topk_encoder.topk_layer._candidates, ) assert not loaded_topk_encoder.topk_layer._candidates.trainable + + # 9. Change the top-k threshold + scores = topk_encoder(batch[0], k=20) + assert list(scores.scores.shape) == [BATCH_SIZE, 20] + scores = topk_encoder(batch[0], k=30) + assert list(scores.scores.shape) == [BATCH_SIZE, 30] + + topk_encoder.compile(k=20) + scores = topk_encoder.predict(loader) + assert list(scores.scores.shape) == [100, 20] + topk_encoder.compile(k=30) + scores = topk_encoder.predict(loader) + assert list(scores.scores.shape) == [100, 30]