Skip to content

Commit

Permalink
Fix E731, E402, refactor tests & sklearn API code. Partial fix #1644 (#…
Browse files Browse the repository at this point in the history
…1689)

* Remove ignore of E731

* create datapath() in gensim.test.utils

* remove E402 from pylint ignore

* move get_tmpfile() to test.utils

* move get_tmpfile to utils and delete testfile

* fix build of travis

* remove unnecessary brackets

* move common texts to utils.py

* add docstrings to gensim.test.utils.

* fix building
  • Loading branch information
horpto authored and menshikh-iv committed Nov 7, 2017
1 parent 97820d0 commit 2bb8e1d
Show file tree
Hide file tree
Showing 51 changed files with 410 additions and 880 deletions.
4 changes: 3 additions & 1 deletion gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from six.moves import xrange, zip as izip


blas = lambda name, ndarray: scipy.linalg.get_blas_funcs((name,), (ndarray,))[0]
def blas(name, ndarray):
return scipy.linalg.get_blas_funcs((name,), (ndarray,))[0]


logger = logging.getLogger(__name__)

Expand Down
9 changes: 6 additions & 3 deletions gensim/similarities/docsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,18 +336,21 @@ def __getitem__(self, query):
# the following uses a lot of lazy evaluation and (optionally) parallel
# processing, to improve query latency and minimize memory footprint.
offsets = numpy.cumsum([0] + [len(shard) for shard in self.shards])
convert = lambda doc, shard_no: [(doc_index + offsets[shard_no], sim) for doc_index, sim in doc]

def convert(shard_no, doc):
return [(doc_index + offsets[shard_no], sim) for doc_index, sim in doc]

is_corpus, query = utils.is_corpus(query)
is_corpus = is_corpus or hasattr(query, 'ndim') and query.ndim > 1 and query.shape[0] > 1
if not is_corpus:
# user asked for num_best most similar and query is a single doc
results = (convert(result, shard_no) for shard_no, result in enumerate(shard_results))
results = (convert(shard_no, result) for shard_no, result in enumerate(shard_results))
result = heapq.nlargest(self.num_best, itertools.chain(*results), key=lambda item: item[1])
else:
# the trickiest combination: returning num_best results when query was a corpus
results = []
for shard_no, result in enumerate(shard_results):
shard_result = [convert(doc, shard_no) for doc in result]
shard_result = [convert(shard_no, doc) for doc in result]
results.append(shard_result)
result = []
for parts in izip(*results):
Expand Down
16 changes: 5 additions & 11 deletions gensim/sklearn_api/atmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,11 @@ def transform(self, author_names):
"This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method."
)

check = lambda x: [x] if not isinstance(x, list) else x
author_names = check(author_names)
X = [[] for _ in range(0, len(author_names))]

for k, v in enumerate(author_names):
transformed_author = self.gensim_model[v]
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
probs_author = matutils.sparse2full(transformed_author, self.num_topics)
X[k] = probs_author

return np.reshape(np.array(X), (len(author_names), self.num_topics))
if not isinstance(author_names, list):
author_names = [author_names]
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
topics = [matutils.sparse2full(self.gensim_model[author_name], self.num_topics) for author_name in author_names]
return np.reshape(np.array(topics), (len(author_names), self.num_topics))

def partial_fit(self, X, author2doc=None, doc2author=None):
"""
Expand Down
13 changes: 4 additions & 9 deletions gensim/sklearn_api/d2vmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ def transform(self, docs):
)

# The input as array of array
check = lambda x: [x] if isinstance(x[0], string_types) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
doc_vec = self.gensim_model.infer_vector(v)
X[k] = doc_vec

return np.reshape(np.array(X), (len(docs), self.gensim_model.vector_size))
if isinstance(docs[0], string_types):
docs = [docs]
vectors = [self.gensim_model.infer_vector(doc) for doc in docs]
return np.reshape(np.array(vectors), (len(docs), self.gensim_model.vector_size))
27 changes: 12 additions & 15 deletions gensim/sklearn_api/hdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,18 @@ def transform(self, docs):
)

# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

max_num_topics = 0
for k, v in enumerate(docs):
X[k] = self.gensim_model[v]
max_num_topics = max(max_num_topics, max(x[0] for x in X[k]) + 1)

for k, v in enumerate(X):
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
dense_vec = matutils.sparse2full(v, max_num_topics)
X[k] = dense_vec

return np.reshape(np.array(X), (len(docs), max_num_topics))
if isinstance(docs[0], tuple):
docs = [docs]
distribution, max_num_topics = [], 0

for doc in docs:
topicd = self.gensim_model[doc]
distribution.append(topicd)
max_num_topics = max(max_num_topics, max(topic[0] for topic in topicd) + 1)

# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
distribution = [matutils.sparse2full(t, max_num_topics) for t in distribution]
return np.reshape(np.array(distribution), (len(docs), max_num_topics))

def partial_fit(self, X):
"""
Expand Down
15 changes: 5 additions & 10 deletions gensim/sklearn_api/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,11 @@ def transform(self, docs):
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
doc_topics = self.gensim_model[v]
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
probs_docs = matutils.sparse2full(doc_topics, self.num_topics)
X[k] = probs_docs
return np.reshape(np.array(X), (len(docs), self.num_topics))
if isinstance(docs[0], tuple):
docs = [docs]
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
distribution = [matutils.sparse2full(self.gensim_model[doc], self.num_topics) for doc in docs]
return np.reshape(np.array(distribution), (len(docs), self.num_topics))

