diff --git a/dask_ml/feature_extraction/text.py b/dask_ml/feature_extraction/text.py index a647ddee7..fd1fe02ca 100644 --- a/dask_ml/feature_extraction/text.py +++ b/dask_ml/feature_extraction/text.py @@ -8,14 +8,18 @@ import dask.bag as db import dask.dataframe as dd import distributed +import pandas as pd import numpy as np import scipy.sparse import sklearn.base import sklearn.feature_extraction.text +import sklearn.preprocessing from dask.delayed import Delayed from distributed import get_client, wait from sklearn.utils.validation import check_is_fitted +FLOAT_DTYPES = (np.float64, np.float32, np.float16) + class _BaseHasher(sklearn.base.BaseEstimator): @property @@ -116,6 +120,34 @@ def _hasher(self): return sklearn.feature_extraction.text.FeatureHasher +def _n_samples(X): + """Count the number of samples in dask.array.Array X.""" + def chunk_n_samples(chunk, axis, keepdims): + return np.array([chunk.shape[0]], dtype=np.int64) + + return da.reduction(X, + chunk=chunk_n_samples, + aggregate=np.sum, + concatenate=False, + dtype=np.int64) + + +def _document_frequency(X, dtype): + """Count the number of non-zero values for each feature in dask array X.""" + def chunk_doc_freq(chunk, axis, keepdims): + if scipy.sparse.isspmatrix_csr(chunk): + return np.bincount(chunk.indices, minlength=chunk.shape[1]) + else: + return np.diff(chunk.indptr) + + return da.reduction(X, + chunk=chunk_doc_freq, + aggregate=np.sum, + axis=0, + concatenate=False, + dtype=dtype) + + class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer): """Convert a collection of text documents to a matrix of token counts @@ -140,7 +172,9 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer): Examples -------- The Dask-ML implementation currently requires that ``raw_documents`` - is a :class:`dask.bag.Bag` of documents (lists of strings). + is either a :class:`dask.bag.Bag` of documents (lists of strings) or + a :class:`dask.dataframe.Series` of documents (Series of strings) + with partitions of type :class:`pandas.Series`. >>> from dask_ml.feature_extraction.text import CountVectorizer >>> import dask.bag as db @@ -152,10 +186,25 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer): ... 'And this is the third one.', ... 'Is this the first document?', ... ] - >>> corpus = db.from_sequence(corpus, npartitions=2) + >>> corpus_bag = db.from_sequence(corpus, npartitions=2) >>> vectorizer = CountVectorizer() - >>> X = vectorizer.fit_transform(corpus) - dask.array>> X = vectorizer.fit_transform(corpus_bag) + dask.array + >>> X.compute().toarray() + array([[0, 1, 1, 1, 0, 0, 1, 0, 1], + [0, 2, 0, 1, 0, 1, 1, 0, 1], + [1, 0, 0, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 0, 1, 0, 1]]) + >>> vectorizer.get_feature_names() + ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this'] + + >>> import dask.dataframe as dd + >>> import pandas as pd + >>> corpus_dds = dd.from_pandas(pd.Series(corpus), npartitions=2) + >>> vectorizer = CountVectorizer() + >>> X = vectorizer.fit_transform(corpus_dds) + dask.array >>> X.compute().toarray() array([[0, 1, 1, 1, 0, 0, 1, 0, 1], @@ -166,10 +215,35 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer): ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this'] """ + def get_CountVectorizer_params(self, deep=True): + """ + Get CountVectorizer parameters (names and values) for this + estimator (self), whether it is an instance of CountVectorizer or an + instance of a subclass of CountVectorizer. + + Parameters + ---------- + deep : bool, default=True + If True, will return the CountVectorizer parameters for this + estimator and contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + out = dict() + for key in CountVectorizer._get_param_names(): + value = getattr(self, key) + if deep and hasattr(value, "get_params"): + deep_items = value.get_params().items() + out.update((key + "__" + k, val) for k, val in deep_items) + out[key] = value + return out + def fit_transform(self, raw_documents, y=None): - params = self.get_params() + params = self.get_CountVectorizer_params() vocabulary = params.pop("vocabulary") - vocabulary_for_transform = vocabulary if self.vocabulary is not None: @@ -181,19 +255,21 @@ def fit_transform(self, raw_documents, y=None): fixed_vocabulary = False # Case 2: learn vocabulary from the data. vocabularies = raw_documents.map_partitions(_build_vocabulary, params) - vocabulary = vocabulary_for_transform = _merge_vocabulary( - *vocabularies.to_delayed() - ) + vocabulary = vocabulary_for_transform = ( + _merge_vocabulary(*vocabularies.to_delayed())) vocabulary_for_transform = vocabulary_for_transform.persist() vocabulary_ = vocabulary.compute() n_features = len(vocabulary_) - result = raw_documents.map_partitions( - _count_vectorizer_transform, vocabulary_for_transform, params - ) - - meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype) - result = build_array(result, n_features, meta) + meta = scipy.sparse.csr_matrix((0, n_features), dtype=self.dtype) + if isinstance(raw_documents, dd.Series): + result = raw_documents.map_partitions( + _count_vectorizer_transform, vocabulary_for_transform, + params, meta=meta) + else: + result = raw_documents.map_partitions( + _count_vectorizer_transform, vocabulary_for_transform, params) + result = build_array(result, n_features, meta) self.vocabulary_ = vocabulary_ self.fixed_vocabulary_ = fixed_vocabulary @@ -201,7 +277,7 @@ def fit_transform(self, raw_documents, y=None): return result def transform(self, raw_documents): - params = self.get_params() + params = self.get_CountVectorizer_params() vocabulary = params.pop("vocabulary") if vocabulary is None: @@ -215,18 +291,287 @@ def transform(self, raw_documents): except ValueError: vocabulary_for_transform = dask.delayed(vocabulary) else: - (vocabulary_for_transform,) = client.scatter( - (vocabulary,), broadcast=True - ) + (vocabulary_for_transform,) = client.scatter((vocabulary,), + broadcast=True) else: vocabulary_for_transform = vocabulary n_features = vocabulary_length(vocabulary_for_transform) - transformed = raw_documents.map_partitions( - _count_vectorizer_transform, vocabulary_for_transform, params + meta = scipy.sparse.csr_matrix((0, n_features), dtype=self.dtype) + if isinstance(raw_documents, dd.Series): + result = raw_documents.map_partitions( + _count_vectorizer_transform, vocabulary_for_transform, + params, meta=meta) + else: + transformed = raw_documents.map_partitions( + _count_vectorizer_transform, vocabulary_for_transform, params) + result = build_array(transformed, n_features, meta) + return result + +class TfidfTransformer(sklearn.feature_extraction.text.TfidfTransformer): + """Transform a count matrix to a normalized tf or tf-idf representation + + See Also + -------- + sklearn.feature_extraction.text.TfidfTransformer + + Examples + -------- + >>> from dask_ml.feature_extraction.text import TfidfTransformer + >>> from dask_ml.feature_extraction.text import CountVectorizer + >>> from sklearn.pipeline import Pipeline + >>> import numpy as np + >>> corpus = ['this is the first document', + ... 'this document is the second document', + ... 'and this is the third one', + ... 'is this the first document'] + >>> X = CountVectorizer().fit_transform(corpus) + dask.array + >>> X.compute().toarray() + array([[0, 1, 1, 1, 0, 0, 1, 0, 1], + [0, 2, 0, 1, 0, 1, 1, 0, 1], + [1, 0, 0, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 0, 1, 0, 1]]) + >>> transformer = TfidfTransformer().fit(X) + TfidfTransformer() + >>> transformer.idf_ + array([1.91629073, 1.22314355, 1.51082562, 1. , 1.91629073, + 1.91629073, 1. , 1.91629073, 1. ]) + >>> transformer.transform(X).compute().shape + (4, 9) + """ + def fit(self, X, y=None): + """Learn the idf vector (global term weights). + + Parameters + ---------- + X : sparse matrix of shape n_samples, n_features) + A matrix of term/token counts. + """ + def get_idf_diag(X, dtype): + n_samples = _n_samples(X) # X.shape[0] is not yet known + n_features = X.shape[1] + df = _document_frequency(X, dtype) + + # perform idf smoothing if required + df += int(self.smooth_idf) + n_samples += int(self.smooth_idf) + + # log+1 instead of log makes sure terms with zero idf don't get + # suppressed entirely. + return np.log(n_samples / df) + 1 + + dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64 + + if self.use_idf: + self._idf_diag = get_idf_diag(X, dtype) + + return self + + def transform(self, X, copy=True): + """Transform a count matrix to a tf or tf-idf representation + + Parameters + ---------- + X : sparse matrix of (n_samples, n_features) + a matrix of term/token counts + + copy : bool, default=True + Whether to copy X and operate on the copy or perform in-place + operations. + + Returns + ------- + vectors : sparse matrix of shape (n_samples, n_features) + """ + # X = self._validate_data( + # X, accept_sparse="csr", dtype=FLOAT_DTYPES, copy=copy, reset=False + # ) + # if not sp.issparse(X): + # X = sp.csr_matrix(X, dtype=np.float64) + + def _astype(chunk, Xdtype=np.float64): + return chunk.astype(Xdtype, copy=True) + + def _one_plus_log(chunk): + # transforms nonzero elements x of csr_matrix: x -> 1 + log(x) + c = chunk.copy() + c.data = np.log(chunk.data, dtype=chunk.data.dtype) + c.data += 1 + return c + + def _dot_idf_diag(chunk): + return chunk * self._idf_diag + + dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64 + meta = scipy.sparse.eye(0, format="csr", dtype=dtype) + if X.dtype != dtype: + X = X.map_blocks(_astype, Xdtype=dtype, dtype=dtype, meta=meta) + + if self.sublinear_tf: + X = X.map_blocks(_one_plus_log, dtype=dtype, meta=meta) + + if self.use_idf: + # idf_ being a property, the automatic attributes detection + # does not work as usual and we need to specify the attribute + # name: + check_is_fitted(self, attributes=["idf_"], + msg="idf vector is not fitted") + self.__compute_idf() + X = X.map_blocks(_dot_idf_diag, dtype=dtype, meta=meta) + + if self.norm: + X = X.map_blocks(_normalize_transform, + dtype=dtype, + norm=self.norm, + meta=meta) + + return X + + def __compute_idf(self): + # if _idf_diag is still lazy, then it is computed here + if dask.is_dask_collection(self._idf_diag): + _idf_diag = self._idf_diag.compute() + n_features = len(_idf_diag) + self._idf_diag = scipy.sparse.diags( + _idf_diag, + offsets=0, + shape=(n_features, n_features), + format="csr", + dtype=_idf_diag.dtype) + + @property + def idf_(self): + """Inverse document frequency vector, only defined if `use_idf=True`. + + Returns + ------- + ndarray of shape (n_features,) + """ + self.__compute_idf() + # if _idf_diag is not set, this will raise an attribute error, + # which means hasattr(self, "idf_") is False + return np.ravel(self._idf_diag.sum(axis=0)) + + +class TfidfVectorizer(sklearn.feature_extraction.text.TfidfVectorizer, + CountVectorizer): + r"""Convert a collection of raw documents to a matrix of TF-IDF features. + + Equivalent to :class:`CountVectorizer` followed by + :class:`TfidfTransformer`. + + See Also + -------- + sklearn.feature_extraction.text.TfidfVectorizer + + Examples + -------- + The Dask-ML implementation currently requires that ``raw_documents`` + is either a :class:`dask.bag.Bag` of documents (lists of strings) or + a :class:`dask.dataframe.Series` of documents (Series of strings) + with partitions of type :class:`pandas.Series`. + + >>> from dask_ml.feature_extraction.text import TfidfVectorizer + >>> import dask.bag as db + >>> from distributed import Client + >>> client = Client() + >>> corpus = [ + ... 'This is the first document.', + ... 'This document is the second document.', + ... 'And this is the third one.', + ... 'Is this the first document?', + ... ] + >>> corpus_bag = db.from_sequence(corpus, npartitions=2) + >>> vectorizer = TfidfVectorizer() + >>> X = vectorizer.fit_transform(corpus_bag) + dask.array + >>> X.compute().toarray() + array([[0. , 0.46979139, 0.58028582, 0.38408524, 0. , + 0. , 0.38408524, 0. , 0.38408524], + [0. , 0.6876236 , 0. , 0.28108867, 0. , + 0.53864762, 0.28108867, 0. , 0.28108867], + [0.51184851, 0. , 0. , 0.26710379, 0.51184851, + 0. , 0.26710379, 0.51184851, 0.26710379], + [0. , 0.46979139, 0.58028582, 0.38408524, 0. , + 0. , 0.38408524, 0. , 0.38408524]]) + >>> vectorizer.get_feature_names() + ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this'] + + >>> import dask.dataframe as dd + >>> import pandas as pd + >>> corpus_dds = dd.from_pandas(pd.Series(corpus), npartitions=2) + >>> vectorizer = TfidfVectorizer() + >>> X = vectorizer.fit_transform(corpus_dds) + dask.array + >>> X.compute().toarray() + array([[0. , 0.46979139, 0.58028582, 0.38408524, 0. , + 0. , 0.38408524, 0. , 0.38408524], + [0. , 0.6876236 , 0. , 0.28108867, 0. , + 0.53864762, 0.28108867, 0. , 0.28108867], + [0.51184851, 0. , 0. , 0.26710379, 0.51184851, + 0. , 0.26710379, 0.51184851, 0.26710379], + [0. , 0.46979139, 0.58028582, 0.38408524, 0. , + 0. , 0.38408524, 0. , 0.38408524]]) + >>> vectorizer.get_feature_names() + ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this'] + """ + + def __init__( + self, + *, + input="content", + encoding="utf-8", + decode_error="strict", + strip_accents=None, + lowercase=True, + preprocessor=None, + tokenizer=None, + analyzer="word", + stop_words=None, + token_pattern=r"(?u)\b\w\w+\b", + ngram_range=(1, 1), + max_df=1.0, + min_df=1, + max_features=None, + vocabulary=None, + binary=False, + dtype=np.float64, + norm="l2", + use_idf=True, + smooth_idf=True, + sublinear_tf=False, + ): + + super().__init__( + input=input, + encoding=encoding, + decode_error=decode_error, + strip_accents=strip_accents, + lowercase=lowercase, + preprocessor=preprocessor, + tokenizer=tokenizer, + analyzer=analyzer, + stop_words=stop_words, + token_pattern=token_pattern, + ngram_range=ngram_range, + max_df=max_df, + min_df=min_df, + max_features=max_features, + vocabulary=vocabulary, + binary=binary, + dtype=dtype, + ) + + self._tfidf = TfidfTransformer( + norm=norm, + use_idf=use_idf, + smooth_idf=smooth_idf, + sublinear_tf=sublinear_tf ) - meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype) - return build_array(transformed, n_features, meta) def build_array(bag, n_features, meta): @@ -257,6 +602,10 @@ def vocabulary_length(vocabulary): raise ValueError(f"Unknown vocabulary type {type(vocabulary)}.") +def _normalize_transform(chunk, norm): + return sklearn.preprocessing.normalize(chunk, norm=norm) + + def _count_vectorizer_transform(partition, vocabulary, params): model = sklearn.feature_extraction.text.CountVectorizer( vocabulary=vocabulary, **params diff --git a/tests/feature_extraction/test_text.py b/tests/feature_extraction/test_text.py index 01323f106..a4c9207c6 100644 --- a/tests/feature_extraction/test_text.py +++ b/tests/feature_extraction/test_text.py @@ -183,3 +183,98 @@ def test_count_vectorizer_remote_vocabulary(): ) m.fit_transform(b) assert m.vocabulary_ is remote_vocabulary + + +@pytest.mark.parametrize("distributed", [True, False]) +@pytest.mark.parametrize("collection_type", ["Bag", "Series"]) +@pytest.mark.parametrize("norm", ["l1", "l2"]) +@pytest.mark.parametrize("use_idf", [True, False]) +@pytest.mark.parametrize("smooth_idf", [True, False]) +@pytest.mark.parametrize("sublinear_tf", [True, False]) +def test_tfidf_vectorizer(distributed, + collection_type, + norm, + use_idf, + smooth_idf, + sublinear_tf): + skl1 = (sklearn.feature_extraction.text + .TfidfVectorizer(norm=norm, + use_idf=use_idf, + smooth_idf=smooth_idf, + sublinear_tf=sublinear_tf)) + skl2 = (sklearn.feature_extraction.text + .TfidfVectorizer(norm=norm, + use_idf=use_idf, + smooth_idf=smooth_idf, + sublinear_tf=sublinear_tf)) + + JUNK_FOOD_DOCS_SUBLIST = JUNK_FOOD_DOCS[:2] + if collection_type == "Bag": + full_docs = db.from_sequence(JUNK_FOOD_DOCS, npartitions=2) + sub_docs = db.from_sequence(JUNK_FOOD_DOCS_SUBLIST, npartitions=2) + elif collection_type == "Series": + full_docs = dd.from_pandas(pd.Series(JUNK_FOOD_DOCS), npartitions=2) + sub_docs = dd.from_pandas(pd.Series(JUNK_FOOD_DOCS_SUBLIST), + npartitions=2) + + csr_skl1 = skl1.fit_transform(JUNK_FOOD_DOCS) + skl2 = skl2.fit(JUNK_FOOD_DOCS) + csr_skl2 = skl2.transform(JUNK_FOOD_DOCS) + + dml1 = (dask_ml.feature_extraction.text + .TfidfVectorizer(norm=norm, + use_idf=use_idf, + smooth_idf=smooth_idf, + sublinear_tf=sublinear_tf)) + dml2 = (dask_ml.feature_extraction.text + .TfidfVectorizer(norm=norm, + use_idf=use_idf, + smooth_idf=smooth_idf, + sublinear_tf=sublinear_tf)) + + if distributed: + client = Client() # noqa + else: + client = dummy_context() + + csr_dml1 = dml1.fit_transform(full_docs) + dml2 = dml2.fit(full_docs) + csr_dml2 = dml2.transform(full_docs) + + with client: + exclude = {"vocabulary_actor_", "stop_words_"} + if not use_idf: + # idf_ being a property, the automatic attributes detection + # does not work as usual so we will exclude it in this case: + exclude.add("idf_") + assert_estimator_equal(skl1, dml1, exclude=exclude) + assert isinstance(csr_dml1, da.Array) + assert isinstance(csr_dml1._meta, scipy.sparse.csr_matrix) + np.testing.assert_array_almost_equal(csr_skl1.toarray(), + csr_dml1.compute().toarray()) + + assert_estimator_equal(skl2, dml2, exclude=exclude) + assert isinstance(csr_dml2, da.Array) + assert isinstance(csr_dml2._meta, scipy.sparse.csr_matrix) + np.testing.assert_array_almost_equal(csr_skl2.toarray(), + csr_dml2.compute().toarray()) + + csr_dml1 = dml1.transform(full_docs) + assert isinstance(csr_dml1, da.Array) + assert isinstance(csr_dml1._meta, scipy.sparse.csr_matrix) + np.testing.assert_array_almost_equal(csr_skl1.toarray(), + csr_dml1.compute().toarray()) + + csr_skl1 = skl1.transform(JUNK_FOOD_DOCS_SUBLIST) + csr_dml1 = dml1.transform(sub_docs) + assert isinstance(csr_dml1, da.Array) + assert isinstance(csr_dml1._meta, scipy.sparse.csr_matrix) + np.testing.assert_array_almost_equal(csr_skl1.toarray(), + csr_dml1.compute().toarray()) + + csr_skl1 = skl2.transform(JUNK_FOOD_DOCS_SUBLIST) + csr_dml1 = dml2.transform(sub_docs) + assert isinstance(csr_dml1, da.Array) + assert isinstance(csr_dml1._meta, scipy.sparse.csr_matrix) + np.testing.assert_array_almost_equal(csr_skl1.toarray(), + csr_dml1.compute().toarray())