diff --git a/merlin/models/tf/inputs/embedding.py b/merlin/models/tf/inputs/embedding.py index b38807b817..d4b612077f 100644 --- a/merlin/models/tf/inputs/embedding.py +++ b/merlin/models/tf/inputs/embedding.py @@ -615,6 +615,7 @@ def _get_dim(col, embedding_dims, infer_dim_fn): return dim +@tf.keras.utils.register_keras_serializable(package="merlin.models") class AverageEmbeddingsByWeightFeature(tf.keras.layers.Layer): def __init__(self, weight_feature_name: str, axis=1, **kwargs): """Computes the weighted average of a Tensor based @@ -691,6 +692,13 @@ def from_schema_convention(schema: Schema, weight_features_name_suffix: str = "_ return seq_combiners + def get_config(self): + config = super().get_config() + config["axis"] = self.axis + config["weight_feature_name"] = self.weight_feature_name + + return config + @dataclass class EmbeddingOptions: