Skip to content

Commit

Permalink
Merge pull request #2698 from VibhuJawa/fea_mnmg_tfidf
Browse files Browse the repository at this point in the history
[REVIEW]Distributed TF-IDF transformer
  • Loading branch information
cjnolet authored Aug 28, 2020
2 parents e7d677f + e6033f7 commit 147f795
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# cuML 0.16.0 (Date TBD)

## New Features
- PR #2698: Distributed TF-IDF transformer

## Improvements
- PR #2741: Use rapids build packages in conda environments
Expand Down
17 changes: 17 additions & 0 deletions python/cuml/dask/feature_extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cuml.dask.feature_extraction import text
17 changes: 17 additions & 0 deletions python/cuml/dask/feature_extraction/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cuml.dask.feature_extraction.text.tfidf_transformer import TfidfTransformer
208 changes: 208 additions & 0 deletions python/cuml/dask/feature_extraction/text/tfidf_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import dask
from toolz import first
import dask.array

from cuml.common import with_cupy_rmm
from cuml.dask.common.base import BaseEstimator
from cuml.dask.common.base import DelayedTransformMixin
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.dask.common.func import reduce
from cuml.dask.common.input_utils import DistributedDataHandler

from cuml.feature_extraction.text import TfidfTransformer as s_TfidfTransformer


class TfidfTransformer(BaseEstimator, DelayedTransformMixin):

"""
Distributed TF-IDF transformer
Examples
--------
import cupy as cp
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from cuml.dask.common import to_sparse_dask_array
from cuml.dask.naive_bayes import MultinomialNB
import dask
from cuml.dask.feature_extraction.text import TfidfTransformer
# Create a local CUDA cluster
cluster = LocalCUDACluster()
client = Client(cluster)
# Load corpus
twenty_train = fetch_20newsgroups(subset='train',
shuffle=True, random_state=42)
cv = CountVectorizer()
xformed = cv.fit_transform(twenty_train.data).astype(cp.float32)
X = to_sparse_dask_array(xformed, client)
y = dask.array.from_array(twenty_train.target, asarray=False,
fancy=False).astype(cp.int32)
mutli_gpu_transformer = TfidfTransformer()
X_transormed = mutli_gpu_transformer.fit_transform(X)
X_transormed.compute_chunk_sizes()
model = MultinomialNB()
model.fit(X_transormed, y)
model.score(X_transormed, y)
Output:
.. code-block:: python
array(0.93264981)
"""

def __init__(self, client=None, verbose=False, **kwargs):

"""
Create new distributed TF-IDF transformer instance
Parameters
-----------
client : dask.distributed.Client optional Dask client to use
"""
super(TfidfTransformer, self).__init__(
client=client, verbose=verbose, **kwargs
)

self.datatype = "cupy"

# Make any potential model args available and catch any potential
# ValueErrors before distributed training begins.
self._set_internal_model(s_TfidfTransformer(**kwargs))

@staticmethod
@with_cupy_rmm
def _set_doc_stats(X, kwargs):
model = s_TfidfTransformer(**kwargs)
# Below is only required if we have to set stats
if model.use_idf:
model._set_doc_stats(X)

return model

@staticmethod
def _merge_stats_to_model(models):
modela = first(models)
if modela.use_idf:
for model in models[1:]:
modela.__n_samples += model.__n_samples
modela.__df += model.__df
return modela

@staticmethod
def _set_idf_diag(model):
model._set_idf_diag()
return model

@with_cupy_rmm
def fit(self, X):

"""
Fit distributed TFIDF Transformer
Parameters
----------
X : dask.Array with blocks containing dense or sparse cupy arrays
Returns
-------
cuml.dask.naive_bayes.TfidfTransformer current model instance
"""

# Only Dask.Array supported for now
if not isinstance(X, dask.array.core.Array):
raise ValueError("Only dask.Array is supported for X")

if len(X.chunks[1]) != 1:
raise ValueError(
"X must be chunked by row only. "
"Multi-dimensional chunking is not supported"
)

# We don't' do anything if we don't need idf
if not self.internal_model.use_idf:
return self

futures = DistributedDataHandler.create(X, self.client)

models = [
self.client.submit(
self._set_doc_stats, part, self.kwargs, pure=False
)
for w, part in futures.gpu_futures
]

models = reduce(models, self._merge_stats_to_model, client=self.client)

wait_and_raise_from_futures([models])

models = self.client.submit(self._set_idf_diag, models, pure=False)

wait_and_raise_from_futures([models])

self._set_internal_model(models)

return self

@staticmethod
def _get_part(parts, idx):
return parts[idx]

@staticmethod
def _get_size(arrs):
return arrs.shape[0]

def fit_transform(self, X):
return self.fit(X).transform(X)

