Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove sparse tensor output type for list features #103

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 0 additions & 44 deletions merlin/dataloader/loader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,25 +404,6 @@ def _get_segment_lengths(self, num_samples):
idx.append(num_samples - num_full_batches * self.batch_size)
return idx

def _to_sparse_tensor(self, values_offset, column_name):
"""
Create a sparse representation of the input tensor.
values_offset is either a tensor or a tuple of tensor, offset.
"""
seq_limit = self.sparse_max[column_name]
values, offsets, diff_offsets, num_rows = self._pull_values_offsets(values_offset)
max_seq_len = self._get_max_seq_len(diff_offsets)
if max_seq_len > seq_limit:
raise ValueError(
"The default sequence length has been configured "
+ f"to {seq_limit} but the "
+ f"largest sequence in this batch have {max_seq_len} length"
)
sparse_as_dense = column_name in self.sparse_as_dense
return self._build_sparse_tensor(
values, offsets, diff_offsets, num_rows, seq_limit, sparse_as_dense
)

def _to_tensor(self, gdf):
"""
One of the mandatory functions a child class needs
Expand Down Expand Up @@ -511,12 +492,6 @@ def _process_batch(self, tensors):
X[k] = v

X = ungroup_values_offsets(X)
for column_name in self.sparse_names:
if column_name in self.sparse_max:
tensor = (X[f"{column_name}__values"], X[f"{column_name}__offsets"])
X.pop(f"{column_name}__values")
X.pop(f"{column_name}__offsets")
X[column_name] = self._to_sparse_tensor(tensor, column_name)

# Return a tensor if we have only one label column, but return a
# dictionary of tensors if there are multiple label columns, since
Expand Down Expand Up @@ -628,32 +603,13 @@ def input_schema(self, value):
)
self.label_names = value.select_by_tag(Tags.TARGET).column_names

self.sparse_names = []
self.sparse_max = {}
self.sparse_as_dense = set()
self.dtype_reverse_map = {}

for col_name, col_spec in self._input_schema.column_schemas.items():
if col_spec.dtype not in self.dtype_reverse_map:
self.dtype_reverse_map[col_spec.dtype] = [col_name]
else:
self.dtype_reverse_map[col_spec.dtype].append(col_name)
if col_spec.is_list:
self.sparse_names.append(col_name)

value_count = col_spec.value_count
if value_count and value_count.max:
self.sparse_max[col_name] = value_count.max

if not col_spec.is_ragged:
self.sparse_as_dense.add(col_name)

if not value_count:
# TODO: error message linking to docs
raise ValueError(
f"Dense column {col_name} doesn't have the max value_count defined"
" in the schema"
)

if self._transform_graph is not None:
self._transforms = self._transform_graph.construct_schema(
Expand Down
50 changes: 0 additions & 50 deletions merlin/dataloader/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,62 +244,12 @@ def _to_tensor(self, gdf):
def _sum(self, tensor):
return tf.reduce_sum(tensor)

def _pull_values_offsets(self, values_offset):
"""
values_offset is either a tuple (values, offsets) or just values.
Values is a tensor.
This method is used to turn a tensor into its sparse representation
"""
# pull_values_offsets, return values offsets diff_offsets
diff_offsets = None
if isinstance(values_offset, tuple):
values = tf.reshape(values_offset[0], [-1])
offsets = tf.reshape(values_offset[1], [-1])
else:
values = tf.reshape(values_offset, [-1])
offsets = tf.arange(tf.shape(values)[0], dtype=tf.int64)
num_rows = len(offsets)
diff_offsets = offsets[1:] - offsets[:-1]
return values, offsets, diff_offsets, num_rows

def _row_lengths_to_offsets(self, row_lengths):
zero_value = tf.constant([0], dtype=row_lengths.dtype)
if len(row_lengths.shape) == 2:
zero_value = tf.expand_dims(zero_value, axis=0)
return tf.concat([zero_value, tf.cumsum(row_lengths)], axis=0)

def _get_max_seq_len(self, diff_offsets):
# get_max_seq_len, return int
return int(tf.math.reduce_max(diff_offsets))

def _get_indices(self, offsets, diff_offsets):
# Building the indices to reconstruct the sparse tensors
row_ids = tf.range(len(offsets), dtype=tf.int64)

row_ids_repeated = tf.repeat(row_ids, diff_offsets)
row_offset_repeated = tf.repeat(offsets, diff_offsets)
col_ids = tf.range(len(row_offset_repeated), dtype=tf.int64) - row_offset_repeated
indices = tf.concat(
values=[tf.expand_dims(row_ids_repeated, -1), tf.expand_dims(col_ids, -1)],
axis=1,
)
return indices

def _get_sparse_tensor(self, values, indices, num_rows, seq_limit):
sparse_tensor = tf.sparse.SparseTensor(
indices=indices, values=values, dense_shape=[num_rows, seq_limit]
)
return sparse_tensor

def _build_sparse_tensor(
self, values, offsets, diff_offsets, num_rows, seq_limit, sparse_as_dense
):
ragged = tf.RaggedTensor.from_row_splits(values=values, row_splits=offsets)
tensor = tf.RaggedTensor.from_tensor(ragged.to_tensor(shape=[None, seq_limit])).to_sparse()
if sparse_as_dense:
tensor = tf.sparse.to_dense(tensor)
return tensor

def _process_batch(self, tensors):
to_return = super()._process_batch(tensors)

Expand Down
36 changes: 0 additions & 36 deletions merlin/dataloader/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,31 +136,6 @@ def _tensor_split(self, tensor, idx, axis=0):
def _reshape_dim(self, tensor):
return tensor.view(-1)

def _pull_values_offsets(self, values_offset):
# pull_values_offsets, return values offsets diff_offsets
if isinstance(values_offset, tuple):
values = values_offset[0].flatten()
offsets = values_offset[1].flatten()
else:
values = values_offset.flatten()
offsets = torch.arange(values.size()[0], device=self.device)
num_rows = len(offsets) - 1
diff_offsets = offsets[1:] - offsets[:-1]
return values, offsets, diff_offsets, num_rows

def _get_max_seq_len(self, diff_offsets):
return int(diff_offsets.max())

# Building the indices to reconstruct the sparse tensors

def _get_indices(self, offsets, diff_offsets):
row_ids = torch.arange(len(offsets) - 1, device=self.device)
row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets)
row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets)
col_ids = torch.arange(len(row_offset_repeated), device=self.device) - row_offset_repeated
indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], axis=1)
return indices

def _sum(self, tensor):
return tensor.sum()

Expand All @@ -170,17 +145,6 @@ def _row_lengths_to_offsets(self, row_lengths):
zero_value = zero_value.view(-1, 1)
return torch.cat((zero_value, torch.cumsum(row_lengths, 0)))

def _build_sparse_tensor(
self, values, offsets, diff_offsets, num_rows, seq_limit, sparse_as_dense
):
indices = self._get_indices(offsets, diff_offsets)
sparse_tensor = torch.sparse_coo_tensor(
indices.T, values, torch.Size([num_rows, seq_limit]), device=self.device
)
if sparse_as_dense:
sparse_tensor = sparse_tensor.to_dense()
return sparse_tensor

def _process_batch(self, tensors):
to_return = super()._process_batch(tensors)

Expand Down
72 changes: 0 additions & 72 deletions tests/unit/dataloader/test_tf_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,78 +476,6 @@ def test_multigpu_partitioning(dataset, batch_size, global_rank):
assert indices == [global_rank]


@pytest.mark.parametrize("sparse_dense", [False, True])
def test_sparse_tensors(tmpdir, sparse_dense):
# create small dataset, add values to sparse_list
json_sample = {
"conts": {},
"cats": {
"spar1": {
"dtype": None,
"cardinality": 50,
"min_entry_size": 1,
"max_entry_size": 5,
"multi_min": 2,
"multi_max": 4,
"multi_avg": 3,
},
"spar2": {
"dtype": None,
"cardinality": 50,
"min_entry_size": 1,
"max_entry_size": 5,
"multi_min": 3,
"multi_max": 5,
"multi_avg": 4,
},
# "": {"dtype": None, "cardinality": 500, "min_entry_size": 1, "max_entry_size": 5},
},
"labels": {"rating": {"dtype": None, "cardinality": 2}},
}
datagen = pytest.importorskip("nvtabular.tools.data_gen")

cols = datagen._get_cols_from_schema(json_sample)
df_gen = datagen.DatasetGen(datagen.UniformDistro(), gpu_frac=0.0001)
target_path = os.path.join(tmpdir, "input/")
os.mkdir(target_path)
df_files = df_gen.full_df_create(10000, cols, output=target_path)
spa_lst = ["spar1", "spar2"]
spa_mx = {"spar1": 5, "spar2": 6}
batch_size = 10

ds = Dataset(df_files)
schema = ds.schema
for col_name in spa_lst:
schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL)
if not sparse_dense:
schema[col_name] = schema[col_name].with_properties(
{"value_count": {"max": spa_mx[col_name]}}
)

for col_name in []:
schema[col_name] = schema[col_name].with_tags(Tags.CONTINUOUS)
for col_name in ["rating"]:
schema[col_name] = schema[col_name].with_tags(Tags.TARGET)
ds.schema = schema

data_itr = tf_dataloader.Loader(
ds,
batch_size=batch_size,
)
for batch in data_itr:
feats, labs = batch
for col in spa_lst:
# grab row lengths
if not sparse_dense:
feature_tensor = feats[f"{col}"]
assert list(feature_tensor.shape) == [batch_size, spa_mx[col]]
assert isinstance(feature_tensor, tf.sparse.SparseTensor)
else:
feature_tensor = feats[f"{col}__offsets"]
assert feature_tensor.shape[0] == batch_size + 1
assert not isinstance(feature_tensor, tf.sparse.SparseTensor)


@pytest.mark.skipif(
os.environ.get("NR_USER") is not None,
reason="not working correctly in ci environment",
Expand Down
43 changes: 0 additions & 43 deletions tests/unit/dataloader/test_torch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,49 +279,6 @@ def test_mh_support(multihot_dataset):
assert idx > 0


@pytest.mark.parametrize("sparse_dense", [False, True])
def test_sparse_tensors(sparse_dense):
# create small dataset, add values to sparse_list
df = make_df(
{
"spar1": [[1, 2, 3, 4], [4, 2, 4, 4], [1, 3, 4, 3], [1, 1, 3, 3]],
"spar2": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14], [15, 16]],
}
)
spa_lst = ["spar1", "spar2"]
spa_mx = {"spar1": 5, "spar2": 6}
batch_size = 2

ds = Dataset(df)
schema = ds.schema
for col_name in spa_lst:
if not sparse_dense:
schema[col_name] = schema[col_name].with_properties(
{"value_count": {"max": spa_mx[col_name]}}
)
ds.schema = schema

dataloader = torch_dataloader.Loader(ds, batch_size=batch_size)
for batch in dataloader:
feats, labs = batch
for col in spa_lst:
if sparse_dense:
assert col + "__values" in feats
assert col + "__offsets" in feats
feature_tensor = feats[col + "__offsets"]
else:
feature_tensor = feats[col]
if not sparse_dense:
assert list(feature_tensor.shape) == [batch_size, spa_mx[col]]
assert feature_tensor.is_sparse
else:
assert not feature_tensor[0].is_sparse

# add dict sparse_max entry for each target
# iterate dataloader grab sparse columns
# ensure they are correct structurally


@pytest.mark.parametrize("batch_size", [1000])
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU else [True])
def test_dataloader_schema(df, dataset, batch_size, cpu):
Expand Down