Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce SOKEmbedding using Sparse Operation Kit #863

Merged
merged 52 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
0141ad4
new sok class
Nov 4, 2022
d7ad9ba
new sok class
Nov 4, 2022
7825963
test sok dynamic variable
Nov 7, 2022
475145a
test sok dynamic variable
Nov 7, 2022
4ef7a9e
bug fix comma
Nov 7, 2022
720eacf
add some comments and test distributed var
Nov 16, 2022
7a3a177
format the comments
Nov 17, 2022
54f795b
assert condition in sok lookup sparse
Dec 6, 2022
f3136b7
Merge branch 'main' into fea-sok-integration-wj
edknv Dec 14, 2022
7555c6f
Move SOKEmbedding to a separate file
edknv Dec 14, 2022
d0111b1
Clean up
edknv Dec 14, 2022
debb6e2
Clean up
edknv Dec 14, 2022
a4e3ffc
fix some import and param bug
Dec 27, 2022
97d51f5
remove some unused variable
Dec 28, 2022
b978afa
remove intial vals
Dec 28, 2022
4f22c0f
fix import
Dec 28, 2022
16fb414
reorder the import
Dec 28, 2022
9b37d4a
Merge branch 'main' into fea-sok-integration-wj
marcromeyn Jan 9, 2023
557af98
Merge branch 'main' into fea-sok-integration-wj
edknv Jan 9, 2023
f41f52f
fix import error in test embedding
Jan 12, 2023
f03e4ef
Merge branch 'fea-sok-integration-wj' of https://github.com/NVIDIA-Me…
Jan 12, 2023
fe34f9d
format the code
Jan 13, 2023
98bb17b
change the way of import
Jan 13, 2023
b0de517
Merge branch 'main' into fea-sok-integration-wj
WonderingWJ Jan 15, 2023
3e3c3b7
Merge branch 'main' into fea-sok-integration-wj
Jan 15, 2023
f7ff20f
Merge branch 'fea-sok-integration-wj' of https://github.com/NVIDIA-Me…
Jan 15, 2023
f3743e0
Merge branch 'main' into fea-sok-integration-wj
edknv Jan 27, 2023
4e8116c
Merge branch 'main' into fea-sok-integration-wj
edknv Feb 8, 2023
c266cd2
Merge branch 'main' into fea-sok-integration-wj
rnyak Feb 13, 2023
c3373e8
support sp_weights in lookup
Feb 15, 2023
d7531f6
Add unit tests for SOKEmbedding (#980)
edknv Feb 15, 2023
1e4628d
remove sok from ci since no gpu
edknv Feb 15, 2023
25a789d
Merge branch 'main' into fea-sok-integration-wj
edknv Feb 15, 2023
3c0104e
lint
edknv Feb 16, 2023
c8949e2
Merge branch 'main' into fea-sok-integration-wj
edknv Mar 1, 2023
b17fea4
Merge branch 'main' into fea-sok-integration-wj
edknv Mar 6, 2023
296a151
Merge branch 'main' into fea-sok-integration-wj
rnyak Mar 8, 2023
ef1c0bc
Merge branch 'main' into fea-sok-integration-wj
edknv Mar 8, 2023
b716adc
pip install sparse_operation_kit in tox.ini
edknv Mar 8, 2023
b1c1031
fix spelling
edknv Mar 8, 2023
91f25d1
resolve init method issue
Mar 12, 2023
d0671c9
fix schema issue
Mar 12, 2023
0eb266d
add indices and weights in from_pretrained
Mar 12, 2023
e9cc754
schema issue in from_pretrained
Mar 12, 2023
5532ab7
init method in DET
Mar 12, 2023
c9d3baf
tensor type in sok assing
Mar 12, 2023
33abba0
resolve config passing
Mar 15, 2023
4a9465d
Merge branch 'main' into fea-sok-integration-wj
WonderingWJ Mar 15, 2023
b2a01c2
lint
edknv Mar 17, 2023
b862a2a
Merge branch 'main' into fea-sok-integration-wj
edknv Mar 17, 2023
0fd059a
install sok in tox
edknv Mar 17, 2023
fcaefc3
Merge branch 'main' into fea-sok-integration-wj
edknv Mar 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/usecases/multi-gpu/install_sparse_operation_kit.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

set -e

ROOT_DIR=$1

cd $ROOT_DIR

rm -rf hugectr/

git clone -b release-23.02 https://github.com/NVIDIA-Merlin/HugeCTR.git hugectr

cd hugectr/sparse_operation_kit/
python setup.py install

rm -rf ${HUGECTR_HOME}
20 changes: 20 additions & 0 deletions merlin/models/tf/distributed/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import tensorflow as tf

from merlin.core.dispatch import HAS_GPU

hvd = None
hvd_installed = False

sok = None
sok_installed = False


try:
import horovod.tensorflow.keras as hvd # noqa: F401

Expand All @@ -11,3 +19,15 @@

if hvd_installed:
hvd.init()

if HAS_GPU:
try:
from sparse_operation_kit import experiment as sok # noqa: F401

sok_installed = True
except (ImportError, tf.errors.NotFoundError):
pass


if sok_installed:
sok.init()
232 changes: 232 additions & 0 deletions merlin/models/tf/distributed/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from typing import Union

import tensorflow as tf

from merlin.models.tf.distributed.backend import hvd_installed, sok, sok_installed
from merlin.models.tf.inputs.embedding import EmbeddingTableBase
from merlin.models.utils.schema_utils import (
create_categorical_column,
schema_to_tensorflow_metadata_json,
tensorflow_metadata_json_to_schema,
)
from merlin.schema import ColumnSchema


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class SOKEmbedding(EmbeddingTableBase):
"""
Wrap GPU accelerated opererations dedicated for sparse training / inference case.
dim: int The last dimension of the variable
vocab_sizes: list, rows of the variable list
initializer: string, list = "uniform"
When it's string, it specifies the initializer used to generate initial values.
For sok.DynamicVariable, currently, only support "random" or string of a float
value(meaning const initializer).
For sok.Variable, it is compatible with tf.Variable.
Default value is "uniform".
When it's list, it specifies the values in the embedding table.
For sok.DynamicVariable, initializer[i] must be list of [index, value],
and will be used as the initial indices and value for i-th sok.DynamicVariable.
For sok.Variable, initializer[i] must be a numpy with shape
[vocab_size[i], embedding_vec_size],
and will be used as the initial value for i-th sok.Variable.
use_dynamic_variable: bool = "False" use sok.DynamicVariable or sok.Variable. DynamicVariable
can allocates memory dynamically. Variable is a model-parallel distributed variable
localized: When utilizing sok.Variable, we change choose two mode: distributed(Distributed Va
riable) and localized(Localized Variable). If set to None, use Distributed Variable,
otherwise Localized Variable. where the list indicates which GPU you want to put this
variable on.
Default is None.
Examples
--------
.. code-block:: python
Notes
-----
"""

def __init__(
self,
dim: int,
*col_schemas: ColumnSchema,
vocab_sizes: list,
initializer: Union[str, tf.Tensor, list] = "uniform",
use_dynamic_variable=False,
localized=None,
trainable=True,
name=None,
dtype=None,
**kwargs,
):
if not hvd_installed or not sok_installed:
raise ImportError(
"'horovod' and 'sparse_operation_kit' are required to use "
f"{self.__class__.__name__}."
)

super(SOKEmbedding, self).__init__(
dim, *col_schemas, trainable=trainable, name=name, dtype=dtype, **kwargs
)
self._embedding_vec_size = dim
self._vocab_sizes = vocab_sizes
self._use_dynamic_variable = use_dynamic_variable
self._localized = localized
self._initializer = initializer
self._vars = []
if self._localized is None and self._use_dynamic_variable is False:
for i in range(len(vocab_sizes)):
if isinstance(initializer, str):
v = sok.Variable(
shape=[self._vocab_sizes[i], self._embedding_vec_size],
initializer=tf.keras.initializers.get(initializer),
dtype=tf.float32,
)
else:
v = sok.Variable(initializer[i])
else:
for i in range(len(vocab_sizes)):
if self._use_dynamic_variable:
if isinstance(initializer, str):
v = sok.DynamicVariable(
dimension=self._embedding_vec_size, initializer=initializer
)
else:
v = sok.DynamicVariable(
dimension=self._embedding_vec_size, initializer="random"
)
indices = tf.convert_to_tensor(initializer[i][0], dtype=tf.int64)
values = tf.convert_to_tensor(initializer[i][1], dtype=tf.float32)
sok.assign(v, indices, values)
elif self._localized is not None:
if isinstance(initializer, str):
v = sok.Variable(
shape=[self._vocab_sizes[i], self._embedding_vec_size],
initializer=tf.keras.initializers.get(initializer),
dtype=tf.float32,
mode="localized:%d" % self._localized[i],
)
else:
v = sok.Variable(
initializer[i],
mode="localized:%d" % self._localized[i],
)
else:
raise ValueError("Wrong Configuration!!!")
self._trainable_weights.append(v)
self._vars.append(v)

def call(self, inputs, combiners, training=True):
"""
inputs: list, tuple
a list or tuple of tf.SparseTensor or tf.RaggedTensor.
combiners: list, tuple
a list or tuple of string to specify the combiner of each lookup.
"""
is_list = isinstance(inputs, list) or isinstance(inputs, tuple)
if is_list:
for cur_input in inputs:
if not isinstance(cur_input, tf.SparseTensor):
if not isinstance(cur_input, tf.RaggedTensor):
raise ValueError(
"The input must be a list of tf.SparseTensor or tf.RaggedTensor"
)
else:
if not len(cur_input.shape) == 2:
raise ValueError("The rank of input RaggedTensor must be 2")
else:
if not isinstance(cur_input, tf.SparseTensor):
if not isinstance(cur_input, tf.RaggedTensor):
raise ValueError(
"The input must be a list of tf.SparseTensor or tf.RaggedTensor"
)
else:
if not len(cur_input.shape) == 2:
raise ValueError("The rank of input RaggedTensor must be 2")
emb_vectors = sok.lookup_sparse(
params=self._vars,
sp_ids=inputs,
combiners=combiners,
)
return emb_vectors

@classmethod
def from_pretrained(
cls,
dim: int,
vocab_sizes: list,
data: list,
trainable=True,
name=None,
col_schema=None,
use_dynamic_variable=True,
localized=None,
**kwargs,
) -> "SOKEmbedding":
"""Create From pre-trained embeddings from a Dataset.
Parameters
----------
data :
A list of numpy.array or A list of dict {"indice": numpy.array, "values": numpy.array}
trainable : bool
Whether the layer should be trained or not.
name : str
The name of the layer.
"""

if not col_schema:
if not name:
raise ValueError("`name` is required when not using a ColumnSchema")
col_schema = create_categorical_column(name, sum(vocab_sizes) - 1)

weights = []
for i, item in enumerate(data):
if use_dynamic_variable:
if isinstance(item, dict) and "indice" in item and "values" in item:
weights.append([item["indice"], item["values"]])
else:
raise ValueError("DynamicVariable should be initialized with indice and values")
else:
weights.append(item)

return cls(
dim,
col_schema,
vocab_sizes=vocab_sizes,
name=name,
initializer=weights,
use_dynamic_variable=use_dynamic_variable,
localized=localized,
trainable=trainable,
**kwargs,
)

def get_config(self):
config = super().get_config()
config["dim"] = self.dim

schema = schema_to_tensorflow_metadata_json(self.schema)
config["schema"] = schema
config["vocab_sizes"] = self._vocab_sizes
config["initializer"] = self._initializer
config["use_dynamic_variable"] = self._use_dynamic_variable
config["localized"] = self._localized

return config

@classmethod
def from_config(cls, config):
dim = config.pop("dim")
schema = tensorflow_metadata_json_to_schema(config.pop("schema"))
vocab_size = config.pop("vocab_sizes")
initializer = config.pop("initializer")
use_dynamic_variable = config.pop("use_dynamic_variable")
localized = config.pop("localized")

return cls(
dim,
*schema,
vocab_sizes=vocab_size,
initializer=initializer,
use_dynamic_variable=use_dynamic_variable,
localized=localized,
**config,
)
1 change: 1 addition & 0 deletions requirements/horovod.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
horovod
sparse_operation_kit
43 changes: 43 additions & 0 deletions tests/unit/tf/horovod/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
import pytest
import tensorflow as tf

from merlin.core.dispatch import HAS_GPU
from merlin.models.tf.distributed.embedding import SOKEmbedding
from merlin.schema import ColumnSchema, Tags


@pytest.mark.skipif(not HAS_GPU, reason="No GPU available")
class TestSOKEmbedding:
sample_column_schema = ColumnSchema(
"item_id",
dtype=np.int32,
properties={"domain": {"min": 0, "max": 10, "name": "item_id"}},
tags=[Tags.CATEGORICAL],
)

def test_sok_embedding_basic(self):
embedding = SOKEmbedding(16, self.sample_column_schema, vocab_sizes=[10])
inputs = [tf.ragged.constant([[0, 1, 0], [1, 0]])]
combiners = ["sum"]
outputs = embedding(inputs, combiners)
assert outputs[0].shape == (2, 16)

def test_sok_embedding_pretrained(self):
weights = {}
indices = np.array([0, 1, 2])
values = np.arange(3 * 16).reshape(3, 16)
weights["indice"] = indices
weights["values"] = values
embedding = SOKEmbedding.from_pretrained(
16, vocab_sizes=[10], data=[weights], name="item_id"
)
inputs = [tf.ragged.constant([[0, 1, 0], [1, 0]])]
combiners = ["sum"]
outputs = embedding(inputs, combiners)
assert outputs[0].shape == (2, 16)

def test_sok_embedding_config(self):
embedding = SOKEmbedding(16, self.sample_column_schema, vocab_sizes=[10], name="item_id")
config = embedding.get_config()
_ = SOKEmbedding.from_config(config)
12 changes: 9 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,25 @@ commands =
; Runs GPU-based tests.
allowlist_externals =
horovodrun
sh
#deps =
# -rrequirements/test.txt
passenv =
OPAL_PREFIX
setenv =
TF_GPU_ALLOCATOR=cuda_malloc_async
CPATH={env:CPATH}{:}{envdir}/hugectr/include
LD_LIBRARY_PATH=${envdir}/hugectr/include/lib{:}/usr/local/lib/python3.8/dist-packages/tensorflow{:}{env:LD_LIBRARY_PATH}
LIBRARY_PATH=${envdir}/hugectr/lib{:}{env:LIBRARY_PATH}
sitepackages=true
commands =
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git@{posargs:main}
#python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git@{posargs:main}
python -m pip install --upgrade git+https://github.com/bschifferer/dataloader.git@change_output
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git@{posargs:main}
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/nvtabular.git@{posargs:main}
python -m pip install --upgrade git+https://github.com/bschifferer/dataloader.git@change_output
# TODO: Move SOK installation to ci-runner dockerfile
# Install SOK
sh examples/usecases/multi-gpu/install_sparse_operation_kit.sh {envdir}
# Run multi-gpu tests marked with `horovod` marker
horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh python -m pytest -m horovod -rxs tests/unit

[testenv:py38-horovod-cpu]
Expand Down