From f8e5830db5f3c2ec9210711b5729d56f92b79fb9 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 10 Mar 2023 09:48:10 +0000 Subject: [PATCH 01/11] Add Padding transform to enable removing sparse output from dataloader --- transformers4rec/torch/utils/data_utils.py | 102 ++++++++++++++++++++- 1 file changed, 97 insertions(+), 5 deletions(-) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index fceb4b62f5..6a32efcdd3 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -18,6 +18,8 @@ import warnings from abc import ABC +from typing import Dict, Optional + import numpy as np import torch from merlin.dataloader.torch import Loader @@ -323,10 +325,6 @@ def __init__( reader_kwargs["row_groups_per_part"] = row_groups_per_part self.set_dataset(buffer_size, engine, reader_kwargs) - self.dataset.schema = _augment_schema( - self.dataset.schema, cats, conts, labels, sparse_names, sparse_max, sparse_as_dense - ) - if (global_rank is not None) and (self.dataset.npartitions < global_size): logger.warning( "UserWarning: User is advised to repartition the parquet file before training " @@ -355,7 +353,7 @@ def __init__( global_size=global_size, global_rank=global_rank, drop_last=drop_last, - ) + ).map(_get_pad_fn(sparse_max)) DLDataLoader.__init__( self, @@ -439,6 +437,100 @@ def from_schema( return loader +import torch.nn.functional as F + +def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor: + if length and len(t.shape) == 2: + pad_diff = length - t.shape[1] + return F.pad(input=t, pad=(0, pad_diff, 0, 0)) + return t + + +def _pad_ragged_tensor_1(values: torch.Tensor, offsets: torch.Tensor, padding_length: int): + num_rows = len(offsets) - 1 + padded_values = torch.zeros( + (num_rows, padding_length), + dtype=values.dtype, device=values.device + ) + for i in range(num_rows): + row_values = values[offsets[i] : offsets[i + 1]] + padded_values[i: len(row_values)] = row_values + return padded_values + +def transform_1(X, padding_length): + X_padded = {} + for k, values in X.items(): + if k.endswith("__values"): + col_name = k[:-8] + offsets = X[f"{col_name}__offsets"] + padded_values = _pad_ragged_tensor_1(values, offsets, padding_length) + X_padded[col_name] = padded_values + elif k.endswith("__offsets"): + continue + elif isinstance(values, tuple): + values, offsets = values + padded_values = _pad_ragged_tensor_1(values, offsets, padding_length) + X_padded[col_name] = padded_values + else: + X_padded[k] = _pad_dense_tensor(values, padding_length) + + return X_padded + + +def _get_indices(offsets, diff_offsets): + row_ids = torch.arange(len(offsets) - 1, device=offsets.device) + row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets) + row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets) + col_ids = torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated + indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], axis=1) + return indices + + +def _pad_ragged_tensor_2(values, offsets, padding_length): + num_rows = len(offsets) - 1 + diff_offsets = offsets[1:] - offsets[:-1] + indices = _get_indices(offsets, diff_offsets) + sparse_tensor = torch.sparse_coo_tensor( + indices.T, values, torch.Size([num_rows, padding_length]), device=values.device + ) + return sparse_tensor.to_dense() + + +def _pad_batch(X, padding_lengths, ragged_pad_fn): + X_padded = {} + for k, values in X.items(): + if k.endswith("__values"): + col_name = k[:-8] + offsets = X[f"{col_name}__offsets"] + padding_length = padding_lengths.get(col_name) + padded_values = ragged_pad_fn(values, offsets, padding_length) + X_padded[col_name] = padded_values + elif k.endswith("__offsets"): + continue + elif isinstance(values, tuple): + padding_length = padding_lengths.get(k) + if padding_length: + values, offsets = values + padded_values = ragged_pad_fn(values, offsets, padding_length) + X_padded[col_name] = padded_values + else: + X_padded[k] = values + else: + padding_length = padding_lengths.get(k) + X_padded[k] = _pad_dense_tensor(values, padding_length) + + return X_padded + + +def _get_pad_fn(padding_lengths: Dict[str, int]): + def pad_fn(x, y): + new_x = _pad_batch(x, padding_lengths, _pad_ragged_tensor_2) + new_y = _pad_batch(y, padding_lengths, _pad_ragged_tensor_2) + return new_x, new_y + + return pad_fn + + class ParquetDataset(Dataset): def __init__(self, parquet_file, cols_to_read, target_names, seq_features_len_pad_trim): self.cols_to_read = cols_to_read From ba923e2bfe05e591bc484c5f3947028db46a57d2 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 14 Mar 2023 12:47:47 +0000 Subject: [PATCH 02/11] Handle padding of 1-d and 2-d values/offsets --- transformers4rec/torch/utils/data_utils.py | 47 +++++++++------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index 6a32efcdd3..577ab0be08 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -17,11 +17,11 @@ import logging import warnings from abc import ABC - from typing import Dict, Optional import numpy as np import torch +import torch.nn.functional as F from merlin.dataloader.torch import Loader from merlin.models.utils.misc_utils import validate_dataset from merlin.models.utils.registry import Registry @@ -32,7 +32,6 @@ from merlin_standard_lib import Schema from ...utils import dependencies -from ..utils.schema_utils import _augment_schema logger = logging.getLogger(__name__) @@ -343,6 +342,9 @@ def __init__( f" GPUs ({global_size}). This will divide the work equally among GPUs" " for DDP training and ensure optimal performance." ) + + self.dataset.schema = self.dataset.schema.select_by_name(cats + conts + labels) + loader = Loader( self.dataset, self.batch_size, @@ -437,8 +439,6 @@ def from_schema( return loader -import torch.nn.functional as F - def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor: if length and len(t.shape) == 2: pad_diff = length - t.shape[1] @@ -446,36 +446,24 @@ def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor: return t +def _squeeze(tensor): + if len(tensor.shape) == 2: + return tensor.squeeze(1) + return tensor + + def _pad_ragged_tensor_1(values: torch.Tensor, offsets: torch.Tensor, padding_length: int): + values = _squeeze(values) + offsets = _squeeze(offsets) num_rows = len(offsets) - 1 padded_values = torch.zeros( - (num_rows, padding_length), - dtype=values.dtype, device=values.device + (num_rows, padding_length), dtype=values.dtype, device=values.device ) for i in range(num_rows): row_values = values[offsets[i] : offsets[i + 1]] - padded_values[i: len(row_values)] = row_values + padded_values[i, : len(row_values)] = row_values return padded_values -def transform_1(X, padding_length): - X_padded = {} - for k, values in X.items(): - if k.endswith("__values"): - col_name = k[:-8] - offsets = X[f"{col_name}__offsets"] - padded_values = _pad_ragged_tensor_1(values, offsets, padding_length) - X_padded[col_name] = padded_values - elif k.endswith("__offsets"): - continue - elif isinstance(values, tuple): - values, offsets = values - padded_values = _pad_ragged_tensor_1(values, offsets, padding_length) - X_padded[col_name] = padded_values - else: - X_padded[k] = _pad_dense_tensor(values, padding_length) - - return X_padded - def _get_indices(offsets, diff_offsets): row_ids = torch.arange(len(offsets) - 1, device=offsets.device) @@ -487,6 +475,8 @@ def _get_indices(offsets, diff_offsets): def _pad_ragged_tensor_2(values, offsets, padding_length): + values = _squeeze(values) + offsets = _squeeze(offsets) num_rows = len(offsets) - 1 diff_offsets = offsets[1:] - offsets[:-1] indices = _get_indices(offsets, diff_offsets) @@ -497,6 +487,9 @@ def _pad_ragged_tensor_2(values, offsets, padding_length): def _pad_batch(X, padding_lengths, ragged_pad_fn): + if not X: + return X + X_padded = {} for k, values in X.items(): if k.endswith("__values"): @@ -512,7 +505,7 @@ def _pad_batch(X, padding_lengths, ragged_pad_fn): if padding_length: values, offsets = values padded_values = ragged_pad_fn(values, offsets, padding_length) - X_padded[col_name] = padded_values + X_padded[k] = padded_values else: X_padded[k] = values else: From 86bfddb1c3022dd0a3f3dbaeae31078854f5d007 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 14 Mar 2023 15:03:23 +0000 Subject: [PATCH 03/11] Move padding function to padding module --- tests/utils/test_padding.py | 106 +++++++++++++++++++ transformers4rec/torch/utils/data_utils.py | 95 ++--------------- transformers4rec/torch/utils/padding.py | 98 +++++++++++++++++ transformers4rec/torch/utils/schema_utils.py | 20 +--- 4 files changed, 211 insertions(+), 108 deletions(-) create mode 100644 tests/utils/test_padding.py create mode 100644 transformers4rec/torch/utils/padding.py diff --git a/tests/utils/test_padding.py b/tests/utils/test_padding.py new file mode 100644 index 0000000000..07feab9f70 --- /dev/null +++ b/tests/utils/test_padding.py @@ -0,0 +1,106 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from itertools import accumulate + +import torch + +from transformers4rec.torch.utils.padding import get_pad_fn + + +def _get_values_offsets(data): + values = [] + row_lengths = [] + for row in data: + row_lengths.append(len(row)) + values += row + offsets = [0] + list(accumulate(row_lengths)) + return torch.tensor(values), torch.tensor(offsets) + + +def test_pad_values_offsets_tuple(): + data = [[1, 2], [], [3, 4, 5]] + pad_fn = get_pad_fn({"a": 5}) + values, offsets = _get_values_offsets(data) + + x = {"a": (values, offsets)} + y = torch.tensor([1, 0, 1]) + + padded_x, padded_y = pad_fn(x, y) + assert torch.equal( + padded_x["a"], + torch.tensor( + [ + [1, 2, 0, 0, 0], + [0, 0, 0, 0, 0], + [3, 4, 5, 0, 0], + ] + ), + ) + assert torch.equal(padded_y, y) + + +def test_pad_values_offsets_dict(): + data = [[1, 2], [], [3, 4, 5]] + pad_fn = get_pad_fn({"a": 7}) + values, offsets = _get_values_offsets(data) + + x = {"a__values": values, "a__offsets": offsets} + y = torch.tensor([1, 0, 1]) + + padded_x, padded_y = pad_fn(x, y) + assert torch.equal( + padded_x["a"], + torch.tensor( + [ + [1, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [3, 4, 5, 0, 0, 0, 0], + ] + ), + ) + assert torch.equal(padded_y, y) + + +def test_pad_values_dense(): + data = [[1, 2], [], [3, 4, 5]] + pad_fn = get_pad_fn({"a": 7, "b": 3}) + values, offsets = _get_values_offsets(data) + + x = {"a__values": values, "a__offsets": offsets, "b": torch.tensor([[3, 6], [4, 1], [8, 4]])} + y = torch.tensor([1, 0, 1]) + + padded_x, padded_y = pad_fn(x, y) + assert torch.equal( + padded_x["a"], + torch.tensor( + [ + [1, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [3, 4, 5, 0, 0, 0, 0], + ] + ), + ) + assert torch.equal( + padded_x["b"], + torch.tensor( + [ + [3, 6, 0], + [4, 1, 0], + [8, 4, 0], + ] + ), + ) + assert torch.equal(padded_y, y) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index 577ab0be08..cea0d4d1c2 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -17,11 +17,9 @@ import logging import warnings from abc import ABC -from typing import Dict, Optional import numpy as np import torch -import torch.nn.functional as F from merlin.dataloader.torch import Loader from merlin.models.utils.misc_utils import validate_dataset from merlin.models.utils.registry import Registry @@ -30,6 +28,8 @@ from torch.utils.data import Dataset, IterableDataset from merlin_standard_lib import Schema +from transformers4rec.torch.utils import schema_utils +from transformers4rec.torch.utils.padding import get_pad_fn from ...utils import dependencies @@ -343,7 +343,9 @@ def __init__( " for DDP training and ensure optimal performance." ) - self.dataset.schema = self.dataset.schema.select_by_name(cats + conts + labels) + self.dataset.schema = schema_utils._augment_schema( + self.dataset.schema, cats=cats, conts=conts, labels=labels + ) loader = Loader( self.dataset, @@ -355,7 +357,7 @@ def __init__( global_size=global_size, global_rank=global_rank, drop_last=drop_last, - ).map(_get_pad_fn(sparse_max)) + ).map(get_pad_fn(sparse_max)) DLDataLoader.__init__( self, @@ -439,91 +441,6 @@ def from_schema( return loader -def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor: - if length and len(t.shape) == 2: - pad_diff = length - t.shape[1] - return F.pad(input=t, pad=(0, pad_diff, 0, 0)) - return t - - -def _squeeze(tensor): - if len(tensor.shape) == 2: - return tensor.squeeze(1) - return tensor - - -def _pad_ragged_tensor_1(values: torch.Tensor, offsets: torch.Tensor, padding_length: int): - values = _squeeze(values) - offsets = _squeeze(offsets) - num_rows = len(offsets) - 1 - padded_values = torch.zeros( - (num_rows, padding_length), dtype=values.dtype, device=values.device - ) - for i in range(num_rows): - row_values = values[offsets[i] : offsets[i + 1]] - padded_values[i, : len(row_values)] = row_values - return padded_values - - -def _get_indices(offsets, diff_offsets): - row_ids = torch.arange(len(offsets) - 1, device=offsets.device) - row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets) - row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets) - col_ids = torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated - indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], axis=1) - return indices - - -def _pad_ragged_tensor_2(values, offsets, padding_length): - values = _squeeze(values) - offsets = _squeeze(offsets) - num_rows = len(offsets) - 1 - diff_offsets = offsets[1:] - offsets[:-1] - indices = _get_indices(offsets, diff_offsets) - sparse_tensor = torch.sparse_coo_tensor( - indices.T, values, torch.Size([num_rows, padding_length]), device=values.device - ) - return sparse_tensor.to_dense() - - -def _pad_batch(X, padding_lengths, ragged_pad_fn): - if not X: - return X - - X_padded = {} - for k, values in X.items(): - if k.endswith("__values"): - col_name = k[:-8] - offsets = X[f"{col_name}__offsets"] - padding_length = padding_lengths.get(col_name) - padded_values = ragged_pad_fn(values, offsets, padding_length) - X_padded[col_name] = padded_values - elif k.endswith("__offsets"): - continue - elif isinstance(values, tuple): - padding_length = padding_lengths.get(k) - if padding_length: - values, offsets = values - padded_values = ragged_pad_fn(values, offsets, padding_length) - X_padded[k] = padded_values - else: - X_padded[k] = values - else: - padding_length = padding_lengths.get(k) - X_padded[k] = _pad_dense_tensor(values, padding_length) - - return X_padded - - -def _get_pad_fn(padding_lengths: Dict[str, int]): - def pad_fn(x, y): - new_x = _pad_batch(x, padding_lengths, _pad_ragged_tensor_2) - new_y = _pad_batch(y, padding_lengths, _pad_ragged_tensor_2) - return new_x, new_y - - return pad_fn - - class ParquetDataset(Dataset): def __init__(self, parquet_file, cols_to_read, target_names, seq_features_len_pad_trim): self.cols_to_read = cols_to_read diff --git a/transformers4rec/torch/utils/padding.py b/transformers4rec/torch/utils/padding.py new file mode 100644 index 0000000000..125430c09a --- /dev/null +++ b/transformers4rec/torch/utils/padding.py @@ -0,0 +1,98 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict, Optional + +import torch +import torch.nn.functional as F + + +def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor: + if length and len(t.shape) == 2: + pad_diff = length - t.shape[1] + return F.pad(input=t, pad=(0, pad_diff, 0, 0)) + return t + + +def _squeeze(tensor): + if len(tensor.shape) == 2: + return tensor.squeeze(1) + return tensor + + +def _get_indices(offsets, diff_offsets): + row_ids = torch.arange(len(offsets) - 1, device=offsets.device) + row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets) + row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets) + col_ids = torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated + indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], axis=1) + return indices + + +def _pad_ragged_tensor(values, offsets, padding_length): + values = _squeeze(values) + offsets = _squeeze(offsets) + num_rows = len(offsets) - 1 + diff_offsets = offsets[1:] - offsets[:-1] + indices = _get_indices(offsets, diff_offsets) + sparse_tensor = torch.sparse_coo_tensor( + indices.T, values, torch.Size([num_rows, padding_length]), device=values.device + ) + return sparse_tensor.to_dense() + + +def _pad_batch(X, padding_lengths, ragged_pad_fn): + if X is None or not isinstance(X, dict): + return X + + X_padded = {} + for k, values in X.items(): + if k.endswith("__values"): + col_name = k[:-8] + offsets = X[f"{col_name}__offsets"] + padding_length = padding_lengths.get(col_name) + if padding_length: + padded_values = ragged_pad_fn(values, offsets, padding_length) + X_padded[col_name] = padded_values + else: + raise ValueError( + f"Found ragged column '{col_name}' with unspecified padding length. " + "Please provide a padding length for this feature " + "to be converted to a dense tensor. " + ) + elif k.endswith("__offsets"): + continue + elif isinstance(values, tuple): + padding_length = padding_lengths.get(k) + if padding_length: + values, offsets = values + padded_values = ragged_pad_fn(values, offsets, padding_length) + X_padded[k] = padded_values + else: + X_padded[k] = values + else: + padding_length = padding_lengths.get(k) + X_padded[k] = _pad_dense_tensor(values, padding_length) + + return X_padded + + +def get_pad_fn(padding_lengths: Dict[str, int]): + def pad_fn(x, y): + new_x = _pad_batch(x, padding_lengths, _pad_ragged_tensor) + new_y = _pad_batch(y, padding_lengths, _pad_ragged_tensor) + return new_x, new_y + + return pad_fn diff --git a/transformers4rec/torch/utils/schema_utils.py b/transformers4rec/torch/utils/schema_utils.py index a43421c5d5..82a94717e8 100644 --- a/transformers4rec/torch/utils/schema_utils.py +++ b/transformers4rec/torch/utils/schema_utils.py @@ -134,11 +134,8 @@ def _augment_schema( cats=None, conts=None, labels=None, - sparse_names=None, - sparse_max=None, - sparse_as_dense=False, ): - from merlin.schema import ColumnSchema, Tags + from merlin.schema import Tags schema = schema.select_by_name(conts + cats + labels) @@ -150,19 +147,4 @@ def _augment_schema( for label in conts or []: schema[label] = schema[label].with_tags(Tags.CONTINUOUS) - # Set the appropriate properties for the sparse_names/sparse_max/sparse_as_dense - for col in sparse_names or []: - cs = schema[col] - properties = cs.properties - if sparse_max and col in sparse_max: - properties["value_count"] = {"max": sparse_max[col]} - schema[col] = ColumnSchema( - name=cs.name, - tags=cs.tags, - dtype=cs.dtype, - is_list=True, - is_ragged=not sparse_as_dense, - properties=properties, - ) - return schema From ced153fe2b35a3463247a83c66972fd156a81f94 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 14 Mar 2023 15:13:37 +0000 Subject: [PATCH 04/11] Move _augment_schema to a method of MerlinDataLoader --- transformers4rec/torch/utils/data_utils.py | 23 ++++++++++++++++++-- transformers4rec/torch/utils/schema_utils.py | 21 ------------------ 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index cea0d4d1c2..fd2ae55768 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -28,7 +28,6 @@ from torch.utils.data import Dataset, IterableDataset from merlin_standard_lib import Schema -from transformers4rec.torch.utils import schema_utils from transformers4rec.torch.utils.padding import get_pad_fn from ...utils import dependencies @@ -343,7 +342,7 @@ def __init__( " for DDP training and ensure optimal performance." ) - self.dataset.schema = schema_utils._augment_schema( + self.dataset.schema = self._augment_schema( self.dataset.schema, cats=cats, conts=conts, labels=labels ) @@ -369,6 +368,26 @@ def __init__( self.schema = schema self.max_sequence_length = max_sequence_length + + @staticmethod + def _augment_schema( + schema, + cats=None, + conts=None, + labels=None, + ): + schema = schema.select_by_name(conts + cats + labels) + + labels = [labels] if isinstance(labels, str) else labels + for label in labels or []: + schema[label] = schema[label].with_tags(Tags.TARGET) + for label in cats or []: + schema[label] = schema[label].with_tags(Tags.CATEGORICAL) + for label in conts or []: + schema[label] = schema[label].with_tags(Tags.CONTINUOUS) + + return schema + def set_dataset(self, buffer_size, engine, reader_kwargs): dataset = validate_dataset( self.paths_or_dataset, diff --git a/transformers4rec/torch/utils/schema_utils.py b/transformers4rec/torch/utils/schema_utils.py index 82a94717e8..c5a2f93c0f 100644 --- a/transformers4rec/torch/utils/schema_utils.py +++ b/transformers4rec/torch/utils/schema_utils.py @@ -127,24 +127,3 @@ def _get_sparse_tensor(values, indices, num_rows, seq_limit): sparse_tensor = torch.sparse_coo_tensor(indices.T, values, torch.Size([num_rows, seq_limit])) return sparse_tensor.to_dense() - - -def _augment_schema( - schema, - cats=None, - conts=None, - labels=None, -): - from merlin.schema import Tags - - schema = schema.select_by_name(conts + cats + labels) - - labels = [labels] if isinstance(labels, str) else labels - for label in labels or []: - schema[label] = schema[label].with_tags(Tags.TARGET) - for label in cats or []: - schema[label] = schema[label].with_tags(Tags.CATEGORICAL) - for label in conts or []: - schema[label] = schema[label].with_tags(Tags.CONTINUOUS) - - return schema From f4b91c3bac20ea19944a818431ef6a00263559ab Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 14 Mar 2023 15:24:05 +0000 Subject: [PATCH 05/11] Move get_pad_fn to staticmethod on MerlinDataLoader --- tests/utils/test_padding.py | 17 ++---- transformers4rec/torch/utils/data_utils.py | 12 +++- transformers4rec/torch/utils/padding.py | 66 ++++++++++++++-------- 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/tests/utils/test_padding.py b/tests/utils/test_padding.py index 07feab9f70..1aec66dc50 100644 --- a/tests/utils/test_padding.py +++ b/tests/utils/test_padding.py @@ -17,7 +17,7 @@ import torch -from transformers4rec.torch.utils.padding import get_pad_fn +from transformers4rec.torch.utils.padding import pad_batch def _get_values_offsets(data): @@ -32,13 +32,11 @@ def _get_values_offsets(data): def test_pad_values_offsets_tuple(): data = [[1, 2], [], [3, 4, 5]] - pad_fn = get_pad_fn({"a": 5}) values, offsets = _get_values_offsets(data) x = {"a": (values, offsets)} - y = torch.tensor([1, 0, 1]) - padded_x, padded_y = pad_fn(x, y) + padded_x = pad_batch(x, {"a": 5}) assert torch.equal( padded_x["a"], torch.tensor( @@ -49,18 +47,15 @@ def test_pad_values_offsets_tuple(): ] ), ) - assert torch.equal(padded_y, y) def test_pad_values_offsets_dict(): data = [[1, 2], [], [3, 4, 5]] - pad_fn = get_pad_fn({"a": 7}) values, offsets = _get_values_offsets(data) x = {"a__values": values, "a__offsets": offsets} - y = torch.tensor([1, 0, 1]) - padded_x, padded_y = pad_fn(x, y) + padded_x = pad_batch(x, {"a": 7}) assert torch.equal( padded_x["a"], torch.tensor( @@ -71,18 +66,15 @@ def test_pad_values_offsets_dict(): ] ), ) - assert torch.equal(padded_y, y) def test_pad_values_dense(): data = [[1, 2], [], [3, 4, 5]] - pad_fn = get_pad_fn({"a": 7, "b": 3}) values, offsets = _get_values_offsets(data) x = {"a__values": values, "a__offsets": offsets, "b": torch.tensor([[3, 6], [4, 1], [8, 4]])} - y = torch.tensor([1, 0, 1]) - padded_x, padded_y = pad_fn(x, y) + padded_x = pad_batch(x, {"a": 7, "b": 3}) assert torch.equal( padded_x["a"], torch.tensor( @@ -103,4 +95,3 @@ def test_pad_values_dense(): ] ), ) - assert torch.equal(padded_y, y) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index fd2ae55768..b0827d7f65 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -28,7 +28,7 @@ from torch.utils.data import Dataset, IterableDataset from merlin_standard_lib import Schema -from transformers4rec.torch.utils.padding import get_pad_fn +from transformers4rec.torch.utils.padding import pad_batch from ...utils import dependencies @@ -356,7 +356,7 @@ def __init__( global_size=global_size, global_rank=global_rank, drop_last=drop_last, - ).map(get_pad_fn(sparse_max)) + ).map(self.get_pad_fn(sparse_max)) DLDataLoader.__init__( self, @@ -368,6 +368,14 @@ def __init__( self.schema = schema self.max_sequence_length = max_sequence_length + @staticmethod + def _get_pad_fn(padding_lengths): + def pad_fn(x, y): + new_x = pad_batch(x, padding_lengths) + new_y = pad_batch(y, padding_lengths) + return new_x, new_y + + return pad_fn @staticmethod def _augment_schema( diff --git a/transformers4rec/torch/utils/padding.py b/transformers4rec/torch/utils/padding.py index 125430c09a..57c1bb811e 100644 --- a/transformers4rec/torch/utils/padding.py +++ b/transformers4rec/torch/utils/padding.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict, Optional +from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -53,20 +53,44 @@ def _pad_ragged_tensor(values, offsets, padding_length): return sparse_tensor.to_dense() -def _pad_batch(X, padding_lengths, ragged_pad_fn): - if X is None or not isinstance(X, dict): - return X +Batch = Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] - X_padded = {} - for k, values in X.items(): + +def pad_batch(batch: Batch, padding_lengths: Dict[str, int]) -> Batch: + """Pad list features in a batch to padding length specified + + Parameters + ---------- + X : Batch + dictionary of tensors in batch + padding_lengths : Dict[str, int] + dictionary mapping list column name to padding length + + Returns + ------- + Batch + Batch with padded list features + + Raises + ------ + ValueError + If ragged column found with no padding length provided + """ + if batch is None or not isinstance(batch, dict): + return batch + + batch_padded = {} + for k, values in batch.items(): if k.endswith("__values"): col_name = k[:-8] - offsets = X[f"{col_name}__offsets"] + offsets = batch[f"{col_name}__offsets"] padding_length = padding_lengths.get(col_name) if padding_length: - padded_values = ragged_pad_fn(values, offsets, padding_length) - X_padded[col_name] = padded_values + padded_values = _pad_ragged_tensor(values, offsets, padding_length) + batch_padded[col_name] = padded_values else: + # Note: This exception can be removed if the model is + # updated to support __values / __offsets inputs raise ValueError( f"Found ragged column '{col_name}' with unspecified padding length. " "Please provide a padding length for this feature " @@ -77,22 +101,18 @@ def _pad_batch(X, padding_lengths, ragged_pad_fn): elif isinstance(values, tuple): padding_length = padding_lengths.get(k) if padding_length: + col_name = k values, offsets = values - padded_values = ragged_pad_fn(values, offsets, padding_length) - X_padded[k] = padded_values + padded_values = _pad_ragged_tensor(values, offsets, padding_length) + batch_padded[col_name] = padded_values else: - X_padded[k] = values + raise ValueError( + f"Found ragged column '{col_name}' with unspecified padding length. " + "Please provide a padding length for this feature " + "to be converted to a dense tensor. " + ) else: padding_length = padding_lengths.get(k) - X_padded[k] = _pad_dense_tensor(values, padding_length) - - return X_padded - - -def get_pad_fn(padding_lengths: Dict[str, int]): - def pad_fn(x, y): - new_x = _pad_batch(x, padding_lengths, _pad_ragged_tensor) - new_y = _pad_batch(y, padding_lengths, _pad_ragged_tensor) - return new_x, new_y + batch_padded[k] = _pad_dense_tensor(values, padding_length) - return pad_fn + return batch_padded From 86662c6d4d1865de030162550e4cc73b15257d81 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 14 Mar 2023 16:58:48 +0000 Subject: [PATCH 06/11] Correct name of pad method in MerlinDataLoader --- transformers4rec/torch/utils/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index b0827d7f65..397c18052f 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -356,7 +356,7 @@ def __init__( global_size=global_size, global_rank=global_rank, drop_last=drop_last, - ).map(self.get_pad_fn(sparse_max)) + ).map(self._get_pad_fn(sparse_max)) DLDataLoader.__init__( self, From eab5492116bfc008258535f5ad554fd97a1b3941 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 15 Mar 2023 12:20:30 +0000 Subject: [PATCH 07/11] Temporary change to `run-on` to check if `2PGU` is working --- .github/workflows/gpu-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index 2eda33c724..d1bcd7933e 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -12,7 +12,7 @@ on: jobs: gpu-ci: - runs-on: 1GPU + runs-on: 2GPU steps: - uses: actions/checkout@v3 From 43d535dfe975745eea63a3fe92ff80c684cccb8f Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 15 Mar 2023 12:57:32 +0000 Subject: [PATCH 08/11] Use `TensorTable` to simplify padding implementation --- transformers4rec/torch/utils/padding.py | 28 ++++++------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/transformers4rec/torch/utils/padding.py b/transformers4rec/torch/utils/padding.py index 57c1bb811e..c86ffefe72 100644 --- a/transformers4rec/torch/utils/padding.py +++ b/transformers4rec/torch/utils/padding.py @@ -17,6 +17,7 @@ import torch import torch.nn.functional as F +from merlin.table import TensorTable def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor: @@ -80,13 +81,11 @@ def pad_batch(batch: Batch, padding_lengths: Dict[str, int]) -> Batch: return batch batch_padded = {} - for k, values in batch.items(): - if k.endswith("__values"): - col_name = k[:-8] - offsets = batch[f"{col_name}__offsets"] + for col_name, col in TensorTable(batch).items(): + if col.offsets is not None: padding_length = padding_lengths.get(col_name) if padding_length: - padded_values = _pad_ragged_tensor(values, offsets, padding_length) + padded_values = _pad_ragged_tensor(col.values, col.offsets, padding_length) batch_padded[col_name] = padded_values else: # Note: This exception can be removed if the model is @@ -96,23 +95,8 @@ def pad_batch(batch: Batch, padding_lengths: Dict[str, int]) -> Batch: "Please provide a padding length for this feature " "to be converted to a dense tensor. " ) - elif k.endswith("__offsets"): - continue - elif isinstance(values, tuple): - padding_length = padding_lengths.get(k) - if padding_length: - col_name = k - values, offsets = values - padded_values = _pad_ragged_tensor(values, offsets, padding_length) - batch_padded[col_name] = padded_values - else: - raise ValueError( - f"Found ragged column '{col_name}' with unspecified padding length. " - "Please provide a padding length for this feature " - "to be converted to a dense tensor. " - ) else: - padding_length = padding_lengths.get(k) - batch_padded[k] = _pad_dense_tensor(values, padding_length) + padding_length = padding_lengths.get(col_name) + batch_padded[col_name] = _pad_dense_tensor(col.values, padding_length) return batch_padded From 1a9b76856d8f8ec11cdca658e3627f306ab6a359 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 15 Mar 2023 13:04:45 +0000 Subject: [PATCH 09/11] Remove None defaults in `_augment_schema` --- transformers4rec/torch/utils/data_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index 397c18052f..ee822345fe 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -380,9 +380,9 @@ def pad_fn(x, y): @staticmethod def _augment_schema( schema, - cats=None, - conts=None, - labels=None, + cats, + conts, + labels, ): schema = schema.select_by_name(conts + cats + labels) From f65dd6678ab16dacd3db338f9387f4fb191a1186 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 15 Mar 2023 13:11:05 +0000 Subject: [PATCH 10/11] Restore `None` default in _augment_schema and set default to empty list --- transformers4rec/torch/utils/data_utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index ee822345fe..30a587817f 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -380,18 +380,22 @@ def pad_fn(x, y): @staticmethod def _augment_schema( schema, - cats, - conts, - labels, + cats=None, + conts=None, + labels=None, ): + cats = cats or [] + conts = conts or [] + labels = labels or [] + schema = schema.select_by_name(conts + cats + labels) labels = [labels] if isinstance(labels, str) else labels - for label in labels or []: + for label in labels: schema[label] = schema[label].with_tags(Tags.TARGET) - for label in cats or []: + for label in cats: schema[label] = schema[label].with_tags(Tags.CATEGORICAL) - for label in conts or []: + for label in conts: schema[label] = schema[label].with_tags(Tags.CONTINUOUS) return schema From 25564addfb7b392d3af842835f73ec982aa53e1e Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 15 Mar 2023 16:28:51 +0000 Subject: [PATCH 11/11] Revert "Temporary change to `run-on` to check if `2PGU` is working" This reverts commit eab5492116bfc008258535f5ad554fd97a1b3941. --- .github/workflows/gpu-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index d1bcd7933e..2eda33c724 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -12,7 +12,7 @@ on: jobs: gpu-ci: - runs-on: 2GPU + runs-on: 1GPU steps: - uses: actions/checkout@v3