-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update padding of ragged features to enable dataloader change #647
Changes from 8 commits
f8e5830
ba923e2
86bfddb
ced153f
f4b91c3
1431f5d
86662c6
eab5492
43d535d
1a9b768
f65dd66
25564ad
363ab92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ on: | |
|
||
jobs: | ||
gpu-ci: | ||
runs-on: 1GPU | ||
runs-on: 2GPU | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from itertools import accumulate | ||
|
||
import torch | ||
|
||
from transformers4rec.torch.utils.padding import pad_batch | ||
|
||
|
||
def _get_values_offsets(data): | ||
values = [] | ||
row_lengths = [] | ||
for row in data: | ||
row_lengths.append(len(row)) | ||
values += row | ||
offsets = [0] + list(accumulate(row_lengths)) | ||
return torch.tensor(values), torch.tensor(offsets) | ||
|
||
|
||
def test_pad_values_offsets_tuple(): | ||
data = [[1, 2], [], [3, 4, 5]] | ||
values, offsets = _get_values_offsets(data) | ||
|
||
x = {"a": (values, offsets)} | ||
|
||
padded_x = pad_batch(x, {"a": 5}) | ||
assert torch.equal( | ||
padded_x["a"], | ||
torch.tensor( | ||
[ | ||
[1, 2, 0, 0, 0], | ||
[0, 0, 0, 0, 0], | ||
[3, 4, 5, 0, 0], | ||
] | ||
), | ||
) | ||
|
||
|
||
def test_pad_values_offsets_dict(): | ||
data = [[1, 2], [], [3, 4, 5]] | ||
values, offsets = _get_values_offsets(data) | ||
|
||
x = {"a__values": values, "a__offsets": offsets} | ||
|
||
padded_x = pad_batch(x, {"a": 7}) | ||
assert torch.equal( | ||
padded_x["a"], | ||
torch.tensor( | ||
[ | ||
[1, 2, 0, 0, 0, 0, 0], | ||
[0, 0, 0, 0, 0, 0, 0], | ||
[3, 4, 5, 0, 0, 0, 0], | ||
] | ||
), | ||
) | ||
|
||
|
||
def test_pad_values_dense(): | ||
data = [[1, 2], [], [3, 4, 5]] | ||
values, offsets = _get_values_offsets(data) | ||
|
||
x = {"a__values": values, "a__offsets": offsets, "b": torch.tensor([[3, 6], [4, 1], [8, 4]])} | ||
|
||
padded_x = pad_batch(x, {"a": 7, "b": 3}) | ||
assert torch.equal( | ||
padded_x["a"], | ||
torch.tensor( | ||
[ | ||
[1, 2, 0, 0, 0, 0, 0], | ||
[0, 0, 0, 0, 0, 0, 0], | ||
[3, 4, 5, 0, 0, 0, 0], | ||
] | ||
), | ||
) | ||
assert torch.equal( | ||
padded_x["b"], | ||
torch.tensor( | ||
[ | ||
[3, 6, 0], | ||
[4, 1, 0], | ||
[8, 4, 0], | ||
] | ||
), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,9 +28,9 @@ | |
from torch.utils.data import Dataset, IterableDataset | ||
|
||
from merlin_standard_lib import Schema | ||
from transformers4rec.torch.utils.padding import pad_batch | ||
|
||
from ...utils import dependencies | ||
from ..utils.schema_utils import _augment_schema | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -323,10 +323,6 @@ def __init__( | |
reader_kwargs["row_groups_per_part"] = row_groups_per_part | ||
self.set_dataset(buffer_size, engine, reader_kwargs) | ||
|
||
self.dataset.schema = _augment_schema( | ||
self.dataset.schema, cats, conts, labels, sparse_names, sparse_max, sparse_as_dense | ||
) | ||
|
||
if (global_rank is not None) and (self.dataset.npartitions < global_size): | ||
logger.warning( | ||
"UserWarning: User is advised to repartition the parquet file before training " | ||
|
@@ -345,6 +341,11 @@ def __init__( | |
f" GPUs ({global_size}). This will divide the work equally among GPUs" | ||
" for DDP training and ensure optimal performance." | ||
) | ||
|
||
self.dataset.schema = self._augment_schema( | ||
self.dataset.schema, cats=cats, conts=conts, labels=labels | ||
) | ||
|
||
loader = Loader( | ||
self.dataset, | ||
self.batch_size, | ||
|
@@ -355,7 +356,7 @@ def __init__( | |
global_size=global_size, | ||
global_rank=global_rank, | ||
drop_last=drop_last, | ||
) | ||
).map(self._get_pad_fn(sparse_max)) | ||
|
||
DLDataLoader.__init__( | ||
self, | ||
|
@@ -367,6 +368,34 @@ def __init__( | |
self.schema = schema | ||
self.max_sequence_length = max_sequence_length | ||
|
||
@staticmethod | ||
def _get_pad_fn(padding_lengths): | ||
def pad_fn(x, y): | ||
new_x = pad_batch(x, padding_lengths) | ||
new_y = pad_batch(y, padding_lengths) | ||
return new_x, new_y | ||
|
||
return pad_fn | ||
|
||
@staticmethod | ||
def _augment_schema( | ||
schema, | ||
cats=None, | ||
conts=None, | ||
labels=None, | ||
): | ||
schema = schema.select_by_name(conts + cats + labels) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's clearer without the default now. it wouldn't have worked as None before this change either. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, I've put the None back and added support for the None to the method by setting to an empty list if not provided. It doesn't appear to be captured by any tests, but this would enable training a model with only continuous or only categorical features. Currently you need to have at least one of each for the current version of this MerlinDataloader to work. |
||
|
||
labels = [labels] if isinstance(labels, str) else labels | ||
for label in labels or []: | ||
schema[label] = schema[label].with_tags(Tags.TARGET) | ||
for label in cats or []: | ||
schema[label] = schema[label].with_tags(Tags.CATEGORICAL) | ||
for label in conts or []: | ||
schema[label] = schema[label].with_tags(Tags.CONTINUOUS) | ||
|
||
return schema | ||
|
||
def set_dataset(self, buffer_size, engine, reader_kwargs): | ||
dataset = validate_dataset( | ||
self.paths_or_dataset, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from typing import Dict, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor: | ||
if length and len(t.shape) == 2: | ||
pad_diff = length - t.shape[1] | ||
return F.pad(input=t, pad=(0, pad_diff, 0, 0)) | ||
return t | ||
|
||
|
||
def _squeeze(tensor): | ||
if len(tensor.shape) == 2: | ||
return tensor.squeeze(1) | ||
return tensor | ||
|
||
|
||
def _get_indices(offsets, diff_offsets): | ||
row_ids = torch.arange(len(offsets) - 1, device=offsets.device) | ||
row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets) | ||
row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets) | ||
col_ids = torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated | ||
indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], axis=1) | ||
return indices | ||
|
||
|
||
def _pad_ragged_tensor(values, offsets, padding_length): | ||
values = _squeeze(values) | ||
offsets = _squeeze(offsets) | ||
num_rows = len(offsets) - 1 | ||
diff_offsets = offsets[1:] - offsets[:-1] | ||
indices = _get_indices(offsets, diff_offsets) | ||
sparse_tensor = torch.sparse_coo_tensor( | ||
indices.T, values, torch.Size([num_rows, padding_length]), device=values.device | ||
) | ||
return sparse_tensor.to_dense() | ||
|
||
|
||
Batch = Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] | ||
|
||
|
||
def pad_batch(batch: Batch, padding_lengths: Dict[str, int]) -> Batch: | ||
"""Pad list features in a batch to padding length specified | ||
|
||
Parameters | ||
---------- | ||
X : Batch | ||
dictionary of tensors in batch | ||
padding_lengths : Dict[str, int] | ||
dictionary mapping list column name to padding length | ||
|
||
Returns | ||
------- | ||
Batch | ||
Batch with padded list features | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If ragged column found with no padding length provided | ||
""" | ||
if batch is None or not isinstance(batch, dict): | ||
return batch | ||
|
||
batch_padded = {} | ||
for k, values in batch.items(): | ||
if k.endswith("__values"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be better to put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated in 43d535d |
||
col_name = k[:-8] | ||
offsets = batch[f"{col_name}__offsets"] | ||
padding_length = padding_lengths.get(col_name) | ||
if padding_length: | ||
padded_values = _pad_ragged_tensor(values, offsets, padding_length) | ||
batch_padded[col_name] = padded_values | ||
else: | ||
# Note: This exception can be removed if the model is | ||
# updated to support __values / __offsets inputs | ||
raise ValueError( | ||
f"Found ragged column '{col_name}' with unspecified padding length. " | ||
"Please provide a padding length for this feature " | ||
"to be converted to a dense tensor. " | ||
) | ||
elif k.endswith("__offsets"): | ||
continue | ||
elif isinstance(values, tuple): | ||
padding_length = padding_lengths.get(k) | ||
if padding_length: | ||
col_name = k | ||
values, offsets = values | ||
padded_values = _pad_ragged_tensor(values, offsets, padding_length) | ||
batch_padded[col_name] = padded_values | ||
else: | ||
raise ValueError( | ||
f"Found ragged column '{col_name}' with unspecified padding length. " | ||
"Please provide a padding length for this feature " | ||
"to be converted to a dense tensor. " | ||
) | ||
else: | ||
padding_length = padding_lengths.get(k) | ||
batch_padded[k] = _pad_dense_tensor(values, padding_length) | ||
|
||
return batch_padded |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this from
schema_utiils
to here so that it's closer to the only place it's called in the codebase.