Skip to content

Commit

Permalink
revert test_wikicorpus
Browse files Browse the repository at this point in the history
  • Loading branch information
markroxor committed Oct 21, 2016
1 parent 9bf5cae commit 718fbc6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
16 changes: 7 additions & 9 deletions gensim/test/test_doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,10 @@ def test_int_doctags(self):
"""Test doc2vec doctag alternatives"""
corpus = DocsLeeCorpus()

size = 300
model = doc2vec.Doc2Vec(min_count=1, size=size)
model = doc2vec.Doc2Vec(min_count=1)
model.build_vocab(corpus)
self.assertEqual(len(model.docvecs.doctag_syn0), size)
self.assertEqual(model.docvecs[0].shape, (size,))
self.assertEqual(len(model.docvecs.doctag_syn0), 300)
self.assertEqual(model.docvecs[0].shape, (100,))
self.assertRaises(KeyError, model.__getitem__, '_*0')

def test_missing_string_doctag(self):
Expand All @@ -107,13 +106,12 @@ def test_string_doctags(self):
# force duplicated tags
corpus = corpus[0:10] + corpus

size = 300
model = doc2vec.Doc2Vec(size=size, min_count=1)
model = doc2vec.Doc2Vec(min_count=1)
model.build_vocab(corpus)

self.assertEqual(len(model.docvecs.doctag_syn0), size)
self.assertEqual(model.docvecs[0].shape, (size,))
self.assertEqual(model.docvecs['_*0'].shape, (size,))
self.assertEqual(len(model.docvecs.doctag_syn0), 300)
self.assertEqual(model.docvecs[0].shape, (100,))
self.assertEqual(model.docvecs['_*0'].shape, (100,))
self.assertTrue(all(model.docvecs['_*0'] == model.docvecs[0]))
self.assertTrue(max(d.offset for d in model.docvecs.doctags.values()) < len(model.docvecs.doctags))
self.assertTrue(max(model.docvecs._int_index(str_key) for str_key in model.docvecs.doctags.keys()) < len(model.docvecs.doctag_syn0))
Expand Down
4 changes: 1 addition & 3 deletions gensim/test/test_wikicorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ def setUp(self):
wc = WikiCorpus(datapath(FILENAME))

def test_get_texts_returns_generator_of_lists(self):

logger.debug("Current Python Version is "+str(sys.version_info))
if sys.version_info < (2, 7, 0):
return

wc = WikiCorpus(datapath(FILENAME))
l = wc.get_texts()
self.assertEqual(type(l), types.GeneratorType)
Expand All @@ -49,7 +48,6 @@ def test_first_element(self):
2) autism
"""

if sys.version_info < (2, 7, 0):
return
wc = WikiCorpus(datapath(FILENAME))
Expand Down

0 comments on commit 718fbc6

Please sign in to comment.