diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e7985c1e6..3ca4d5ce6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # cuML 0.16.0 (Date TBD) ## New Features +- PR #2698: Distributed TF-IDF transformer ## Improvements - PR #2741: Use rapids build packages in conda environments diff --git a/python/cuml/dask/feature_extraction/__init__.py b/python/cuml/dask/feature_extraction/__init__.py new file mode 100644 index 0000000000..493c5362c1 --- /dev/null +++ b/python/cuml/dask/feature_extraction/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cuml.dask.feature_extraction import text diff --git a/python/cuml/dask/feature_extraction/text/__init__.py b/python/cuml/dask/feature_extraction/text/__init__.py new file mode 100644 index 0000000000..2275e9f125 --- /dev/null +++ b/python/cuml/dask/feature_extraction/text/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cuml.dask.feature_extraction.text.tfidf_transformer import TfidfTransformer diff --git a/python/cuml/dask/feature_extraction/text/tfidf_transformer.py b/python/cuml/dask/feature_extraction/text/tfidf_transformer.py new file mode 100644 index 0000000000..353badfc84 --- /dev/null +++ b/python/cuml/dask/feature_extraction/text/tfidf_transformer.py @@ -0,0 +1,208 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import dask +from toolz import first +import dask.array + +from cuml.common import with_cupy_rmm +from cuml.dask.common.base import BaseEstimator +from cuml.dask.common.base import DelayedTransformMixin +from cuml.dask.common.utils import wait_and_raise_from_futures +from cuml.dask.common.func import reduce +from cuml.dask.common.input_utils import DistributedDataHandler + +from cuml.feature_extraction.text import TfidfTransformer as s_TfidfTransformer + + +class TfidfTransformer(BaseEstimator, DelayedTransformMixin): + + """ + Distributed TF-IDF transformer + + Examples + -------- + + import cupy as cp + from sklearn.datasets import fetch_20newsgroups + from sklearn.feature_extraction.text import CountVectorizer + from dask_cuda import LocalCUDACluster + from dask.distributed import Client + from cuml.dask.common import to_sparse_dask_array + from cuml.dask.naive_bayes import MultinomialNB + import dask + from cuml.dask.feature_extraction.text import TfidfTransformer + + # Create a local CUDA cluster + cluster = LocalCUDACluster() + client = Client(cluster) + + # Load corpus + twenty_train = fetch_20newsgroups(subset='train', + shuffle=True, random_state=42) + cv = CountVectorizer() + xformed = cv.fit_transform(twenty_train.data).astype(cp.float32) + X = to_sparse_dask_array(xformed, client) + + y = dask.array.from_array(twenty_train.target, asarray=False, + fancy=False).astype(cp.int32) + + + mutli_gpu_transformer = TfidfTransformer() + X_transormed = mutli_gpu_transformer.fit_transform(X) + X_transormed.compute_chunk_sizes() + + model = MultinomialNB() + model.fit(X_transormed, y) + model.score(X_transormed, y) + + Output: + + .. code-block:: python + + array(0.93264981) + + """ + + def __init__(self, client=None, verbose=False, **kwargs): + + """ + Create new distributed TF-IDF transformer instance + + Parameters + ----------- + + client : dask.distributed.Client optional Dask client to use + """ + super(TfidfTransformer, self).__init__( + client=client, verbose=verbose, **kwargs + ) + + self.datatype = "cupy" + + # Make any potential model args available and catch any potential + # ValueErrors before distributed training begins. + self._set_internal_model(s_TfidfTransformer(**kwargs)) + + @staticmethod + @with_cupy_rmm + def _set_doc_stats(X, kwargs): + model = s_TfidfTransformer(**kwargs) + # Below is only required if we have to set stats + if model.use_idf: + model._set_doc_stats(X) + + return model + + @staticmethod + def _merge_stats_to_model(models): + modela = first(models) + if modela.use_idf: + for model in models[1:]: + modela.__n_samples += model.__n_samples + modela.__df += model.__df + return modela + + @staticmethod + def _set_idf_diag(model): + model._set_idf_diag() + return model + + @with_cupy_rmm + def fit(self, X): + + """ + Fit distributed TFIDF Transformer + + Parameters + ---------- + + X : dask.Array with blocks containing dense or sparse cupy arrays + Returns + ------- + + cuml.dask.naive_bayes.TfidfTransformer current model instance + """ + + # Only Dask.Array supported for now + if not isinstance(X, dask.array.core.Array): + raise ValueError("Only dask.Array is supported for X") + + if len(X.chunks[1]) != 1: + raise ValueError( + "X must be chunked by row only. " + "Multi-dimensional chunking is not supported" + ) + + # We don't' do anything if we don't need idf + if not self.internal_model.use_idf: + return self + + futures = DistributedDataHandler.create(X, self.client) + + models = [ + self.client.submit( + self._set_doc_stats, part, self.kwargs, pure=False + ) + for w, part in futures.gpu_futures + ] + + models = reduce(models, self._merge_stats_to_model, client=self.client) + + wait_and_raise_from_futures([models]) + + models = self.client.submit(self._set_idf_diag, models, pure=False) + + wait_and_raise_from_futures([models]) + + self._set_internal_model(models) + + return self + + @staticmethod + def _get_part(parts, idx): + return parts[idx] + + @staticmethod + def _get_size(arrs): + return arrs.shape[0] + + def fit_transform(self, X): + return self.fit(X).transform(X) + + def transform(self, X): + """ + Use distributed TFIDFTransformer to transforme the + given set of data samples. + + Parameters + ---------- + + X : dask.Array with blocks containing dense or sparse cupy arrays + + + Returns + ------- + + dask.Array with blocks containing transformed sparse cupy arrays + + """ + if not isinstance(X, dask.array.core.Array): + raise ValueError("Only dask.Array is supported for X") + + return self._transform( + X, n_dims=2, delayed=True, output_collection_type="cupy" + ) diff --git a/python/cuml/feature_extraction/_tfidf.py b/python/cuml/feature_extraction/_tfidf.py index e2bd397919..614145bef4 100644 --- a/python/cuml/feature_extraction/_tfidf.py +++ b/python/cuml/feature_extraction/_tfidf.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from sklearn.utils.validation import FLOAT_DTYPES from cuml.common.exceptions import NotFittedError import cupy as cp import cupyx from cuml.common import with_cupy_rmm from cuml.common.sparsefuncs import csr_row_normalize_l1, csr_row_normalize_l2 from cuml.common.sparsefuncs import csr_diag_mul +from cuml.common.array import CumlArray +from cuml import Base def _sparse_document_frequency(X): @@ -30,7 +31,18 @@ def _sparse_document_frequency(X): return cp.diff(X.indptr) -class TfidfTransformer: +def _get_dtype(X): + """ + Returns the valid dtype for tf-idf transformer + """ + import numpy as np + FLOAT_DTYPES = (np.float64, np.float32, np.float16) + + dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32 + return dtype + + +class TfidfTransformer(Base): """Transform a count matrix to a normalized tf or tf-idf representation Tf means term-frequency while tf-idf means term-frequency times inverse document-frequency. This is a common term weighting scheme in information @@ -91,11 +103,53 @@ class TfidfTransformer: def __init__(self, *, norm='l2', use_idf=True, smooth_idf=True, sublinear_tf=False): + + super(TfidfTransformer, self).__init__(...) self.norm = norm self.use_idf = use_idf self.smooth_idf = smooth_idf self.sublinear_tf = sublinear_tf + @with_cupy_rmm + def _set_doc_stats(self, X): + """ + We set the following document level statistics here: + n_samples + n_features + df(document frequency) + """ + # Should not have a cost if already sparse + output_dtype = _get_dtype(X) + X = self._convert_to_csr(X, output_dtype) + n_samples, n_features = X.shape + df = _sparse_document_frequency(X) + df = df.astype(output_dtype, copy=False) + self.__df = CumlArray(df) + self.__n_samples = n_samples + self.__n_features = n_features + + return + + @with_cupy_rmm + def _set_idf_diag(self): + """ + Sets idf_diagonal sparse array + """ + # perform idf smoothing if required + df = self.__df.to_output('cupy') + int(self.smooth_idf) + n_samples = self.__n_samples + int(self.smooth_idf) + + # log+1 instead of log makes sure terms with zero idf don't get + # suppressed entirely. + idf = cp.log(n_samples / df) + 1 + self._idf_diag = cp.sparse.dia_matrix( + (idf, 0), + shape=(self.__n_features, self.__n_features), + dtype=df.dtype + ) + # Free up memory occupied by below + del self.__df + @with_cupy_rmm def fit(self, X): """Learn the idf vector (global term weights). @@ -105,26 +159,11 @@ def fit(self, X): X : array-like of shape n_samples, n_features A matrix of term/token counts. """ - dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32 - X = self._convert_to_csr(X, dtype) - + output_dtype = _get_dtype(X) + X = self._convert_to_csr(X, output_dtype) if self.use_idf: - n_samples, n_features = X.shape - df = _sparse_document_frequency(X) - df = df.astype(dtype, copy=False) - - # perform idf smoothing if required - df += int(self.smooth_idf) - n_samples += int(self.smooth_idf) - - # log+1 instead of log makes sure terms with zero idf don't get - # suppressed entirely. - idf = cp.log(n_samples / df) + 1 - self._idf_diag = cupyx.scipy.sparse.dia_matrix( - (idf, 0), - shape=(n_features, n_features), - dtype=dtype - ) + self._set_doc_stats(X) + self._set_idf_diag() return self @@ -147,7 +186,7 @@ def transform(self, X, copy=True): if copy: X = X.copy() - dtype = X.dtype if X.dtype in FLOAT_DTYPES else cp.float32 + dtype = _get_dtype(X) X = self._convert_to_csr(X, dtype) if X.dtype != dtype: diff --git a/python/cuml/test/conftest.py b/python/cuml/test/conftest.py index 8734686cba..315a575dbf 100644 --- a/python/cuml/test/conftest.py +++ b/python/cuml/test/conftest.py @@ -110,7 +110,13 @@ def patched__setattr__(self, name, value): supported_type = get_supported_input_type(value) - if (supported_type == CumlArray): + if name == 'idf_': + # We skip this test because idf_' for tfidf setter returns + # a sparse diagonal matrix and getter gets a cupy array + # see discussion at: + # https://github.com/rapidsai/cuml/pull/2698/files#r471865982 + pass + elif (supported_type == CumlArray): assert name.startswith("_"), "Invalid CumlArray Use! CumlArray \ attributes need a leading underscore. Attribute: '{}' In: {}" \ .format(name, self.__repr__()) diff --git a/python/cuml/test/dask/test_tfidf.py b/python/cuml/test/dask/test_tfidf.py new file mode 100644 index 0000000000..9461f07e2d --- /dev/null +++ b/python/cuml/test/dask/test_tfidf.py @@ -0,0 +1,122 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +import cupy as cp +from scipy.sparse import csr_matrix as scipy_csr_matrix +from cupy.sparse import csr_matrix as cp_csr_matrix +import dask.array as da +import dask + +from cuml.dask.feature_extraction.text import TfidfTransformer +from sklearn.feature_extraction.text import ( + TfidfTransformer as SkTfidfTransformer, +) + + +# Testing Util Functions +def generate_dask_array(np_array, n_parts): + """ + Creates a dask array from a numpy 2d array + """ + n_samples = np_array.shape[0] + n_samples_per_part = int(n_samples / n_parts) + chunks = [n_samples_per_part] * n_parts + samples_last_row = n_samples - ((n_parts - 1) * n_samples_per_part) + chunks[-1] = samples_last_row + chunks = tuple(chunks) + return da.from_array(np_array, chunks=(chunks, -1)) + + +def create_cp_sparse_ar_from_dense_np_ar(ar, dtype=np.float32): + """ + Creates a gpu array from a dense cpu array + """ + return cp_csr_matrix(scipy_csr_matrix(ar), dtype=dtype) + + +def create_cp_sparse_dask_array(np_ar, n_parts): + """ + Creates a sparse gpu dask array from the given numpy array + """ + ar = generate_dask_array(np_ar, n_parts) + meta = dask.array.from_array(cp_csr_matrix(cp.zeros(1, dtype=cp.float32))) + ar = ar.map_blocks(create_cp_sparse_ar_from_dense_np_ar, meta=meta) + return ar + + +def create_scipy_sparse_array_from_dask_cp_sparse_array(ar): + """ + Creates a cpu sparse array from the given numpy array + Will not be needed probably once we have + https://github.com/cupy/cupy/issues/3178 + """ + meta = dask.array.from_array(scipy_csr_matrix(np.zeros(1, dtype=ar.dtype))) + ar = ar.map_blocks(lambda x: x.get(), meta=meta) + ar = ar.compute() + return ar + + +# data_ids correspond to data, order is important +data_ids = ["base_case", "diag", "empty_feature", "123", "empty_doc"] +data = [ + np.array( + [ + [0, 1, 1, 1, 0, 0, 1, 0, 1], + [0, 2, 0, 1, 0, 1, 1, 0, 1], + [1, 0, 0, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 0, 1, 0, 1], + ] + ), + np.array([[1, 1, 1], [1, 1, 0], [1, 0, 0]]), + np.array([[1, 1, 0], [1, 1, 0], [1, 0, 0]]), + np.array([[1], [2], [3]]), + np.array([[1, 1, 1], [1, 1, 0], [0, 0, 0]]), +] + + +@pytest.mark.mg +@pytest.mark.parametrize("data", data, ids=data_ids) +@pytest.mark.parametrize("norm", ["l1", "l2", None]) +@pytest.mark.parametrize("use_idf", [True, False]) +@pytest.mark.parametrize("smooth_idf", [True, False]) +@pytest.mark.parametrize("sublinear_tf", [True, False]) +def test_tfidf_transformer( + data, norm, use_idf, smooth_idf, sublinear_tf, client +): + # Testing across multiple-n_parts + for n_parts in range(1, data.shape[0]): + dask_sp_array = create_cp_sparse_dask_array(data, n_parts) + tfidf = TfidfTransformer( + norm=norm, + use_idf=use_idf, + smooth_idf=smooth_idf, + sublinear_tf=sublinear_tf, + ) + sk_tfidf = SkTfidfTransformer( + norm=norm, + use_idf=use_idf, + smooth_idf=smooth_idf, + sublinear_tf=sublinear_tf, + ) + + res = tfidf.fit_transform(dask_sp_array) + res = create_scipy_sparse_array_from_dask_cp_sparse_array( + res + ).todense() + ref = sk_tfidf.fit_transform(data).todense() + + cp.testing.assert_array_almost_equal(res, ref)