Skip to content

Commit

Permalink
revert test_wikicorpus
Browse files Browse the repository at this point in the history
  • Loading branch information
markroxor committed Oct 21, 2016
1 parent 9bf5cae commit 0a353bb
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 20 deletions.
7 changes: 5 additions & 2 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def repeat(self, word_count):

class Doc2Vec(Word2Vec):
"""Class for training, using and evaluating neural networks described in http://arxiv.org/pdf/1405.4053v2.pdf"""
def __init__(self, documents=None,dm_mean=None,
def __init__(self, documents=None, dm_mean=None,
dm=1, dbow_words=0, dm_concat=0, dm_tag_count=1,
docvecs=None, docvecs_mapfile=None, comment=None, trim_rule=None, **kwargs):
"""
Expand Down Expand Up @@ -601,9 +601,12 @@ def __init__(self, documents=None,dm_mean=None,
"""

super(Doc2Vec, self).__init__(
sg=(1 + dm) % 2, dm_mean=dm_mean,
sg=(1 + dm) % 2,
null_word=dm_concat, **kwargs)

if dm_mean is not None:
self.cbow_mean = dm_mean

self.dbow_words = dbow_words
self.dm_concat = dm_concat
self.dm_tag_count = dm_tag_count
Expand Down
8 changes: 2 additions & 6 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(
self, sentences=None, size=100, alpha=0.025, window=5, min_count=5,
max_vocab_size=None, sample=1e-3, seed=1, workers=3, min_alpha=0.0001,
sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0,
trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH,**kwargs):
trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH):
"""
Initialize the model from an iterable of `sentences`. Each sentence is a
list of words (unicode strings) that will be used for training.
Expand Down Expand Up @@ -451,6 +451,7 @@ def __init__(
self.min_alpha = float(min_alpha)
self.hs = hs
self.negative = negative
self.cbow_mean = int(cbow_mean)
self.hashfxn = hashfxn
self.iter = iter
self.null_word = null_word
Expand All @@ -459,11 +460,6 @@ def __init__(
self.sorted_vocab = sorted_vocab
self.batch_words = batch_words

if "dm_mean" in kwargs and kwargs["dm_mean"] is not None:
self.cbow_mean = int(kwargs["dm_mean"])
else:
self.cbow_mean = int(cbow_mean)

if sentences is not None:
if isinstance(sentences, GeneratorType):
raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.")
Expand Down
16 changes: 7 additions & 9 deletions gensim/test/test_doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,10 @@ def test_int_doctags(self):
"""Test doc2vec doctag alternatives"""
corpus = DocsLeeCorpus()

size = 300
model = doc2vec.Doc2Vec(min_count=1, size=size)
model = doc2vec.Doc2Vec(min_count=1)
model.build_vocab(corpus)
self.assertEqual(len(model.docvecs.doctag_syn0), size)
self.assertEqual(model.docvecs[0].shape, (size,))
self.assertEqual(len(model.docvecs.doctag_syn0), 300)
self.assertEqual(model.docvecs[0].shape, (100,))
self.assertRaises(KeyError, model.__getitem__, '_*0')

def test_missing_string_doctag(self):
Expand All @@ -107,13 +106,12 @@ def test_string_doctags(self):
# force duplicated tags
corpus = corpus[0:10] + corpus

size = 300
model = doc2vec.Doc2Vec(size=size, min_count=1)
model = doc2vec.Doc2Vec(min_count=1)
model.build_vocab(corpus)

self.assertEqual(len(model.docvecs.doctag_syn0), size)
self.assertEqual(model.docvecs[0].shape, (size,))
self.assertEqual(model.docvecs['_*0'].shape, (size,))
self.assertEqual(len(model.docvecs.doctag_syn0), 300)
self.assertEqual(model.docvecs[0].shape, (100,))
self.assertEqual(model.docvecs['_*0'].shape, (100,))
self.assertTrue(all(model.docvecs['_*0'] == model.docvecs[0]))
self.assertTrue(max(d.offset for d in model.docvecs.doctags.values()) < len(model.docvecs.doctags))
self.assertTrue(max(model.docvecs._int_index(str_key) for str_key in model.docvecs.doctags.keys()) < len(model.docvecs.doctag_syn0))
Expand Down
4 changes: 1 addition & 3 deletions gensim/test/test_wikicorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ def setUp(self):
wc = WikiCorpus(datapath(FILENAME))

def test_get_texts_returns_generator_of_lists(self):

logger.debug("Current Python Version is "+str(sys.version_info))
if sys.version_info < (2, 7, 0):
return

wc = WikiCorpus(datapath(FILENAME))
l = wc.get_texts()
self.assertEqual(type(l), types.GeneratorType)
Expand All @@ -49,7 +48,6 @@ def test_first_element(self):
2) autism
"""

if sys.version_info < (2, 7, 0):
return
wc = WikiCorpus(datapath(FILENAME))
Expand Down

0 comments on commit 0a353bb

Please sign in to comment.