Skip to content

Commit

Permalink
Fix D2VTransformer.fit_transform. Fix piskvorky#1834 (piskvorky#1845)
Browse files Browse the repository at this point in the history
* Fix D2VTransformer.fit_transform\(piskvorky#1834\)

* Add check for TaggedDocument

* Add test for D2VTransformer fit_transform

* Add test and check d2vtransformer
  • Loading branch information
karshd3v authored and sj29-innovate committed Feb 21, 2018
1 parent c8e7325 commit 53ac6b2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
7 changes: 6 additions & 1 deletion gensim/sklearn_api/d2vmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.models import doc2vec


class D2VTransformer(TransformerMixin, BaseEstimator):
Expand Down Expand Up @@ -63,8 +64,12 @@ def fit(self, X, y=None):
Fit the model according to the given training data.
Calls gensim.models.Doc2Vec
"""
if isinstance(X[0], doc2vec.TaggedDocument):
d2v_sentences = X
else:
d2v_sentences = [doc2vec.TaggedDocument(words, [i]) for i, words in enumerate(X)]
self.gensim_model = models.Doc2Vec(
documents=X, dm_mean=self.dm_mean, dm=self.dm,
documents=d2v_sentences, dm_mean=self.dm_mean, dm=self.dm,
dbow_words=self.dbow_words, dm_concat=self.dm_concat, dm_tag_count=self.dm_tag_count,
docvecs=self.docvecs, docvecs_mapfile=self.docvecs_mapfile, comment=self.comment,
trim_rule=self.trim_rule, size=self.size, alpha=self.alpha, window=self.window,
Expand Down
15 changes: 15 additions & 0 deletions gensim/test/test_sklearn_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,21 @@ def testTransform(self):
self.assertEqual(matrix.shape[0], 1)
self.assertEqual(matrix.shape[1], self.model.size)

def testFitTransform(self):
model = D2VTransformer(min_count=1)

# fit and transform multiple documents
docs = [w2v_texts[0], w2v_texts[1], w2v_texts[2]]
matrix = model.fit_transform(docs)
self.assertEqual(matrix.shape[0], 3)
self.assertEqual(matrix.shape[1], model.size)

# fit and transform one document
doc = w2v_texts[0]
matrix = model.fit_transform(doc)
self.assertEqual(matrix.shape[0], 1)
self.assertEqual(matrix.shape[1], model.size)

def testSetGetParams(self):
# updating only one param
self.model.set_params(negative=20)
Expand Down

0 comments on commit 53ac6b2

Please sign in to comment.