Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] getting error when using AverageEmbeddingsByWeightFeature class in TT model #970

Closed
rnyak opened this issue Feb 1, 2023 · 0 comments · Fixed by #973
Closed

[BUG] getting error when using AverageEmbeddingsByWeightFeature class in TT model #970

rnyak opened this issue Feb 1, 2023 · 0 comments · Fixed by #973
Assignees
Labels
bug Something isn't working P0 status/needs-triage

Comments

@rnyak
Copy link
Contributor

rnyak commented Feb 1, 2023

Bug description

I would like to use AverageEmbeddingsByWeightFeature class to take the weighted average the embeddings of a list column in user tower but I am getting the following error from user tower :

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[11], line 1
----> 1 user_inputs(batch[0])

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/tabular.py:478, in _tabular_call(self, inputs, pre, post, merge_with, aggregation, *args, **kwargs)
    475 inputs = self.pre_call(inputs, transformations=pre)
    477 # This will call the `call` method implemented by the super class.
--> 478 outputs = self.super().__call__(inputs, *args, **kwargs)  # type: ignore
    480 if isinstance(outputs, dict):
    481     outputs = self.post_call(
    482         outputs, transformations=post, merge_with=merge_with, aggregation=aggregation
    483     )

File /usr/local/lib/python3.8/dist-packages/merlin/models/config/schema.py:58, in SchemaMixin.__call__(self, *args, **kwargs)
     55 def __call__(self, *args, **kwargs):
     56     self.check_schema()
