Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] sklearn API for Gensim models #1462

Merged
merged 36 commits into from
Aug 18, 2017
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f70c583
created sklearn wrapper for Doc2Vec
Jul 5, 2017
0c675fa
PEP8 fix
Jul 5, 2017
7210c69
added 'transform' function and refactored code
Jul 10, 2017
b733e25
updated d2v skl api code
Jul 11, 2017
7f198a1
added unittests for sklearn api for d2v model
Jul 11, 2017
8a12ef5
fixed flake8 errors
Jul 11, 2017
2c18b87
added skl api class for Text2Bow model
chinmayapancholi13 Jul 12, 2017
710d2ce
updated docstring for d2vmodel api
chinmayapancholi13 Jul 12, 2017
fe76d28
updated text2bow skl api code
chinmayapancholi13 Jul 13, 2017
9acaba5
added unittests for text2bow skl api class
chinmayapancholi13 Jul 13, 2017
8c5d04e
updated 'testPipeline' and 'testTransform' for text2bow
chinmayapancholi13 Jul 13, 2017
4101e30
added 'tokenizer' param to text2bow skl api
chinmayapancholi13 Jul 14, 2017
ed7a571
updated unittests for text2bow
chinmayapancholi13 Jul 14, 2017
66a8302
removed get_params and set_params functions from existing classes
chinmayapancholi13 Jul 18, 2017
75faa7a
added tfidf api class
chinmayapancholi13 Jul 18, 2017
8cbd2ba
added unittests for tfidf api class
chinmayapancholi13 Jul 18, 2017
48958c0
flake8 fixes
chinmayapancholi13 Jul 18, 2017
a852980
added skl api for hdpmodel
chinmayapancholi13 Jul 19, 2017
5a16e77
added unittests for hdp model api class
chinmayapancholi13 Jul 19, 2017
9191053
flake8 fixes
chinmayapancholi13 Jul 19, 2017
57257be
updated hdp api class
chinmayapancholi13 Jul 20, 2017
de2e11d
added 'testPartialFit' and 'testPipeline' tests for hdp api class
chinmayapancholi13 Jul 20, 2017
acdb6dd
flake8 fixes
chinmayapancholi13 Jul 20, 2017
1c7da8e
added skl API class for phrases
chinmayapancholi13 Aug 1, 2017
47a4214
added unit tests for phrases API class
chinmayapancholi13 Aug 1, 2017
9b32c4d
flake8 fixes
chinmayapancholi13 Aug 1, 2017
3a0977a
added 'testPartialFit' function for 'TestPhrasesTransformer'
chinmayapancholi13 Aug 1, 2017
687c3d7
updated 'testPipeline' function for 'TestText2BowTransformer'
chinmayapancholi13 Aug 1, 2017
7fa8632
updated skl api code as per PR 1473
chinmayapancholi13 Aug 15, 2017
d42c877
updated code for transform function for HDP transformer
chinmayapancholi13 Aug 15, 2017
3037620
updated tests as discussed in PR 1473
chinmayapancholi13 Aug 15, 2017
c52a0e2
added examples for new models in ipynb
chinmayapancholi13 Aug 16, 2017
3701eac
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
chinmayapancholi13 Aug 17, 2017
28a3aa3
unpinned sklearn version for running unit-tests
chinmayapancholi13 Aug 18, 2017
464a496
updated 'Pipeline' initialization format
chinmayapancholi13 Aug 18, 2017
9d65fdf
updated 'Pipeline' initialization format in ipynb
chinmayapancholi13 Aug 18, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gensim/sklearn_integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from .sklearn_wrapper_gensim_rpmodel import SklRpModel # noqa: F401
from .sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel # noqa: F401
from .sklearn_wrapper_gensim_w2vmodel import SklW2VModel # noqa: F401
from .d2vmodel import D2VTransformer # noqa: F401
from .sklearn_wrapper_gensim_atmodel import SklATModel # noqa: F401
121 changes: 121 additions & 0 deletions gensim/sklearn_integration/d2vmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""

import numpy as np
from six import string_types
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.sklearn_integration import BaseSklearnWrapper


