Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW]Distributed TF-IDF transformer #2698

Merged
merged 23 commits into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9c42274
Merge pull request #1935 from rapidsai/branch-0.13
raydouglass Mar 30, 2020
7544c43
REL v0.13.0 release
GPUtester Mar 31, 2020
049eadd
Merge pull request #2342 from rapidsai/branch-0.14
raydouglass Jun 3, 2020
31202da
REL v0.14.0 release
GPUtester Jun 3, 2020
b51e0a9
working intial implimentation multi-node multi-gpu transformer
VibhuJawa Aug 14, 2020
f360a8b
verified working on a bigger dataset
VibhuJawa Aug 14, 2020
2349a97
added tests and code clean up
VibhuJawa Aug 17, 2020
1fd7033
working intial implimentation multi-node multi-gpu transformer
VibhuJawa Aug 14, 2020
4c3e8f0
verified working on a bigger dataset
VibhuJawa Aug 14, 2020
e85d485
added tests and code clean up
VibhuJawa Aug 17, 2020
dddb36c
added change-log
VibhuJawa Aug 17, 2020
bfa8640
Merge branch 'fea_mnmg_tfidf' of github.com:vibhujawa/cuml into fea_m…
VibhuJawa Aug 17, 2020
4cd0356
style check satisfied dask/tfidf_transformer.py
VibhuJawa Aug 17, 2020
404a274
style check test_tfidf.py
VibhuJawa Aug 17, 2020
704fe31
style check for test/dask/test_tfidf.py and feature_extraction/_tfidf.py
VibhuJawa Aug 17, 2020
3625017
removed comments from _tfidf.py
VibhuJawa Aug 18, 2020
85ac30c
added idf_ excecption to conftest.py
VibhuJawa Aug 18, 2020
75af2d1
fixed minor style_check_issue_with__tfidf.py
VibhuJawa Aug 18, 2020
492d5ab
resolved conflicts
VibhuJawa Aug 27, 2020
befe6a1
switched to cumlarray for doc_frequency and resolved conflicts
VibhuJawa Aug 27, 2020
1e2b41c
added incorrectly removed space back to CHANGELOG.MD
VibhuJawa Aug 27, 2020
e3bbbde
Apply suggestions from code review
VibhuJawa Aug 27, 2020
e6033f7
removed sklearn utils dependency
VibhuJawa Aug 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# cuML 0.16.0 (Date TBD)

## New Features
- PR #2698: Distributed TF-IDF transformer

## Improvements

Expand Down
17 changes: 17 additions & 0 deletions python/cuml/dask/feature_extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cuml.dask.feature_extraction import text
17 changes: 17 additions & 0 deletions python/cuml/dask/feature_extraction/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cuml.dask.feature_extraction.text.tfidf_transformer import TfidfTransformer
208 changes: 208 additions & 0 deletions python/cuml/dask/feature_extraction/text/tfidf_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import dask
from toolz import first
import dask.array

from cuml.common import with_cupy_rmm
from cuml.dask.common.base import BaseEstimator
from cuml.dask.common.base import DelayedTransformMixin
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.dask.common.func import reduce
from cuml.dask.common.input_utils import DistributedDataHandler

from cuml.feature_extraction.text import TfidfTransformer as s_TfidfTransformer


class TfidfTransformer(BaseEstimator, DelayedTransformMixin):

"""
Distributed TF-IDF transformer

Examples
--------

import cupy as cp
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from cuml.dask.common import to_sparse_dask_array
from cuml.dask.naive_bayes import MultinomialNB
import dask
from cuml.dask.feature_extraction.text import TfidfTransformer

# Create a local CUDA cluster
cluster = LocalCUDACluster()
client = Client(cluster)

# Load corpus
twenty_train = fetch_20newsgroups(subset='train',
shuffle=True, random_state=42)
cv = CountVectorizer()
xformed = cv.fit_transform(twenty_train.data).astype(cp.float32)
X = to_sparse_dask_array(xformed, client)

y = dask.array.from_array(twenty_train.target, asarray=False,
fancy=False).astype(cp.int32)


mutli_gpu_transformer = TfidfTransformer()
X_transormed = mutli_gpu_transformer.fit_transform(X)
X_transormed.compute_chunk_sizes()

model = MultinomialNB()
model.fit(X_transormed, y)
model.score(X_transormed, y)

Output:

.. code-block:: python

array(0.93264981)

"""

def __init__(self, client=None, verbose=False, **kwargs):

"""
Create new distributed TF-IDF transformer instance

Parameters
-----------

client : dask.distributed.Client optional Dask client to use
"""
super(TfidfTransformer, self).__init__(
client=client, verbose=verbose, **kwargs
)

self.datatype = "cupy"

# Make any potential model args available and catch any potential
# ValueErrors before distributed training begins.
self._set_internal_model(s_TfidfTransformer(**kwargs))

@staticmethod
@with_cupy_rmm
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
def _set_doc_stats(X, kwargs):
model = s_TfidfTransformer(**kwargs)
# Below is only required if we have to set stats
if model.use_idf:
model._set_doc_stats(X)

return model

@staticmethod
def _merge_stats_to_model(models):
modela = first(models)
if modela.use_idf:
for model in models[1:]:
modela.n_samples += model.n_samples
modela.df += model.df
return modela

@staticmethod
def _set_idf_diag(model):
model._set_idf_diag()
return model

@with_cupy_rmm
def fit(self, X):

"""
Fit distributed TFIDF Transformer

Parameters
----------

X : dask.Array with blocks containing dense or sparse cupy arrays
Returns
-------

cuml.dask.naive_bayes.TfidfTransformer current model instance
"""

# Only Dask.Array supported for now
if not isinstance(X, dask.array.core.Array):
raise ValueError("Only dask.Array is supported for X")