---> 58     return super().__call__(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:566, in ParallelBlock.call(self, inputs, **kwargs)
    564 for name, layer in self.parallel_dict.items():
    565     layer_inputs = self._maybe_filter_layer_inputs_using_schema(name, layer, inputs)
--> 566     out = call_layer(layer, layer_inputs, **kwargs)
    567     if not isinstance(out, dict):
    568         out = {name: out}

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/tabular.py:478, in _tabular_call(self, inputs, pre, post, merge_with, aggregation, *args, **kwargs)
    475 inputs = self.pre_call(inputs, transformations=pre)
    477 # This will call the `call` method implemented by the super class.
--> 478 outputs = self.super().__call__(inputs, *args, **kwargs)  # type: ignore
    480 if isinstance(outputs, dict):
    481     outputs = self.post_call(
    482         outputs, transformations=post, merge_with=merge_with, aggregation=aggregation
    483     )

File /usr/local/lib/python3.8/dist-packages/merlin/models/config/schema.py:58, in SchemaMixin.__call__(self, *args, **kwargs)
     55 def __call__(self, *args, **kwargs):
     56     self.check_schema()
---> 58     return super().__call__(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:566, in ParallelBlock.call(self, inputs, **kwargs)
    564 for name, layer in self.parallel_dict.items():
    565     layer_inputs = self._maybe_filter_layer_inputs_using_schema(name, layer, inputs)
--> 566     out = call_layer(layer, layer_inputs, **kwargs)
    567     if not isinstance(out, dict):
    568         out = {name: out}

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/config/schema.py:58, in SchemaMixin.__call__(self, *args, **kwargs)
     55 def __call__(self, *args, **kwargs):
     56     self.check_schema()
---> 58     return super().__call__(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/inputs/embedding.py:386, in EmbeddingTable.call(self, inputs, **kwargs)
    384             out[feature_name] = self._call_table(inputs[feature_name], **kwargs)
    385 else:
--> 386     out = self._call_table(inputs, **kwargs)
    388 return out

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/inputs/embedding.py:418, in EmbeddingTable._call_table(self, inputs, **kwargs)
    416         out = call_layer(self.table, inputs, **kwargs)
    417         if isinstance(self.sequence_combiner, tf.keras.layers.Layer):
--> 418             out = call_layer(self.sequence_combiner, out, **kwargs)
    419 else:
    420     out = call_layer(self.table, inputs, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

TypeError: Exception encountered when calling layer "channel_id_hist" "                 f"(type EmbeddingTable).

call() missing 1 required positional argument: 'features'

Call arguments received by layer "channel_id_hist" "                 f"(type EmbeddingTable):
  • inputs=('tf.Tensor(shape=(41, 1), dtype=int64)', 'tf.Tensor(shape=(16, 1), dtype=int32)')
  • kwargs={'training': 'None'}

AND THE FOLLOWING ERROR FROM model.fit():

usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
  warnings.warn(
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[12], line 2
      1 model.compile(optimizer="adam", run_eagerly=False, metrics=[mm.RecallAt(10), mm.NDCGAt(10)])
----> 2 model.fit(train, batch_size=128, epochs=2)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py:1157, in BaseModel.fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing, train_metrics_steps, pre, **kwargs)
   1154     self._reset_compile_cache()
   1155     self.train_pre = pre
-> 1157 out = super().fit(**fit_kwargs)
   1159 if pre:
   1160     del self.train_pre

File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py:1494, in Model.call(self, inputs, targets, training, testing, output_context)
   1491     outputs, context = self._call_child(self.pre, outputs, context)
   1493 for block in self.blocks:
-> 1494     outputs, context = self._call_child(block, outputs, context)
   1496 if self.post:
   1497     outputs, context = self._call_child(self.post, outputs, context)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py:1523, in Model._call_child(self, child, inputs, context)
   1520 if any(isinstance(sub, ModelBlock) for sub in child.submodules):
   1521     del call_kwargs["features"]
-> 1523 outputs = call_layer(child, inputs, **call_kwargs)
   1524 if isinstance(outputs, Prediction):
   1525     targets = outputs.targets if outputs.targets is not None else context.targets

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/tabular.py:478, in _tabular_call(self, inputs, pre, post, merge_with, aggregation, *args, **kwargs)
    475 inputs = self.pre_call(inputs, transformations=pre)
    477 # This will call the `call` method implemented by the super class.
--> 478 outputs = self.super().__call__(inputs, *args, **kwargs)  # type: ignore
    480 if isinstance(outputs, dict):
    481     outputs = self.post_call(
    482         outputs, transformations=post, merge_with=merge_with, aggregation=aggregation
    483     )

File /usr/local/lib/python3.8/dist-packages/merlin/models/config/schema.py:58, in SchemaMixin.__call__(self, *args, **kwargs)
     55 def __call__(self, *args, **kwargs):
     56     self.check_schema()
---> 58     return super().__call__(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:566, in ParallelBlock.call(self, inputs, **kwargs)
    564 for name, layer in self.parallel_dict.items():
    565     layer_inputs = self._maybe_filter_layer_inputs_using_schema(name, layer, inputs)
--> 566     out = call_layer(layer, layer_inputs, **kwargs)
    567     if not isinstance(out, dict):
    568         out = {name: out}

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/encoder.py:185, in Encoder.__call__(self, inputs, **kwargs)
    182 if "features" in kwargs:
    183     kwargs.pop("features")
--> 185 return super().__call__(inputs, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/encoder.py:161, in Encoder.call(self, inputs, training, testing, targets, **kwargs)
    160 def call(self, inputs, training=False, testing=False, targets=None, **kwargs):
--> 161     return combinators.call_sequentially(
    162         list(self.to_call),
    163         inputs=inputs,
    164         features=inputs,
    165         targets=targets,
    166         training=training,
    167         testing=testing,
    168         **kwargs,
    169     )

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:836, in call_sequentially(layers, inputs, **kwargs)
    834 outputs = inputs
    835 for layer in layers:
--> 836     outputs = call_layer(layer, outputs, **kwargs)
    838 return outputs

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/tabular.py:478, in _tabular_call(self, inputs, pre, post, merge_with, aggregation, *args, **kwargs)
    475 inputs = self.pre_call(inputs, transformations=pre)
    477 # This will call the `call` method implemented by the super class.
--> 478 outputs = self.super().__call__(inputs, *args, **kwargs)  # type: ignore
    480 if isinstance(outputs, dict):
    481     outputs = self.post_call(
    482         outputs, transformations=post, merge_with=merge_with, aggregation=aggregation
    483     )

File /usr/local/lib/python3.8/dist-packages/merlin/models/config/schema.py:58, in SchemaMixin.__call__(self, *args, **kwargs)
     55 def __call__(self, *args, **kwargs):
     56     self.check_schema()
---> 58     return super().__call__(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:566, in ParallelBlock.call(self, inputs, **kwargs)
    564 for name, layer in self.parallel_dict.items():
    565     layer_inputs = self._maybe_filter_layer_inputs_using_schema(name, layer, inputs)
--> 566     out = call_layer(layer, layer_inputs, **kwargs)
    567     if not isinstance(out, dict):
    568         out = {name: out}

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/tabular.py:478, in _tabular_call(self, inputs, pre, post, merge_with, aggregation, *args, **kwargs)
    475 inputs = self.pre_call(inputs, transformations=pre)
    477 # This will call the `call` method implemented by the super class.
--> 478 outputs = self.super().__call__(inputs, *args, **kwargs)  # type: ignore
    480 if isinstance(outputs, dict):
    481     outputs = self.post_call(
    482         outputs, transformations=post, merge_with=merge_with, aggregation=aggregation
    483     )

File /usr/local/lib/python3.8/dist-packages/merlin/models/config/schema.py:58, in SchemaMixin.__call__(self, *args, **kwargs)
     55 def __call__(self, *args, **kwargs):
     56     self.check_schema()
---> 58     return super().__call__(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:566, in ParallelBlock.call(self, inputs, **kwargs)
    564 for name, layer in self.parallel_dict.items():
    565     layer_inputs = self._maybe_filter_layer_inputs_using_schema(name, layer, inputs)
--> 566     out = call_layer(layer, layer_inputs, **kwargs)
    567     if not isinstance(out, dict):
    568         out = {name: out}

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/config/schema.py:58, in SchemaMixin.__call__(self, *args, **kwargs)
     55 def __call__(self, *args, **kwargs):
     56     self.check_schema()
---> 58     return super().__call__(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/inputs/embedding.py:386, in EmbeddingTable.call(self, inputs, **kwargs)
    384             out[feature_name] = self._call_table(inputs[feature_name], **kwargs)
    385 else:
--> 386     out = self._call_table(inputs, **kwargs)
    388 return out

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/inputs/embedding.py:418, in EmbeddingTable._call_table(self, inputs, **kwargs)
    416         out = call_layer(self.table, inputs, **kwargs)
    417         if isinstance(self.sequence_combiner, tf.keras.layers.Layer):
--> 418             out = call_layer(self.sequence_combiner, out, **kwargs)
    419 else:
    420     out = call_layer(self.table, inputs, **kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:437, in call_layer(layer, inputs, *args, **kwargs)
    433         call_fn = type(layer).call
    435         filtered_kwargs = filter_kwargs(filtered_kwargs, call_fn, **_k)
--> 437 return layer(inputs, *args, **filtered_kwargs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/inputs/embedding.py:639, in AverageEmbeddingsByWeightFeature.call(self, inputs, features)
    638 def call(self, inputs, features):
--> 639     weight_feature = features[self.weight_feature_name]
    640     if isinstance(inputs, tf.RaggedTensor) and not isinstance(weight_feature, tf.RaggedTensor):
    641         raise ValueError(
    642             f"If inputs is a tf.RaggedTensor, the weight feature ({self.weight_feature_name}) "
    643             f"should also be a tf.RaggedTensor (and not a {type(weight_feature)}), "
    644             "so that the list length can vary per example for both input embedding "
    645             "and weight features."
    646         )
KeyError: 'Exception encountered when calling layer "average_embeddings_by_weight_feature" "                 f"(type AverageEmbeddingsByWeightFeature).\n\nchannel_id_hist_weights\n\nCall arguments received by layer "average_embeddings_by_weight_feature" "                 f"(type AverageEmbeddingsByWeightFeature):\n  • inputs=<tf.RaggedTensor [[[-0.031741597, -0.0075855963, 0.017847631, -0.046845544, 0.018529568,\n   0.048255417, ....

Steps/Code to reproduce bug

Please run the code in this gist file to repro the issue.

NOTE: THIS REQUIRES THIS PR'S #968 TO BE MERGED. or you can add get_config() func at the end of the AverageEmbeddingsByWeightFeature class by yourself.

Expected behavior

It should work without any errors.

Environment details

  • Merlin version:
  • Platform:
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?):

Using merlin-tensorflow:22.12 image with the latest main branches pulled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P0 status/needs-triage
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants