Skip to content

Commit

Permalink
fix contrastive output (#800)
Browse files Browse the repository at this point in the history
  • Loading branch information
sararb authored Oct 11, 2022
1 parent abab6e1 commit 8ac9090
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
4 changes: 4 additions & 0 deletions merlin/models/tf/losses/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
tf.Tensor
Loss per example
"""
tf.assert_equal(tf.rank(y_true), 2, f"Targets must be 2-D tensor (got {y_true.shape})")

tf.assert_equal(tf.rank(y_pred), 2, f"Predictions must be 2-D tensor (got {y_pred.shape})")

(
positives_scores,
negatives_scores,
Expand Down
1 change: 0 additions & 1 deletion merlin/models/tf/outputs/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def outputs(
# To ensure that the output is always fp32, avoiding numerical
# instabilities with mixed_float16 policy
outputs = tf.cast(outputs, tf.float32)
outputs = tf.squeeze(outputs)

targets = tf.concat(
[
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/tf/outputs/test_contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,21 @@ def test_contrastive_only_positive_when_not_training(ecommerce_data: Dataset):
)


@pytest.mark.parametrize("run_eagerly", [True, False])
def test_contrastive_output_with_pairwise_loss(ecommerce_data: Dataset, run_eagerly):
model = mm.RetrievalModelV2(
query=mm.Encoder(ecommerce_data.schema.select_by_tag(Tags.USER), mm.MLPBlock([2])),
candidate=mm.Encoder(ecommerce_data.schema.select_by_tag(Tags.ITEM), mm.MLPBlock([2])),
output=mm.ContrastiveOutput(
ecommerce_data.schema.select_by_tag(Tags.ITEM_ID),
negative_samplers="in-batch",
candidate_name="item",
),
)
model.compile(run_eagerly=run_eagerly, loss="bpr-max")
_ = model.fit(ecommerce_data, batch_size=50, epochs=1)


def _retrieval_inputs_(batch_size):
users_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32)
items_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32)
Expand Down

0 comments on commit 8ac9090

Please sign in to comment.