diff --git a/merlin/models/tf/losses/pairwise.py b/merlin/models/tf/losses/pairwise.py index 03c795a063..97045a4cec 100644 --- a/merlin/models/tf/losses/pairwise.py +++ b/merlin/models/tf/losses/pairwise.py @@ -72,6 +72,10 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: tf.Tensor Loss per example """ + tf.assert_equal(tf.rank(y_true), 2, f"Targets must be 2-D tensor (got {y_true.shape})") + + tf.assert_equal(tf.rank(y_pred), 2, f"Predictions must be 2-D tensor (got {y_pred.shape})") + ( positives_scores, negatives_scores, diff --git a/merlin/models/tf/outputs/contrastive.py b/merlin/models/tf/outputs/contrastive.py index 9f712d11d0..f5eb0497f6 100644 --- a/merlin/models/tf/outputs/contrastive.py +++ b/merlin/models/tf/outputs/contrastive.py @@ -226,7 +226,6 @@ def outputs( # To ensure that the output is always fp32, avoiding numerical # instabilities with mixed_float16 policy outputs = tf.cast(outputs, tf.float32) - outputs = tf.squeeze(outputs) targets = tf.concat( [ diff --git a/tests/unit/tf/outputs/test_contrastive.py b/tests/unit/tf/outputs/test_contrastive.py index 90b8b077b3..04c92863dc 100644 --- a/tests/unit/tf/outputs/test_contrastive.py +++ b/tests/unit/tf/outputs/test_contrastive.py @@ -190,6 +190,21 @@ def test_contrastive_only_positive_when_not_training(ecommerce_data: Dataset): ) +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_contrastive_output_with_pairwise_loss(ecommerce_data: Dataset, run_eagerly): + model = mm.RetrievalModelV2( + query=mm.Encoder(ecommerce_data.schema.select_by_tag(Tags.USER), mm.MLPBlock([2])), + candidate=mm.Encoder(ecommerce_data.schema.select_by_tag(Tags.ITEM), mm.MLPBlock([2])), + output=mm.ContrastiveOutput( + ecommerce_data.schema.select_by_tag(Tags.ITEM_ID), + negative_samplers="in-batch", + candidate_name="item", + ), + ) + model.compile(run_eagerly=run_eagerly, loss="bpr-max") + _ = model.fit(ecommerce_data, batch_size=50, epochs=1) + + def _retrieval_inputs_(batch_size): users_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32) items_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32)