class D2VTransformer(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
"""
Base Doc2Vec module
"""

def __init__(self, dm_mean=None, dm=1, dbow_words=0, dm_concat=0,
dm_tag_count=1, docvecs=None, docvecs_mapfile=None,
comment=None, trim_rule=None, size=100, alpha=0.025,
window=5, min_count=5, max_vocab_size=None, sample=1e-3,
seed=1, workers=3, min_alpha=0.0001, hs=0, negative=5,
cbow_mean=1, hashfxn=hash, iter=5, sorted_vocab=1,
batch_words=10000):
"""
Sklearn api for Doc2Vec model. See gensim.models.Doc2Vec and gensim.models.Word2Vec for parameter details.
"""
self.gensim_model = None
self.dm_mean = dm_mean
self.dm = dm
self.dbow_words = dbow_words
self.dm_concat = dm_concat
self.dm_tag_count = dm_tag_count
self.docvecs = docvecs
self.docvecs_mapfile = docvecs_mapfile
self.comment = comment
self.trim_rule = trim_rule

# attributes associated with gensim.models.Word2Vec
self.size = size
self.alpha = alpha
self.window = window
self.min_count = min_count
self.max_vocab_size = max_vocab_size
self.sample = sample
self.seed = seed
self.workers = workers
self.min_alpha = min_alpha
self.hs = hs
self.negative = negative
self.cbow_mean = int(cbow_mean)
self.hashfxn = hashfxn
self.iter = iter
self.sorted_vocab = sorted_vocab
self.batch_words = batch_words

def get_params(self, deep=True):
"""
Return all parameters as dictionary.
"""
return {"dm_mean": self.dm_mean, "dm": self.dm, "dbow_words": self.dbow_words,
"dm_concat": self.dm_concat, "dm_tag_count": self.dm_tag_count, "docvecs": self.docvecs,
"docvecs_mapfile": self.docvecs_mapfile, "comment": self.comment, "trim_rule": self.trim_rule,
"size": self.size, "alpha": self.alpha, "window": self.window, "min_count": self.min_count,
"max_vocab_size": self.max_vocab_size, "sample": self.sample, "seed": self.seed,
"workers": self.workers, "min_alpha": self.min_alpha, "hs": self.hs,
"negative": self.negative, "cbow_mean": self.cbow_mean, "hashfxn": self.hashfxn,
"iter": self.iter, "sorted_vocab": self.sorted_vocab, "batch_words": self.batch_words}

def set_params(self, **parameters):
"""
Set parameters
"""
super(D2VTransformer, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Calls gensim.models.Doc2Vec
"""
self.gensim_model = models.Doc2Vec(documents=X, dm_mean=self.dm_mean, dm=self.dm,
dbow_words=self.dbow_words, dm_concat=self.dm_concat, dm_tag_count=self.dm_tag_count,
docvecs=self.docvecs, docvecs_mapfile=self.docvecs_mapfile, comment=self.comment,
trim_rule=self.trim_rule, size=self.size, alpha=self.alpha, window=self.window,
min_count=self.min_count, max_vocab_size=self.max_vocab_size, sample=self.sample,
seed=self.seed, workers=self.workers, min_alpha=self.min_alpha, hs=self.hs,
negative=self.negative, cbow_mean=self.cbow_mean, hashfxn=self.hashfxn,
iter=self.iter, sorted_vocab=self.sorted_vocab, batch_words=self.batch_words)
return self

def transform(self, docs):
"""
Return the vector representations for the input documents.
The input `docs` should be a list of lists like : [ ['calculus', 'mathematical'], ['geometry', 'operations', 'curves'] ]
or a single document like : ['calculus', 'mathematical']
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# The input as array of array
check = lambda x: [x] if isinstance(x[0], string_types) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
doc_vec = self.gensim_model.infer_vector(v)
X[k] = doc_vec

return np.reshape(np.array(X), (len(docs), self.gensim_model.vector_size))

def partial_fit(self, X):
raise NotImplementedError("'partial_fit' has not been implemented for D2VTransformer")
74 changes: 74 additions & 0 deletions gensim/sklearn_integration/sklearn_wrapper_gensim_text2bow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""

from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim.corpora import Dictionary
from gensim.sklearn_integration import BaseSklearnWrapper


class SklText2BowModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
"""
Base Text2Bow module
"""

def __init__(self, prune_at=2000000):
"""
Sklearn wrapper for Text2Bow model.
"""
self.gensim_model = None
self.prune_at = prune_at

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"prune_at": self.prune_at}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklText2BowModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
"""
self.gensim_model = gensim.corpora.Dictionary(documents=X, prune_at=self.prune_at)
return self

def transform(self, docs):
"""
Return the BOW format for the input documents.
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# The input as array of array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please call them python lists

check = lambda x: [x] if isinstance(x[0], six.string_types) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
bow_val = self.gensim_model.doc2bow(v)
X[k] = bow_val

return X

def partial_fit(self, X):
if self.gensim_model is None:
self.gensim_model = gensim.corpora.Dictionary(prune_at=self.prune_at)

self.gensim_model.add_documents(X)
return self
73 changes: 73 additions & 0 deletions gensim/test/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_w2vmodel import SklW2VModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklATModel
from gensim.sklearn_integration.d2vmodel import D2VTransformer
from gensim.corpora import mmcorpus, Dictionary
from gensim.models import doc2vec
from gensim import matutils

module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder
Expand Down Expand Up @@ -97,6 +99,9 @@
['advances', 'in', 'the', 'understanding', 'of', 'electromagnetism', 'or', 'nuclear', 'physics', 'led', 'directly', 'to', 'the', 'development', 'of', 'new', 'products', 'that', 'have', 'dramatically', 'transformed', 'modern', 'day', 'society']
]

