Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tf.function for list column operations #938

Merged
merged 1 commit into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions merlin/models/tf/outputs/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tensorflow.keras.layers import Layer

import merlin.io
from merlin.core.dispatch import DataFrameType
from merlin.core.dispatch import DataFrameType, make_df
from merlin.models.tf.core.base import Block, block_registry
from merlin.models.tf.core.prediction import Prediction, TopKPrediction
from merlin.models.tf.outputs.base import MetricsFn, ModelOutput
Expand Down Expand Up @@ -100,7 +100,7 @@ def extract_ids_embeddings(self, data: merlin.io.Dataset, check_unique_ids: bool
if check_unique_ids:
self._check_unique_ids(data=data)
values = tf_utils.df_to_tensor(data)
ids = tf_utils.df_to_tensor(data.index)
ids = tf_utils.df_to_tensor(make_df({"index": data.index}))

if len(ids.shape) == 2:
ids = tf.squeeze(ids)
Expand Down
1 change: 1 addition & 0 deletions merlin/models/tf/transforms/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ def _get_seq_features_shapes(self, inputs: TabularData):

return seq_features_shapes, sequence_length

@tf.function
def _broadcast(self, inputs, target):
seq_features_shapes, sequence_length = self._get_seq_features_shapes(inputs)
if len(seq_features_shapes) > 0:
Expand Down
2 changes: 2 additions & 0 deletions merlin/models/tf/transforms/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ class SequenceTargetAsInput(SequenceTransform):
so that the tensors sequences can be processed
"""

@tf.function
def call(
self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs
) -> Prediction:
Expand Down Expand Up @@ -441,6 +442,7 @@ def __init__(
self.masking_prob = masking_prob
super().__init__(schema, target, **kwargs)

@tf.function
def compute_mask(self, inputs, mask=None):
"""Selects (masks) some positions of the targets to be predicted.
This method is called by Keras after call()
Expand Down
1 change: 1 addition & 0 deletions merlin/models/tf/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def get_sub_blocks(blocks: Sequence[Block]) -> List[Block]:
return list(result_blocks)


@tf.function
def list_col_to_ragged(col: Tuple[tf.Tensor, tf.Tensor]):
values = col[0][:, 0]
row_lengths = col[1][:, 0]
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-r dev.txt
-r pytorch.txt

tensorflow<2.10
tensorflow<2.11
numpy<1.24
24 changes: 18 additions & 6 deletions tests/unit/tf/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,29 @@ def test_nested_list():
)

batch = next(iter(loader))

# [[1,2,3],[3,1],[...],[]]
nested_data_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
).to_tensor()
@tf.function
def _ragged_for_nested_data_col():
nested_data_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
).to_tensor()
return nested_data_col

nested_data_col = _ragged_for_nested_data_col()
true_data_col = tf.reshape(
tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(), [batch_size, -1]
)

# [1,2,3]
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
).to_tensor()
@tf.function
def _ragged_for_multihot_data_col():
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
).to_tensor()
return multihot_data2_col

multihot_data2_col = _ragged_for_multihot_data_col()
true_data2_col = tf.reshape(
tf.ragged.constant(df.iloc[:batch_size, 1].tolist()).to_tensor(), [batch_size, -1]
)
Expand Down