Skip to content

Commit

Permalink
Uses merlin-dataloader package (#547)
Browse files Browse the repository at this point in the history
* Uses merlin-dataloader package

* Adds missing requirements file

* Rename to dataloader; Move column selection; Use only max value count
  • Loading branch information
edknv authored and sararb committed Nov 30, 2022
1 parent 712c206 commit aebc66c
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 47 deletions.
50 changes: 45 additions & 5 deletions merlin_standard_lib/utils/misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ def get_object_size(obj, seen=None):

def validate_dataset(paths_or_dataset, batch_size, buffer_size, engine, reader_kwargs):
"""
Util function to load NVTabular Dataset from disk
Util function to load merlin.io.Dataset from disk
Parameters
----------
paths_or_dataset: Union[nvtabular.Dataset, str]
Path to dataset to load of nvtabular Dataset,
paths_or_dataset: Union[merlin.io.dataset.Dataset, str]
Path to dataset to load or Merlin Dataset,
if Dataset, return the object.
batch_size: int
batch size for Dataloader.
Expand All @@ -196,9 +196,9 @@ def validate_dataset(paths_or_dataset, batch_size, buffer_size, engine, reader_k
Additional arguments of the specified reader.
"""
try:
from nvtabular.io import Dataset
from merlin.io.dataset import Dataset
except ImportError:
raise ValueError("NVTabular is necessary for this function, please install: " "nvtabular.")
raise ValueError("Merlin Core is necessary for this function, please install: merlin-core.")

# TODO: put this in parent class and allow
# torch dataset to leverage as well?
Expand Down Expand Up @@ -236,4 +236,44 @@ def validate_dataset(paths_or_dataset, batch_size, buffer_size, engine, reader_k
reader_kwargs["batch_size"] = buffer_size
else:
reader_kwargs["part_mem_fraction"] = buffer_size

return Dataset(files, engine=engine, **reader_kwargs)


def _augment_schema(
schema,
cats=None,
conts=None,
labels=None,
sparse_names=None,
sparse_max=None,
sparse_as_dense=False,
):
from merlin.schema import ColumnSchema, Tags

schema = schema.select_by_name(conts + cats + labels)

labels = [labels] if isinstance(labels, str) else labels
for label in labels or []:
schema[label] = schema[label].with_tags(Tags.TARGET)
for label in cats or []:
schema[label] = schema[label].with_tags(Tags.CATEGORICAL)
for label in conts or []:
schema[label] = schema[label].with_tags(Tags.CONTINUOUS)

# Set the appropriate properties for the sparse_names/sparse_max/sparse_as_dense
for col in sparse_names or []:
cs = schema[col]
properties = cs.properties
if sparse_max and col in sparse_max:
properties["value_count"] = {"max": sparse_max[col]}
schema[col] = ColumnSchema(
name=cs.name,
tags=cs.tags,
dtype=cs.dtype,
is_list=True,
is_ragged=not sparse_as_dense,
properties=properties,
)

return schema
2 changes: 2 additions & 0 deletions requirements/dataloader.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
merlin-core>=0.2.0
merlin-dataloader>=0.0.2
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def read_requirements(filename):
"base": read_requirements("requirements/base.txt"),
"pytorch": read_requirements("requirements/pytorch.txt"),
"nvtabular": read_requirements("requirements/nvtabular.txt"),
"dataloader": read_requirements("requirements/dataloader.txt"),
"docs": read_requirements("requirements/docs.txt"),
"dev": read_requirements("requirements/dev.txt"),
}
Expand Down
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ commands =
deps = -rrequirements/test.txt
commands =
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git@{posargs:main}
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git@{posargs:main}
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/NVTabular.git@{posargs:main}

python -m pytest -rsx --cov-config tests/.coveragerc --cov-report term-missing --cov=. tests
Expand All @@ -34,6 +35,7 @@ deps =
-rrequirements/test.txt
commands =
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/NVTabular.git

python -m pytest -rsx --cov-config tests/.coveragerc --cov-report term-missing --cov=. tests
Expand All @@ -54,4 +56,3 @@ deps = -rrequirements/docs.txt
commands =
sphinx-multiversion --dump-metadata docs/source docs/build/html | jq "keys"
sphinx-multiversion docs/source docs/build/html

4 changes: 2 additions & 2 deletions transformers4rec/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ class T4RecTrainingArguments(TrainingArguments):
)