if len(X.chunks[1]) != 1:
raise ValueError(
"X must be chunked by row only. "
"Multi-dimensional chunking is not supported"
)

# We don't' do anything if we don't need idf
if not self.internal_model.use_idf:
return self

futures = DistributedDataHandler.create(X, self.client)

models = [
self.client.submit(
self._set_doc_stats, part, self.kwargs, pure=False
)
for w, part in futures.gpu_futures
]

models = reduce(models, self._merge_stats_to_model, client=self.client)

wait_and_raise_from_futures([models])

models = self.client.submit(self._set_idf_diag, models, pure=False)

wait_and_raise_from_futures([models])

self._set_internal_model(models)

return self

@staticmethod
def _get_part(parts, idx):
return parts[idx]

@staticmethod
def _get_size(arrs):
return arrs.shape[0]

def fit_transform(self, X):
return self.fit(X).transform(X)

def transform(self, X):
"""
Use distributed TFIDFTransformer to transforme the
given set of data samples.

Parameters
----------

X : dask.Array with blocks containing dense or sparse cupy arrays


Returns
-------

dask.Array with blocks containing transformed sparse cupy arrays

"""
if not isinstance(X, dask.array.core.Array):
raise ValueError("Only dask.Array is supported for X")

return self._transform(
X, n_dims=2, delayed=True, output_collection_type="cupy"
)
83 changes: 62 additions & 21 deletions python/cuml/feature_extraction/_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cuml.common import with_cupy_rmm
from cuml.common.sparsefuncs import csr_row_normalize_l1, csr_row_normalize_l2
from cuml.common.sparsefuncs import csr_diag_mul
from cuml import Base


def _sparse_document_frequency(X):
Expand All @@ -29,7 +30,15 @@ def _sparse_document_frequency(X):
return cp.diff(X.indptr)


class TfidfTransformer:
def _get_dtype(X):
"""
Returns the valid dtype for tf-idf transoformer
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
"""
dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
return dtype


class TfidfTransformer(Base):
"""Transform a count matrix to a normalized tf or tf-idf representation
Tf means term-frequency while tf-idf means term-frequency times inverse
document-frequency. This is a common term weighting scheme in information
Expand Down Expand Up @@ -90,11 +99,58 @@ class TfidfTransformer:

def __init__(self, *, norm='l2', use_idf=True, smooth_idf=True,
sublinear_tf=False):

super(TfidfTransformer, self).__init__(...)
self.norm = norm
self.use_idf = use_idf
self.smooth_idf = smooth_idf
self.sublinear_tf = sublinear_tf

@with_cupy_rmm
def _set_doc_stats(self, X):
"""
We set the following document level statistics here:
n_samples
n_features
df(document frequency)
"""
# Should not have a cost if allready sparse
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
output_dtype = _get_dtype(X)
X = self._convert_to_csr(X, output_dtype)
n_samples, n_features = X.shape
df = _sparse_document_frequency(X)
df = df.astype(output_dtype, copy=False)
# TODO Finding it difficult to
# map below to cumlarray
# removed trailing underscores for now
self.df = df
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
self.n_samples = n_samples
self.n_features = n_features

return

@with_cupy_rmm
def _set_idf_diag(self):
"""
Sets idf_diagonal sparse array
"""
# perform idf smoothing if required
df = self.df + int(self.smooth_idf)
n_samples = self.n_samples + int(self.smooth_idf)

# log+1 instead of log makes sure terms with zero idf don't get
# suppressed entirely.
idf = cp.log(n_samples / df) + 1
self._idf_diag = cp.sparse.dia_matrix(
(idf, 0),
shape=(self.n_features, self.n_features),
dtype=df.dtype
)
# Free up memory occupied by below
del self.df
del self.n_samples
del self.n_features

@with_cupy_rmm
def fit(self, X):
"""Learn the idf vector (global term weights).
Expand All @@ -104,26 +160,11 @@ def fit(self, X):
X : array-like of shape n_samples, n_features
A matrix of term/token counts.
"""
dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32
X = self._convert_to_csr(X, dtype)

output_dtype = _get_dtype(X)
X = self._convert_to_csr(X, output_dtype)
if self.use_idf:
n_samples, n_features = X.shape
df = _sparse_document_frequency(X)
df = df.astype(dtype, copy=False)

# perform idf smoothing if required
df += int(self.smooth_idf)
n_samples += int(self.smooth_idf)

# log+1 instead of log makes sure terms with zero idf don't get
# suppressed entirely.
idf = cp.log(n_samples / df) + 1
self._idf_diag = cp.sparse.dia_matrix(
(idf, 0),
shape=(n_features, n_features),
dtype=dtype
)
self._set_doc_stats(X)
self._set_idf_diag()

return self

Expand All @@ -146,7 +187,7 @@ def transform(self, X, copy=True):
if copy:
X = X.copy()

dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32
dtype = _get_dtype(X)

X = self._convert_to_csr(X, dtype)
if X.dtype != dtype:
Expand Down
8 changes: 7 additions & 1 deletion python/cuml/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,13 @@ def patched__setattr__(self, name, value):

supported_type = get_supported_input_type(value)

if (supported_type == CumlArray):
if name == 'idf_':
# We skip this test because idf_' fot tfidf setter returns
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
# a sparse diagonal matrix and getter gets a cupy array
# see discussion at:
# https://github.com/rapidsai/cuml/pull/2698/files#r471865982
pass
elif (supported_type == CumlArray):
assert name.startswith("_"), "Invalid CumlArray Use! CumlArray \
attributes need a leading underscore. Attribute: '{}' In: {}" \
.format(name, self.__repr__())
Expand Down
Loading