Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow indexing with np.int64 in doc2vec - #1231 #1254

Merged
merged 4 commits into from
May 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gensim/corpora/indexedcorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

import logging
import shelve
import six

import numpy

Expand Down Expand Up @@ -124,7 +124,7 @@ def __getitem__(self, docno):

if isinstance(docno, (slice, list, numpy.ndarray)):
return utils.SlicedCorpus(self, docno)
elif isinstance(docno, (int, numpy.integer)):
elif isinstance(docno, six.integer_types + (numpy.integer,)):
return self.docbyoffset(self.index[docno])
else:
raise ValueError('Unrecognised value for docno, use either a single integer, a slice or a numpy.ndarray')
Expand Down
7 changes: 1 addition & 6 deletions gensim/models/atmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,8 @@
# are included in the code where this is the case, for example in the log_perplexity
# and do_estep methods.

import pdb
from pdb import set_trace as st
from pprint import pprint

import logging
import numpy as np # for arrays, array broadcasting etc.
import numbers
from copy import deepcopy
from shutil import copyfile
from os.path import isfile
Expand Down Expand Up @@ -391,7 +386,7 @@ def inference(self, chunk, author2doc, doc2author, rhot, collect_sstats=False, c
doc_no = d
# Get the IDs and counts of all the words in the current document.
# TODO: this is duplication of code in LdaModel. Refactor.
if doc and not isinstance(doc[0][0], six.integer_types):
if doc and not isinstance(doc[0][0], six.integer_types + (np.integer,)):
# make sure the term IDs are ints, otherwise np will get upset
ids = [int(id) for id, _ in doc]
else:
Expand Down
20 changes: 10 additions & 10 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@
from collections import namedtuple, defaultdict
from timeit import default_timer

from numpy import zeros, random, sum as np_sum, add as np_add, concatenate, \
from numpy import zeros, sum as np_sum, add as np_add, concatenate, \
repeat as np_repeat, array, float32 as REAL, empty, ones, memmap as np_memmap, \
sqrt, newaxis, ndarray, dot, vstack, dtype, divide as np_divide
sqrt, newaxis, ndarray, dot, vstack, dtype, divide as np_divide, integer


from gensim.utils import call_on_class_only
from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.models.word2vec import Word2Vec, train_cbow_pair, train_sg_pair, train_batch_sg
from gensim.models.keyedvectors import KeyedVectors
from six.moves import xrange, zip
from six import string_types, integer_types, itervalues
from six import string_types, integer_types

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(self, mapfile_path=None):

def note_doctag(self, key, document_no, document_length):
"""Note a document tag during initial corpus scan, for structure sizing."""
if isinstance(key, int):
if isinstance(key, integer_types + (integer,)):
self.max_rawint = max(self.max_rawint, key)
else:
if key in self.doctags:
Expand All @@ -319,7 +319,7 @@ def trained_item(self, indexed_tuple):

def _int_index(self, index):
"""Return int index for either string or int index"""
if isinstance(index, int):
if isinstance(index, integer_types + (integer,)):
return index
else:
return self.max_rawint + 1 + self.doctags[index].offset
Expand Down Expand Up @@ -347,7 +347,7 @@ def __getitem__(self, index):
If a list, return designated tags' vector representations as a
2D numpy array: #tags x #vector_size.
"""
if isinstance(index, string_types + (int,)):
if isinstance(index, string_types + integer_types + (integer,)):
return self.doctag_syn0[self._int_index(index)]

return vstack([self[i] for i in index])
Expand All @@ -356,7 +356,7 @@ def __len__(self):
return self.count

def __contains__(self, index):
if isinstance(index, int):
if isinstance(index, integer_types + (integer,)):
return index < self.count
else:
return index in self.doctags
Expand Down Expand Up @@ -439,17 +439,17 @@ def most_similar(self, positive=[], negative=[], topn=10, clip_start=0, clip_end
self.init_sims()
clip_end = clip_end or len(self.doctag_syn0norm)

if isinstance(positive, string_types + integer_types) and not negative:
if isinstance(positive, string_types + integer_types + (integer,)) and not negative:
# allow calls like most_similar('dog'), as a shorthand for most_similar(['dog'])
positive = [positive]

# add weights for each doc, if not already present; default to 1.0 for positive and -1.0 for negative docs
positive = [
(doc, 1.0) if isinstance(doc, string_types + (ndarray,) + integer_types)
(doc, 1.0) if isinstance(doc, string_types + integer_types + (ndarray, integer))
else doc for doc in positive
]
negative = [
(doc, -1.0) if isinstance(doc, string_types + (ndarray,) + integer_types)
(doc, -1.0) if isinstance(doc, string_types + integer_types + (ndarray, integer))
else doc for doc in negative
]

Expand Down
2 changes: 1 addition & 1 deletion gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def inference(self, chunk, collect_sstats=False):
# Lee&Seung trick which speeds things up by an order of magnitude, compared
# to Blei's original LDA-C code, cool!).
for d, doc in enumerate(chunk):
if len(doc) > 0 and not isinstance(doc[0][0], six.integer_types):
if len(doc) > 0 and not isinstance(doc[0][0], six.integer_types + (np.integer,)):
# make sure the term IDs are ints, otherwise np will get upset
ids = [int(id) for id, _ in doc]
else:
Expand Down
1 change: 1 addition & 0 deletions gensim/test/test_corpora.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_indexing(self):

for idx, doc in enumerate(docs):
self.assertEqual(doc, corpus[idx])
self.assertEqual(doc, corpus[np.int64(idx)])

self.assertEqual(docs, list(corpus[:]))
self.assertEqual(docs[0:], list(corpus[0:]))
Expand Down
5 changes: 3 additions & 2 deletions gensim/test/test_doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_int_doctags(self):
model.build_vocab(corpus)
self.assertEqual(len(model.docvecs.doctag_syn0), 300)
self.assertEqual(model.docvecs[0].shape, (100,))
self.assertEqual(model.docvecs[np.int64(0)].shape, (100,))
self.assertRaises(KeyError, model.__getitem__, '_*0')

def test_missing_string_doctag(self):
Expand Down Expand Up @@ -164,7 +165,7 @@ def test_similarity_unseen_docs(self):
def model_sanity(self, model, keep_training=True):
"""Any non-trivial model on DocsLeeCorpus can pass these sanity checks"""
fire1 = 0 # doc 0 sydney fires
fire2 = 8 # doc 8 sydney fires
fire2 = np.int64(8) # doc 8 sydney fires
tennis1 = 6 # doc 6 tennis

# inferred vector should be top10 close to bulk-trained one
Expand Down Expand Up @@ -304,7 +305,7 @@ def test_mixed_tag_types(self):
model = doc2vec.Doc2Vec()
model.build_vocab(mixed_tag_corpus)
expected_length = len(sentences) + len(model.docvecs.doctags) # 9 sentences, 7 unique first tokens
self.assertEquals(len(model.docvecs.doctag_syn0), expected_length)
self.assertEqual(len(model.docvecs.doctag_syn0), expected_length)

def models_equal(self, model, model2):
# check words/hidden-weights
Expand Down