data_loader_engine: str = field(
default="nvtabular",
default="merlin",
metadata={
"help": "Parquet data loader engine. "
"'nvtabular': GPU-accelerated parquet data loader from NVTabular, 'pyarrow': read whole parquet into memory."
"'merlin': GPU-accelerated parquet data loader from Merlin, 'pyarrow': read whole parquet into memory."
},
)

Expand Down
28 changes: 17 additions & 11 deletions transformers4rec/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def session_aggregator(
session_data: Union[pandas.DataFrame, cudf.DataFrame]
session-level dataset with list features.
"""
try:
from merlin.dag import ColumnSelector
from merlin.io.dataset import Dataset
except ImportError:
raise ValueError("merlin-core is necessary for this function, please install it")

try:
import nvtabular as nvt
except ImportError:
Expand Down Expand Up @@ -128,7 +134,7 @@ def session_aggregator(
session_column = session_column[0]

# define groupby operator
groupby_feats = nvt.ColumnSelector(schema.column_names)
groupby_feats = ColumnSelector(schema.column_names)
groupby_features = groupby_feats >> nvt.ops.Groupby(
groupby_cols=[session_column.name], aggs=groupby_dict, name_sep="-"
)
Expand All @@ -146,7 +152,7 @@ def session_aggregator(
selected_features = groupby_features[non_list_variable] + groupby_features_trim

workflow = nvt.Workflow(selected_features)
dataset = nvt.Dataset(data, cpu=False)
dataset = Dataset(data, cpu=False)
workflow.fit(dataset)
session_data = workflow.transform(dataset).to_ddf().compute()

Expand Down Expand Up @@ -176,7 +182,7 @@ def save_time_based_splits(
Parameters
-----
data: Union[nvtabular.Dataset, dask_cudf.DataFrame]
data: Union[merlin.io.dataset.Dataset, dask_cudf.DataFrame]
Dataset to split into time-based splits.
output_dir: str
Output path the save the time-based splits.
Expand Down Expand Up @@ -230,7 +236,7 @@ def _save_time_based_splits_gpu(
cudf, cupy and dask_cudf
Parameters
-----
data: Union[nvtabular.Dataset, dask_cudf.DataFrame]
data: Union[merlin.io.dataset.Dataset, dask_cudf.DataFrame]
Dataset to split into time-based splits.
output_dir: str
Output path the save the time-based splits.
Expand All @@ -250,15 +256,15 @@ def _save_time_based_splits_gpu(
import cudf
import cupy
import dask_cudf
import nvtabular as nvt
from merlin.io.dataset import Dataset
except ImportError:
raise ValueError(
"Rapids is necessary for this function, please install: "
"cudf, cupy, dask_cudf & nvtabular."
"cudf, cupy, dask_cudf, & merlin-core."
)

if isinstance(data, dask_cudf.DataFrame):
data = nvt.Dataset(data)
data = Dataset(data)
if not isinstance(partition_col, list):
partition_col = [partition_col]

Expand Down Expand Up @@ -312,7 +318,7 @@ def _save_time_based_splits_cpu(
Parameters
-----
data: Union[nvtabular.Dataset, dask_cudf.DataFrame]
data: Union[merlin.io.dataset.Dataset, dask_cudf.DataFrame]
Dataset to split into time-based splits.
output_dir: str
Output path the save the time-based splits.
Expand All @@ -331,16 +337,16 @@ def _save_time_based_splits_cpu(
try:
import dask
import numpy as np
import nvtabular as nvt
import pandas as pd
from merlin.io.dataset import Dataset
except ImportError:
raise ValueError(
"Rapids is necessary for this function, please install: "
"cudf, cupy, dask_cudf & nvtabular."
"cudf, cupy, dask_cudf & merlin-core."
)

if isinstance(data, dask.DataFrame):
data = nvt.Dataset(data)
data = Dataset(data)
if not isinstance(partition_col, list):
partition_col = [partition_col]

Expand Down
46 changes: 28 additions & 18 deletions transformers4rec/torch/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,26 +171,39 @@ def from_schema(


if dependencies.is_gpu_dataloader_available():
from nvtabular.loader.torch import DLDataLoader
from nvtabular.loader.torch import TorchAsyncItr as DataLoader
import torch
from merlin.dataloader.torch import Loader
from torch.utils.data import DataLoader

from merlin_standard_lib.utils.misc_utils import validate_dataset
from merlin_standard_lib.utils.misc_utils import _augment_schema, validate_dataset

class DLDataLoaderWrapper(DLDataLoader):
class DLDataLoader(DataLoader):
"""
This class is an extension of the torch dataloader.
It is required to support the FastAI framework.
Setting the batch size directly to DLDataLoader makes it 3x slower.
So we set as an alternative attribute and use it within
T4Rec Trainer during evaluation
# TODO : run experiments with new nvt dataloader
# TODO : run experiments with new merlin-dataloader
"""

def __init__(self, *args, **kwargs) -> None:
if "batch_size" in kwargs:
self._batch_size = kwargs.pop("batch_size")
super().__init__(*args, **kwargs)

@dataloader_registry.register_with_multiple_names("nvtabular_dataloader", "nvtabular")
class NVTabularDataLoader(T4RecDataLoader, DLDataLoaderWrapper):
@property
def device(self):
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def __len__(self):
return len(self.dataset)

@dataloader_registry.register_with_multiple_names(
"merlin_dataloader", "merlin", "nvtabular_dataloader", "nvtabular"
)
class MerlinDataLoader(T4RecDataLoader, DLDataLoader):
def __init__(
self,
paths_or_dataset,
Expand Down Expand Up @@ -225,7 +238,10 @@ def __init__(
self.drop_last = drop_last

self.set_dataset(buffer_size, engine, reader_kwargs)
self.dataset.schema = self.dataset.schema.select_by_name(conts + cats + labels)

self.dataset.schema = _augment_schema(
self.dataset.schema, cats, conts, labels, sparse_names, sparse_max, sparse_as_dense
)

if (global_rank is not None) and (self.dataset.npartitions < global_size):
logger.warning(
Expand All @@ -236,11 +252,8 @@ def __init__(
)
self.dataset = self.dataset.repartition(npartitions=global_size)

loader = DataLoader(
loader = Loader(
self.dataset,
cats,
conts,
labels,
self.batch_size,
shuffle,
seed_fn=seed_fn,
Expand All @@ -249,12 +262,9 @@ def __init__(
global_size=global_size,
global_rank=global_rank,
drop_last=drop_last,
sparse_names=sparse_names,
sparse_max=sparse_max,
sparse_as_dense=sparse_as_dense,
)

DLDataLoaderWrapper.__init__(
DLDataLoader.__init__(
self,
loader,
collate_fn=collate_fn,
Expand Down Expand Up @@ -317,7 +327,7 @@ def from_schema(
schema = schema.select_by_name(categorical_features + continuous_features + targets)
sparse_names = sparse_names or schema.select_by_tag(Tag.LIST).column_names
sparse_max = sparse_max or {name: max_sequence_length for name in sparse_names}
nvt_loader = cls(
loader = cls(
paths_or_dataset,
batch_size=batch_size,
max_sequence_length=max_sequence_length,
Expand All @@ -335,7 +345,7 @@ def from_schema(
**kwargs,
)

return nvt_loader
return loader


class ParquetDataset(Dataset):
Expand Down
18 changes: 8 additions & 10 deletions transformers4rec/utils/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#


def is_nvtabular_available() -> bool:
try:
import nvtabular
except ImportError:
nvtabular = None
return nvtabular is not None


def is_gpu_dataloader_available() -> bool:
try:
import cudf
Expand All @@ -39,3 +29,11 @@ def is_pyarrow_available() -> bool:
except ImportError:
pyarrow = None
return pyarrow is not None


def is_merlin_dataloader_available() -> bool:
try:
import merlin.dataloader
except ImportError:
merlin.dataloader = None
return merlin.dataloader is not None

0 comments on commit aebc66c

Please sign in to comment.