d2v_sentences = [doc2vec.TaggedDocument(words, [i]) for i, words in enumerate(w2v_texts)]


class TestSklLdaModelWrapper(unittest.TestCase):
def setUp(self):
numpy.random.seed(0) # set fixed seed to get similar values everytime
Expand Down Expand Up @@ -537,5 +542,73 @@ def testPipeline(self):
self.assertEqual(len(ret_val), len(author_list))


class TestD2VTransformerWrapper(unittest.TestCase):
def setUp(self):
numpy.random.seed(0)
self.model = D2VTransformer(min_count=1)
self.model.fit(d2v_sentences)

def testTransform(self):
# tranform multiple documents
docs = []
docs.append(w2v_texts[0])
docs.append(w2v_texts[1])
docs.append(w2v_texts[2])
matrix = self.model.transform(docs)
self.assertEqual(matrix.shape[0], 3)
self.assertEqual(matrix.shape[1], self.model.size)

# tranform one document
doc = w2v_texts[0]
matrix = self.model.transform(doc)
self.assertEqual(matrix.shape[0], 1)
self.assertEqual(matrix.shape[1], self.model.size)

def testSetGetParams(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add checking with the original model too for each "getset" test (same as previous PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# updating only one param
self.model.set_params(negative=20)
model_params = self.model.get_params()
self.assertEqual(model_params["negative"], 20)

def testPipeline(self):
numpy.random.seed(0) # set fixed seed to get similar values everytime
model = D2VTransformer(min_count=1)
model.fit(d2v_sentences)

class_dict = {'mathematics': 1, 'physics': 0}
train_data = [
(['calculus', 'mathematical'], 'mathematics'), (['geometry', 'operations', 'curves'], 'mathematics'),
(['natural', 'nuclear'], 'physics'), (['science', 'electromagnetism', 'natural'], 'physics')
]
train_input = list(map(lambda x: x[0], train_data))
train_target = list(map(lambda x: class_dict[x[1]], train_data))

clf = linear_model.LogisticRegression(penalty='l2', C=0.1)
clf.fit(model.transform(train_input), train_target)
text_w2v = Pipeline((('features', model,), ('classifier', clf)))
score = text_w2v.score(train_input, train_target)
self.assertGreater(score, 0.40)

def testPersistence(self):
model_dump = pickle.dumps(self.model)
model_load = pickle.loads(model_dump)

doc = w2v_texts[0]
loaded_transformed_vecs = model_load.transform(doc)

# sanity check for transformation operation
self.assertEqual(loaded_transformed_vecs.shape[0], 1)
self.assertEqual(loaded_transformed_vecs.shape[1], model_load.size)

# comparing the original and loaded models
original_transformed_vecs = self.model.transform(doc)
passed = numpy.allclose(sorted(loaded_transformed_vecs), sorted(original_transformed_vecs), atol=1e-1)
self.assertTrue(passed)

def testModelNotFitted(self):
d2vmodel_wrapper = D2VTransformer(min_count=1)
self.assertRaises(NotFittedError, d2vmodel_wrapper.transform, 1)


if __name__ == '__main__':
unittest.main()