def transform(self, X):
"""
Use distributed TFIDFTransformer to transforme the
given set of data samples.
Parameters
----------
X : dask.Array with blocks containing dense or sparse cupy arrays
Returns
-------
dask.Array with blocks containing transformed sparse cupy arrays
"""
if not isinstance(X, dask.array.core.Array):
raise ValueError("Only dask.Array is supported for X")

return self._transform(
X, n_dims=2, delayed=True, output_collection_type="cupy"
)
83 changes: 61 additions & 22 deletions python/cuml/feature_extraction/_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from sklearn.utils.validation import FLOAT_DTYPES
from cuml.common.exceptions import NotFittedError
import cupy as cp
import cupyx
from cuml.common import with_cupy_rmm
from cuml.common.sparsefuncs import csr_row_normalize_l1, csr_row_normalize_l2
from cuml.common.sparsefuncs import csr_diag_mul
from cuml.common.array import CumlArray
from cuml import Base


def _sparse_document_frequency(X):
Expand All @@ -30,7 +31,18 @@ def _sparse_document_frequency(X):
return cp.diff(X.indptr)


class TfidfTransformer:
def _get_dtype(X):
"""
Returns the valid dtype for tf-idf transformer
"""
import numpy as np
FLOAT_DTYPES = (np.float64, np.float32, np.float16)

dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32
return dtype


class TfidfTransformer(Base):
"""Transform a count matrix to a normalized tf or tf-idf representation
Tf means term-frequency while tf-idf means term-frequency times inverse
document-frequency. This is a common term weighting scheme in information
Expand Down Expand Up @@ -91,11 +103,53 @@ class TfidfTransformer:

def __init__(self, *, norm='l2', use_idf=True, smooth_idf=True,
sublinear_tf=False):

super(TfidfTransformer, self).__init__(...)
self.norm = norm
self.use_idf = use_idf
self.smooth_idf = smooth_idf
self.sublinear_tf = sublinear_tf

@with_cupy_rmm
def _set_doc_stats(self, X):
"""
We set the following document level statistics here:
n_samples
n_features
df(document frequency)
"""
# Should not have a cost if already sparse
output_dtype = _get_dtype(X)
X = self._convert_to_csr(X, output_dtype)
n_samples, n_features = X.shape
df = _sparse_document_frequency(X)
df = df.astype(output_dtype, copy=False)
self.__df = CumlArray(df)
self.__n_samples = n_samples
self.__n_features = n_features

return

@with_cupy_rmm
def _set_idf_diag(self):
"""
Sets idf_diagonal sparse array
"""
# perform idf smoothing if required
df = self.__df.to_output('cupy') + int(self.smooth_idf)
n_samples = self.__n_samples + int(self.smooth_idf)

# log+1 instead of log makes sure terms with zero idf don't get
# suppressed entirely.
idf = cp.log(n_samples / df) + 1
self._idf_diag = cp.sparse.dia_matrix(
(idf, 0),
shape=(self.__n_features, self.__n_features),
dtype=df.dtype
)
# Free up memory occupied by below
del self.__df

@with_cupy_rmm
def fit(self, X):
"""Learn the idf vector (global term weights).
Expand All @@ -105,26 +159,11 @@ def fit(self, X):
X : array-like of shape n_samples, n_features
A matrix of term/token counts.
"""
dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32
X = self._convert_to_csr(X, dtype)

output_dtype = _get_dtype(X)
X = self._convert_to_csr(X, output_dtype)
if self.use_idf:
n_samples, n_features = X.shape
df = _sparse_document_frequency(X)
df = df.astype(dtype, copy=False)

# perform idf smoothing if required
df += int(self.smooth_idf)
n_samples += int(self.smooth_idf)

# log+1 instead of log makes sure terms with zero idf don't get
# suppressed entirely.
idf = cp.log(n_samples / df) + 1
self._idf_diag = cupyx.scipy.sparse.dia_matrix(
(idf, 0),
shape=(n_features, n_features),
dtype=dtype
)
self._set_doc_stats(X)
self._set_idf_diag()

return self

Expand All @@ -147,7 +186,7 @@ def transform(self, X, copy=True):
if copy:
X = X.copy()

dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32
dtype = _get_dtype(X)

X = self._convert_to_csr(X, dtype)
if X.dtype != dtype:
Expand Down
8 changes: 7 additions & 1 deletion python/cuml/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def patched__setattr__(self, name, value):

supported_type = get_supported_input_type(value)

if (supported_type == CumlArray):
if name == 'idf_':
# We skip this test because idf_' for tfidf setter returns
# a sparse diagonal matrix and getter gets a cupy array
# see discussion at:
# https://github.com/rapidsai/cuml/pull/2698/files#r471865982
pass
elif (supported_type == CumlArray):
assert name.startswith("_"), "Invalid CumlArray Use! CumlArray \
attributes need a leading underscore. Attribute: '{}' In: {}" \
.format(name, self.__repr__())
Expand Down
Loading

0 comments on commit 147f795

Please sign in to comment.