Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update padding of ragged features to enable dataloader change #647

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions tests/utils/test_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from itertools import accumulate

import torch

from transformers4rec.torch.utils.padding import pad_batch


def _get_values_offsets(data):
values = []
row_lengths = []
for row in data:
row_lengths.append(len(row))
values += row
offsets = [0] + list(accumulate(row_lengths))
return torch.tensor(values), torch.tensor(offsets)


def test_pad_values_offsets_tuple():
data = [[1, 2], [], [3, 4, 5]]
values, offsets = _get_values_offsets(data)

x = {"a": (values, offsets)}

padded_x = pad_batch(x, {"a": 5})
assert torch.equal(
padded_x["a"],
torch.tensor(
[
[1, 2, 0, 0, 0],
[0, 0, 0, 0, 0],
[3, 4, 5, 0, 0],
]
),
)


def test_pad_values_offsets_dict():
data = [[1, 2], [], [3, 4, 5]]
values, offsets = _get_values_offsets(data)

x = {"a__values": values, "a__offsets": offsets}

padded_x = pad_batch(x, {"a": 7})
assert torch.equal(
padded_x["a"],
torch.tensor(
[
[1, 2, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[3, 4, 5, 0, 0, 0, 0],
]
),
)


def test_pad_values_dense():
data = [[1, 2], [], [3, 4, 5]]
values, offsets = _get_values_offsets(data)

x = {"a__values": values, "a__offsets": offsets, "b": torch.tensor([[3, 6], [4, 1], [8, 4]])}

padded_x = pad_batch(x, {"a": 7, "b": 3})
assert torch.equal(
padded_x["a"],
torch.tensor(
[
[1, 2, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[3, 4, 5, 0, 0, 0, 0],
]
),
)
assert torch.equal(
padded_x["b"],
torch.tensor(
[
[3, 6, 0],
[4, 1, 0],
[8, 4, 0],
]
),
)
45 changes: 39 additions & 6 deletions transformers4rec/torch/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from torch.utils.data import Dataset, IterableDataset

from merlin_standard_lib import Schema
from transformers4rec.torch.utils.padding import pad_batch

from ...utils import dependencies
from ..utils.schema_utils import _augment_schema

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -323,10 +323,6 @@ def __init__(
reader_kwargs["row_groups_per_part"] = row_groups_per_part
self.set_dataset(buffer_size, engine, reader_kwargs)

self.dataset.schema = _augment_schema(
self.dataset.schema, cats, conts, labels, sparse_names, sparse_max, sparse_as_dense
)

if (global_rank is not None) and (self.dataset.npartitions < global_size):
logger.warning(
"UserWarning: User is advised to repartition the parquet file before training "
Expand All @@ -345,6 +341,11 @@ def __init__(
f" GPUs ({global_size}). This will divide the work equally among GPUs"
" for DDP training and ensure optimal performance."
)

self.dataset.schema = self._augment_schema(
self.dataset.schema, cats=cats, conts=conts, labels=labels
)

loader = Loader(
self.dataset,
self.batch_size,
Expand All @@ -355,7 +356,7 @@ def __init__(
global_size=global_size,
global_rank=global_rank,
drop_last=drop_last,
)
).map(self._get_pad_fn(sparse_max))

DLDataLoader.__init__(
self,
Expand All @@ -367,6 +368,38 @@ def __init__(
self.schema = schema
self.max_sequence_length = max_sequence_length

@staticmethod
def _get_pad_fn(padding_lengths):
def pad_fn(x, y):
new_x = pad_batch(x, padding_lengths)
new_y = pad_batch(y, padding_lengths)
return new_x, new_y

return pad_fn

@staticmethod
def _augment_schema(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this from schema_utiils to here so that it's closer to the only place it's called in the codebase.

schema,
cats=None,
conts=None,
labels=None,
):
cats = cats or []
conts = conts or []
labels = labels or []

schema = schema.select_by_name(conts + cats + labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should remove the =None for cats, conts & labels?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's clearer without the default now. it wouldn't have worked as None before this change either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I've put the None back and added support for the None to the method by setting to an empty list if not provided.

It doesn't appear to be captured by any tests, but this would enable training a model with only continuous or only categorical features. Currently you need to have at least one of each for the current version of this MerlinDataloader to work.


labels = [labels] if isinstance(labels, str) else labels
for label in labels:
schema[label] = schema[label].with_tags(Tags.TARGET)
for label in cats:
schema[label] = schema[label].with_tags(Tags.CATEGORICAL)
for label in conts:
schema[label] = schema[label].with_tags(Tags.CONTINUOUS)

return schema

def set_dataset(self, buffer_size, engine, reader_kwargs):
dataset = validate_dataset(
self.paths_or_dataset,
Expand Down
102 changes: 102 additions & 0 deletions transformers4rec/torch/utils/padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from merlin.table import TensorTable


def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor:
if length and len(t.shape) == 2:
pad_diff = length - t.shape[1]
return F.pad(input=t, pad=(0, pad_diff, 0, 0))
return t


def _squeeze(tensor):
if len(tensor.shape) == 2:
return tensor.squeeze(1)
return tensor


def _get_indices(offsets, diff_offsets):
row_ids = torch.arange(len(offsets) - 1, device=offsets.device)
row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets)
row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets)
col_ids = torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated
indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], axis=1)
return indices


def _pad_ragged_tensor(values, offsets, padding_length):
values = _squeeze(values)
offsets = _squeeze(offsets)
num_rows = len(offsets) - 1
diff_offsets = offsets[1:] - offsets[:-1]
indices = _get_indices(offsets, diff_offsets)
sparse_tensor = torch.sparse_coo_tensor(
indices.T, values, torch.Size([num_rows, padding_length]), device=values.device
)
return sparse_tensor.to_dense()


Batch = Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]


def pad_batch(batch: Batch, padding_lengths: Dict[str, int]) -> Batch:
"""Pad list features in a batch to padding length specified

Parameters
----------
X : Batch
dictionary of tensors in batch
padding_lengths : Dict[str, int]
dictionary mapping list column name to padding length

Returns
-------
Batch
Batch with padded list features

Raises
------
ValueError
If ragged column found with no padding length provided
"""
if batch is None or not isinstance(batch, dict):
return batch

batch_padded = {}
for col_name, col in TensorTable(batch).items():
if col.offsets is not None:
padding_length = padding_lengths.get(col_name)
if padding_length:
padded_values = _pad_ragged_tensor(col.values, col.offsets, padding_length)
batch_padded[col_name] = padded_values
else:
# Note: This exception can be removed if the model is
# updated to support __values / __offsets inputs
raise ValueError(
f"Found ragged column '{col_name}' with unspecified padding length. "
"Please provide a padding length for this feature "
"to be converted to a dense tensor. "
)
else:
padding_length = padding_lengths.get(col_name)
batch_padded[col_name] = _pad_dense_tensor(col.values, padding_length)

return batch_padded
39 changes: 0 additions & 39 deletions transformers4rec/torch/utils/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,42 +127,3 @@ def _get_sparse_tensor(values, indices, num_rows, seq_limit):
sparse_tensor = torch.sparse_coo_tensor(indices.T, values, torch.Size([num_rows, seq_limit]))

return sparse_tensor.to_dense()


def _augment_schema(
schema,
cats=None,
conts=None,
labels=None,
sparse_names=None,
sparse_max=None,
sparse_as_dense=False,
):
from merlin.schema import ColumnSchema, Tags

schema = schema.select_by_name(conts + cats + labels)

labels = [labels] if isinstance(labels, str) else labels
for label in labels or []:
schema[label] = schema[label].with_tags(Tags.TARGET)
for label in cats or []:
schema[label] = schema[label].with_tags(Tags.CATEGORICAL)
for label in conts or []:
schema[label] = schema[label].with_tags(Tags.CONTINUOUS)

# Set the appropriate properties for the sparse_names/sparse_max/sparse_as_dense
for col in sparse_names or []:
cs = schema[col]
properties = cs.properties
if sparse_max and col in sparse_max:
properties["value_count"] = {"max": sparse_max[col]}
schema[col] = ColumnSchema(
name=cs.name,
tags=cs.tags,
dtype=cs.dtype,
is_list=True,
is_ragged=not sparse_as_dense,
properties=properties,
)

return schema