def partial_fit(self, X):
"""
Expand Down
13 changes: 4 additions & 9 deletions gensim/sklearn_api/ldaseqmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,7 @@ def transform(self, docs):
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
transformed_author = self.gensim_model[v]
X[k] = transformed_author

return np.reshape(np.array(X), (len(docs), self.num_topics))
if isinstance(docs[0], tuple):
docs = [docs]
proportions = [self.gensim_model[doc] for doc in docs]
return np.reshape(np.array(proportions), (len(docs), self.num_topics))
14 changes: 5 additions & 9 deletions gensim/sklearn_api/lsimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,11 @@ def transform(self, docs):
)

# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for i in range(0, len(docs))]
for k, v in enumerate(docs):
doc_topics = self.gensim_model[v]
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
probs_docs = matutils.sparse2full(doc_topics, self.num_topics)
X[k] = probs_docs
return np.reshape(np.array(X), (len(docs), self.num_topics))
if isinstance(docs[0], tuple):
docs = [docs]
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
distribution = [matutils.sparse2full(self.gensim_model[doc], self.num_topics) for doc in docs]
return np.reshape(np.array(distribution), (len(docs), self.num_topics))

def partial_fit(self, X):
"""
Expand Down
12 changes: 3 additions & 9 deletions gensim/sklearn_api/phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,9 @@ def transform(self, docs):
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# input as python lists
check = lambda x: [x] if isinstance(x[0], string_types) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
phrase_tokens = self.gensim_model[v]
X[k] = phrase_tokens

return X
if isinstance(docs[0], string_types):
docs = [docs]
return [self.gensim_model[doc] for doc in docs]

def partial_fit(self, X):
if self.gensim_model is None:
Expand Down
16 changes: 5 additions & 11 deletions gensim/sklearn_api/rpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,8 @@ def transform(self, docs):
)

# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
transformed_doc = self.gensim_model[v]
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
probs_docs = matutils.sparse2full(transformed_doc, self.num_topics)
X[k] = probs_docs

return np.reshape(np.array(X), (len(docs), self.num_topics))
if isinstance(docs[0], tuple):
docs = [docs]
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
presentation = [matutils.sparse2full(self.gensim_model[doc], self.num_topics) for doc in docs]
return np.reshape(np.array(presentation), (len(docs), self.num_topics))
14 changes: 4 additions & 10 deletions gensim/sklearn_api/text2bow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,10 @@ def transform(self, docs):
)

# input as python lists
check = lambda x: [x] if isinstance(x, string_types) else x
docs = check(docs)
tokenized_docs = [list(self.tokenizer(x)) for x in docs]
X = [[] for _ in range(0, len(tokenized_docs))]

for k, v in enumerate(tokenized_docs):
bow_val = self.gensim_model.doc2bow(v)
X[k] = bow_val

return X
if isinstance(docs, string_types):
docs = [docs]
tokenized_docs = (list(self.tokenizer(doc)) for doc in docs)
return [self.gensim_model.doc2bow(doc) for doc in tokenized_docs]

def partial_fit(self, X):
if self.gensim_model is None:
Expand Down
12 changes: 3 additions & 9 deletions gensim/sklearn_api/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,6 @@ def transform(self, docs):
)

# input as python lists
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
transformed_doc = self.gensim_model[v]
X[k] = transformed_doc

return X
if isinstance(docs[0], tuple):
docs = [docs]
return [self.gensim_model[doc] for doc in docs]
13 changes: 4 additions & 9 deletions gensim/sklearn_api/w2vmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,10 @@ def transform(self, words):
)

# The input as array of array
check = lambda x: [x] if isinstance(x, six.string_types) else x
words = check(words)
X = [[] for _ in range(0, len(words))]

for k, v in enumerate(words):
word_vec = self.gensim_model[v]
X[k] = word_vec

return np.reshape(np.array(X), (len(words), self.size))
if isinstance(words, six.string_types):
words = [words]
vectors = [self.gensim_model[word] for word in words]
return np.reshape(np.array(vectors), (len(words), self.size))

def partial_fit(self, X):
raise NotImplementedError(
Expand Down
Loading

0 comments on commit 2bb8e1d

Please sign in to comment.