-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] sklearn API for Gensim models #1462
Merged
menshikh-iv
merged 36 commits into
piskvorky:develop
from
chinmayapancholi13:skl_api_gensim
Aug 18, 2017
Merged
Changes from 8 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
f70c583
created sklearn wrapper for Doc2Vec
0c675fa
PEP8 fix
7210c69
added 'transform' function and refactored code
b733e25
updated d2v skl api code
7f198a1
added unittests for sklearn api for d2v model
8a12ef5
fixed flake8 errors
2c18b87
added skl api class for Text2Bow model
chinmayapancholi13 710d2ce
updated docstring for d2vmodel api
chinmayapancholi13 fe76d28
updated text2bow skl api code
chinmayapancholi13 9acaba5
added unittests for text2bow skl api class
chinmayapancholi13 8c5d04e
updated 'testPipeline' and 'testTransform' for text2bow
chinmayapancholi13 4101e30
added 'tokenizer' param to text2bow skl api
chinmayapancholi13 ed7a571
updated unittests for text2bow
chinmayapancholi13 66a8302
removed get_params and set_params functions from existing classes
chinmayapancholi13 75faa7a
added tfidf api class
chinmayapancholi13 8cbd2ba
added unittests for tfidf api class
chinmayapancholi13 48958c0
flake8 fixes
chinmayapancholi13 a852980
added skl api for hdpmodel
chinmayapancholi13 5a16e77
added unittests for hdp model api class
chinmayapancholi13 9191053
flake8 fixes
chinmayapancholi13 57257be
updated hdp api class
chinmayapancholi13 de2e11d
added 'testPartialFit' and 'testPipeline' tests for hdp api class
chinmayapancholi13 acdb6dd
flake8 fixes
chinmayapancholi13 1c7da8e
added skl API class for phrases
chinmayapancholi13 47a4214
added unit tests for phrases API class
chinmayapancholi13 9b32c4d
flake8 fixes
chinmayapancholi13 3a0977a
added 'testPartialFit' function for 'TestPhrasesTransformer'
chinmayapancholi13 687c3d7
updated 'testPipeline' function for 'TestText2BowTransformer'
chinmayapancholi13 7fa8632
updated skl api code as per PR 1473
chinmayapancholi13 d42c877
updated code for transform function for HDP transformer
chinmayapancholi13 3037620
updated tests as discussed in PR 1473
chinmayapancholi13 c52a0e2
added examples for new models in ipynb
chinmayapancholi13 3701eac
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
chinmayapancholi13 28a3aa3
unpinned sklearn version for running unit-tests
chinmayapancholi13 464a496
updated 'Pipeline' initialization format
chinmayapancholi13 9d65fdf
updated 'Pipeline' initialization format in ipynb
chinmayapancholi13 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
|
||
import numpy as np | ||
from six import string_types | ||
from sklearn.base import TransformerMixin, BaseEstimator | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from gensim import models | ||
from gensim.sklearn_integration import BaseSklearnWrapper | ||
|
||
|
||
class D2VTransformer(BaseSklearnWrapper, TransformerMixin, BaseEstimator): | ||
""" | ||
Base Doc2Vec module | ||
""" | ||
|
||
def __init__(self, dm_mean=None, dm=1, dbow_words=0, dm_concat=0, | ||
dm_tag_count=1, docvecs=None, docvecs_mapfile=None, | ||
comment=None, trim_rule=None, size=100, alpha=0.025, | ||
window=5, min_count=5, max_vocab_size=None, sample=1e-3, | ||
seed=1, workers=3, min_alpha=0.0001, hs=0, negative=5, | ||
cbow_mean=1, hashfxn=hash, iter=5, sorted_vocab=1, | ||
batch_words=10000): | ||
""" | ||
Sklearn api for Doc2Vec model. See gensim.models.Doc2Vec and gensim.models.Word2Vec for parameter details. | ||
""" | ||
self.gensim_model = None | ||
self.dm_mean = dm_mean | ||
self.dm = dm | ||
self.dbow_words = dbow_words | ||
self.dm_concat = dm_concat | ||
self.dm_tag_count = dm_tag_count | ||
self.docvecs = docvecs | ||
self.docvecs_mapfile = docvecs_mapfile | ||
self.comment = comment | ||
self.trim_rule = trim_rule | ||
|
||
# attributes associated with gensim.models.Word2Vec | ||
self.size = size | ||
self.alpha = alpha | ||
self.window = window | ||
self.min_count = min_count | ||
self.max_vocab_size = max_vocab_size | ||
self.sample = sample | ||
self.seed = seed | ||
self.workers = workers | ||
self.min_alpha = min_alpha | ||
self.hs = hs | ||
self.negative = negative | ||
self.cbow_mean = int(cbow_mean) | ||
self.hashfxn = hashfxn | ||
self.iter = iter | ||
self.sorted_vocab = sorted_vocab | ||
self.batch_words = batch_words | ||
|
||
def get_params(self, deep=True): | ||
""" | ||
Return all parameters as dictionary. | ||
""" | ||
return {"dm_mean": self.dm_mean, "dm": self.dm, "dbow_words": self.dbow_words, | ||
"dm_concat": self.dm_concat, "dm_tag_count": self.dm_tag_count, "docvecs": self.docvecs, | ||
"docvecs_mapfile": self.docvecs_mapfile, "comment": self.comment, "trim_rule": self.trim_rule, | ||
"size": self.size, "alpha": self.alpha, "window": self.window, "min_count": self.min_count, | ||
"max_vocab_size": self.max_vocab_size, "sample": self.sample, "seed": self.seed, | ||
"workers": self.workers, "min_alpha": self.min_alpha, "hs": self.hs, | ||
"negative": self.negative, "cbow_mean": self.cbow_mean, "hashfxn": self.hashfxn, | ||
"iter": self.iter, "sorted_vocab": self.sorted_vocab, "batch_words": self.batch_words} | ||
|
||
def set_params(self, **parameters): | ||
""" | ||
Set parameters | ||
""" | ||
super(D2VTransformer, self).set_params(**parameters) | ||
return self | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
Fit the model according to the given training data. | ||
Calls gensim.models.Doc2Vec | ||
""" | ||
self.gensim_model = models.Doc2Vec(documents=X, dm_mean=self.dm_mean, dm=self.dm, | ||
dbow_words=self.dbow_words, dm_concat=self.dm_concat, dm_tag_count=self.dm_tag_count, | ||
docvecs=self.docvecs, docvecs_mapfile=self.docvecs_mapfile, comment=self.comment, | ||
trim_rule=self.trim_rule, size=self.size, alpha=self.alpha, window=self.window, | ||
min_count=self.min_count, max_vocab_size=self.max_vocab_size, sample=self.sample, | ||
seed=self.seed, workers=self.workers, min_alpha=self.min_alpha, hs=self.hs, | ||
negative=self.negative, cbow_mean=self.cbow_mean, hashfxn=self.hashfxn, | ||
iter=self.iter, sorted_vocab=self.sorted_vocab, batch_words=self.batch_words) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Return the vector representations for the input documents. | ||
The input `docs` should be a list of lists like : [ ['calculus', 'mathematical'], ['geometry', 'operations', 'curves'] ] | ||
or a single document like : ['calculus', 'mathematical'] | ||
""" | ||
if self.gensim_model is None: | ||
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.") | ||
|
||
# The input as array of array | ||
check = lambda x: [x] if isinstance(x[0], string_types) else x | ||
docs = check(docs) | ||
X = [[] for _ in range(0, len(docs))] | ||
|
||
for k, v in enumerate(docs): | ||
doc_vec = self.gensim_model.infer_vector(v) | ||
X[k] = doc_vec | ||
|
||
return np.reshape(np.array(X), (len(docs), self.gensim_model.vector_size)) | ||
|
||
def partial_fit(self, X): | ||
raise NotImplementedError("'partial_fit' has not been implemented for D2VTransformer") |
74 changes: 74 additions & 0 deletions
74
gensim/sklearn_integration/sklearn_wrapper_gensim_text2bow.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
|
||
from sklearn.base import TransformerMixin, BaseEstimator | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from gensim.corpora import Dictionary | ||
from gensim.sklearn_integration import BaseSklearnWrapper | ||
|
||
|
||
class SklText2BowModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator): | ||
""" | ||
Base Text2Bow module | ||
""" | ||
|
||
def __init__(self, prune_at=2000000): | ||
""" | ||
Sklearn wrapper for Text2Bow model. | ||
""" | ||
self.gensim_model = None | ||
self.prune_at = prune_at | ||
|
||
def get_params(self, deep=True): | ||
""" | ||
Returns all parameters as dictionary. | ||
""" | ||
return {"prune_at": self.prune_at} | ||
|
||
def set_params(self, **parameters): | ||
""" | ||
Set all parameters. | ||
""" | ||
super(SklText2BowModel, self).set_params(**parameters) | ||
return self | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
Fit the model according to the given training data. | ||
""" | ||
self.gensim_model = gensim.corpora.Dictionary(documents=X, prune_at=self.prune_at) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Return the BOW format for the input documents. | ||
""" | ||
if self.gensim_model is None: | ||
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.") | ||
|
||
# The input as array of array | ||
check = lambda x: [x] if isinstance(x[0], six.string_types) else x | ||
docs = check(docs) | ||
X = [[] for _ in range(0, len(docs))] | ||
|
||
for k, v in enumerate(docs): | ||
bow_val = self.gensim_model.doc2bow(v) | ||
X[k] = bow_val | ||
|
||
return X | ||
|
||
def partial_fit(self, X): | ||
if self.gensim_model is None: | ||
self.gensim_model = gensim.corpora.Dictionary(prune_at=self.prune_at) | ||
|
||
self.gensim_model.add_documents(X) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,9 @@ | |
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel | ||
from gensim.sklearn_integration.sklearn_wrapper_gensim_w2vmodel import SklW2VModel | ||
from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklATModel | ||
from gensim.sklearn_integration.d2vmodel import D2VTransformer | ||
from gensim.corpora import mmcorpus, Dictionary | ||
from gensim.models import doc2vec | ||
from gensim import matutils | ||
|
||
module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder | ||
|
@@ -97,6 +99,9 @@ | |
['advances', 'in', 'the', 'understanding', 'of', 'electromagnetism', 'or', 'nuclear', 'physics', 'led', 'directly', 'to', 'the', 'development', 'of', 'new', 'products', 'that', 'have', 'dramatically', 'transformed', 'modern', 'day', 'society'] | ||
] | ||
|
||
d2v_sentences = [doc2vec.TaggedDocument(words, [i]) for i, words in enumerate(w2v_texts)] | ||
|
||
|
||
class TestSklLdaModelWrapper(unittest.TestCase): | ||
def setUp(self): | ||
numpy.random.seed(0) # set fixed seed to get similar values everytime | ||
|
@@ -537,5 +542,73 @@ def testPipeline(self): | |
self.assertEqual(len(ret_val), len(author_list)) | ||
|
||
|
||
class TestD2VTransformerWrapper(unittest.TestCase): | ||
def setUp(self): | ||
numpy.random.seed(0) | ||
self.model = D2VTransformer(min_count=1) | ||
self.model.fit(d2v_sentences) | ||
|
||
def testTransform(self): | ||
# tranform multiple documents | ||
docs = [] | ||
docs.append(w2v_texts[0]) | ||
docs.append(w2v_texts[1]) | ||
docs.append(w2v_texts[2]) | ||
matrix = self.model.transform(docs) | ||
self.assertEqual(matrix.shape[0], 3) | ||
self.assertEqual(matrix.shape[1], self.model.size) | ||
|
||
# tranform one document | ||
doc = w2v_texts[0] | ||
matrix = self.model.transform(doc) | ||
self.assertEqual(matrix.shape[0], 1) | ||
self.assertEqual(matrix.shape[1], self.model.size) | ||
|
||
def testSetGetParams(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add checking with the original model too for each "getset" test (same as previous PR) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
# updating only one param | ||
self.model.set_params(negative=20) | ||
model_params = self.model.get_params() | ||
self.assertEqual(model_params["negative"], 20) | ||
|
||
def testPipeline(self): | ||
numpy.random.seed(0) # set fixed seed to get similar values everytime | ||
model = D2VTransformer(min_count=1) | ||
model.fit(d2v_sentences) | ||
|
||
class_dict = {'mathematics': 1, 'physics': 0} | ||
train_data = [ | ||
(['calculus', 'mathematical'], 'mathematics'), (['geometry', 'operations', 'curves'], 'mathematics'), | ||
(['natural', 'nuclear'], 'physics'), (['science', 'electromagnetism', 'natural'], 'physics') | ||
] | ||
train_input = list(map(lambda x: x[0], train_data)) | ||
train_target = list(map(lambda x: class_dict[x[1]], train_data)) | ||
|
||
clf = linear_model.LogisticRegression(penalty='l2', C=0.1) | ||
clf.fit(model.transform(train_input), train_target) | ||
text_w2v = Pipeline((('features', model,), ('classifier', clf))) | ||
score = text_w2v.score(train_input, train_target) | ||
self.assertGreater(score, 0.40) | ||
|
||
def testPersistence(self): | ||
model_dump = pickle.dumps(self.model) | ||
model_load = pickle.loads(model_dump) | ||
|
||
doc = w2v_texts[0] | ||
loaded_transformed_vecs = model_load.transform(doc) | ||
|
||
# sanity check for transformation operation | ||
self.assertEqual(loaded_transformed_vecs.shape[0], 1) | ||
self.assertEqual(loaded_transformed_vecs.shape[1], model_load.size) | ||
|
||
# comparing the original and loaded models | ||
original_transformed_vecs = self.model.transform(doc) | ||
passed = numpy.allclose(sorted(loaded_transformed_vecs), sorted(original_transformed_vecs), atol=1e-1) | ||
self.assertTrue(passed) | ||
|
||
def testModelNotFitted(self): | ||
d2vmodel_wrapper = D2VTransformer(min_count=1) | ||
self.assertRaises(NotFittedError, d2vmodel_wrapper.transform, 1) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please call them python lists