Skip to content

Commit

Permalink
Use tf.function for list column operations
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv committed Dec 26, 2022
1 parent c52b2ca commit b74aa69
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 9 deletions.
4 changes: 2 additions & 2 deletions merlin/models/tf/outputs/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tensorflow.keras.layers import Layer

import merlin.io
from merlin.core.dispatch import DataFrameType
from merlin.core.dispatch import DataFrameType, make_df
from merlin.models.tf.core.base import Block, block_registry
from merlin.models.tf.core.prediction import Prediction, TopKPrediction
from merlin.models.tf.outputs.base import MetricsFn, ModelOutput
Expand Down Expand Up @@ -100,7 +100,7 @@ def extract_ids_embeddings(self, data: merlin.io.Dataset, check_unique_ids: bool
if check_unique_ids:
self._check_unique_ids(data=data)
values = tf_utils.df_to_tensor(data)
ids = tf_utils.df_to_tensor(data.index)
ids = tf_utils.df_to_tensor(make_df({"index": data.index}))

if len(ids.shape) == 2:
ids = tf.squeeze(ids)
Expand Down
1 change: 1 addition & 0 deletions merlin/models/tf/transforms/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ def _get_seq_features_shapes(self, inputs: TabularData):

return seq_features_shapes, sequence_length

@tf.function
def _broadcast(self, inputs, target):
seq_features_shapes, sequence_length = self._get_seq_features_shapes(inputs)
if len(seq_features_shapes) > 0:
Expand Down
2 changes: 2 additions & 0 deletions merlin/models/tf/transforms/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ class SequenceTargetAsInput(SequenceTransform):
so that the tensors sequences can be processed
"""

@tf.function
def call(
self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs
) -> Prediction:
Expand Down Expand Up @@ -441,6 +442,7 @@ def __init__(
self.masking_prob = masking_prob
super().__init__(schema, target, **kwargs)

@tf.function
def compute_mask(self, inputs, mask=None):
"""Selects (masks) some positions of the targets to be predicted.
This method is called by Keras after call()
Expand Down
1 change: 1 addition & 0 deletions merlin/models/tf/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def get_sub_blocks(blocks: Sequence[Block]) -> List[Block]:
return list(result_blocks)


@tf.function
def list_col_to_ragged(col: Tuple[tf.Tensor, tf.Tensor]):
values = col[0][:, 0]
row_lengths = col[1][:, 0]
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-r dev.txt
-r pytorch.txt

tensorflow<2.10
tensorflow<2.11
numpy<1.24
24 changes: 18 additions & 6 deletions tests/unit/tf/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,29 @@ def test_nested_list():
)

batch = next(iter(loader))

# [[1,2,3],[3,1],[...],[]]
nested_data_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
).to_tensor()
@tf.function
def _ragged_for_nested_data_col():
nested_data_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
).to_tensor()
return nested_data_col

nested_data_col = _ragged_for_nested_data_col()
true_data_col = tf.reshape(
tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(), [batch_size, -1]
)

# [1,2,3]
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
).to_tensor()
@tf.function
def _ragged_for_multihot_data_col():
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
).to_tensor()
return multihot_data2_col

multihot_data2_col = _ragged_for_multihot_data_col()
true_data2_col = tf.reshape(
tf.ragged.constant(df.iloc[:batch_size, 1].tolist()).to_tensor(), [batch_size, -1]
)
Expand Down

0 comments on commit b74aa69

Please sign in to comment.