From b74aa6966d889bacf73bf27785892c68a45ff60b Mon Sep 17 00:00:00 2001 From: edknv Date: Mon, 26 Dec 2022 13:32:50 -0800 Subject: [PATCH] Use tf.function for list column operations --- merlin/models/tf/outputs/topk.py | 4 ++-- merlin/models/tf/transforms/features.py | 1 + merlin/models/tf/transforms/sequence.py | 2 ++ merlin/models/tf/utils/tf_utils.py | 1 + requirements/test.txt | 2 +- tests/unit/tf/test_loader.py | 24 ++++++++++++++++++------ 6 files changed, 25 insertions(+), 9 deletions(-) diff --git a/merlin/models/tf/outputs/topk.py b/merlin/models/tf/outputs/topk.py index 937b984297..933c92e969 100644 --- a/merlin/models/tf/outputs/topk.py +++ b/merlin/models/tf/outputs/topk.py @@ -20,7 +20,7 @@ from tensorflow.keras.layers import Layer import merlin.io -from merlin.core.dispatch import DataFrameType +from merlin.core.dispatch import DataFrameType, make_df from merlin.models.tf.core.base import Block, block_registry from merlin.models.tf.core.prediction import Prediction, TopKPrediction from merlin.models.tf.outputs.base import MetricsFn, ModelOutput @@ -100,7 +100,7 @@ def extract_ids_embeddings(self, data: merlin.io.Dataset, check_unique_ids: bool if check_unique_ids: self._check_unique_ids(data=data) values = tf_utils.df_to_tensor(data) - ids = tf_utils.df_to_tensor(data.index) + ids = tf_utils.df_to_tensor(make_df({"index": data.index})) if len(ids.shape) == 2: ids = tf.squeeze(ids) diff --git a/merlin/models/tf/transforms/features.py b/merlin/models/tf/transforms/features.py index 48c63e9cad..2f0406f0c5 100644 --- a/merlin/models/tf/transforms/features.py +++ b/merlin/models/tf/transforms/features.py @@ -881,6 +881,7 @@ def _get_seq_features_shapes(self, inputs: TabularData): return seq_features_shapes, sequence_length + @tf.function def _broadcast(self, inputs, target): seq_features_shapes, sequence_length = self._get_seq_features_shapes(inputs) if len(seq_features_shapes) > 0: diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index aad51da98d..279b728a86 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -371,6 +371,7 @@ class SequenceTargetAsInput(SequenceTransform): so that the tensors sequences can be processed """ + @tf.function def call( self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs ) -> Prediction: @@ -441,6 +442,7 @@ def __init__( self.masking_prob = masking_prob super().__init__(schema, target, **kwargs) + @tf.function def compute_mask(self, inputs, mask=None): """Selects (masks) some positions of the targets to be predicted. This method is called by Keras after call() diff --git a/merlin/models/tf/utils/tf_utils.py b/merlin/models/tf/utils/tf_utils.py index d09afbbed1..5eea607458 100644 --- a/merlin/models/tf/utils/tf_utils.py +++ b/merlin/models/tf/utils/tf_utils.py @@ -459,6 +459,7 @@ def get_sub_blocks(blocks: Sequence[Block]) -> List[Block]: return list(result_blocks) +@tf.function def list_col_to_ragged(col: Tuple[tf.Tensor, tf.Tensor]): values = col[0][:, 0] row_lengths = col[1][:, 0] diff --git a/requirements/test.txt b/requirements/test.txt index 6d53d632e3..0853df7f90 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,5 @@ -r dev.txt -r pytorch.txt -tensorflow<2.10 +tensorflow<2.11 numpy<1.24 diff --git a/tests/unit/tf/test_loader.py b/tests/unit/tf/test_loader.py index c94fa430ec..7b51bc9704 100644 --- a/tests/unit/tf/test_loader.py +++ b/tests/unit/tf/test_loader.py @@ -72,17 +72,29 @@ def test_nested_list(): ) batch = next(iter(loader)) + # [[1,2,3],[3,1],[...],[]] - nested_data_col = tf.RaggedTensor.from_row_lengths( - batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32) - ).to_tensor() + @tf.function + def _ragged_for_nested_data_col(): + nested_data_col = tf.RaggedTensor.from_row_lengths( + batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32) + ).to_tensor() + return nested_data_col + + nested_data_col = _ragged_for_nested_data_col() true_data_col = tf.reshape( tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(), [batch_size, -1] ) + # [1,2,3] - multihot_data2_col = tf.RaggedTensor.from_row_lengths( - batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32) - ).to_tensor() + @tf.function + def _ragged_for_multihot_data_col(): + multihot_data2_col = tf.RaggedTensor.from_row_lengths( + batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32) + ).to_tensor() + return multihot_data2_col + + multihot_data2_col = _ragged_for_multihot_data_col() true_data2_col = tf.reshape( tf.ragged.constant(df.iloc[:batch_size, 1].tolist()).to_tensor(), [batch_size, -1] )