Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Dictionary save_as_text method #56 + fix lint errors #1402

Merged
merged 5 commits into from
Jun 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions gensim/corpora/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@

from gensim import utils

if sys.version_info[0] >= 3:
unicode = str

from six import PY3, iteritems, iterkeys, itervalues, string_types
from six.moves import xrange
from six.moves import zip as izip

if sys.version_info[0] >= 3:
unicode = str


logger = logging.getLogger('gensim.corpora.dictionary')

Expand Down Expand Up @@ -180,7 +180,7 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N
2. more than `no_above` documents (fraction of total corpus size, *not*
absolute number).
3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of
the `no_below` and `no_above` settings
the `no_below` and `no_above` settings
4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or
keep all if `None`).

Expand All @@ -196,8 +196,7 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N
keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id]
good_ids = (
v for v in itervalues(self.token2id)
if no_below <= self.dfs.get(v, 0) <= no_above_abs
or v in keep_ids
if no_below <= self.dfs.get(v, 0) <= no_above_abs or v in keep_ids
)
else:
good_ids = (
Expand Down Expand Up @@ -232,7 +231,7 @@ def filter_n_most_frequent(self, remove_n):
# do the actual filtering, then rebuild dictionary to remove gaps in ids
most_frequent_words = [(self[id], self.dfs.get(id, 0)) for id in most_frequent_ids]
logger.info("discarding %i tokens: %s...", len(most_frequent_ids), most_frequent_words[:10])

self.filter_tokens(bad_ids=most_frequent_ids)
logger.info("resulting dictionary: %s" % self)

Expand Down Expand Up @@ -282,6 +281,7 @@ def compactify(self):
def save_as_text(self, fname, sort_by_word=True):
"""
Save this Dictionary to a text file, in format:
`num_docs`
`id[TAB]word_utf8[TAB]document frequency[NEWLINE]`. Sorted by word,
or by decreasing word frequency.

Expand All @@ -290,6 +290,8 @@ def save_as_text(self, fname, sort_by_word=True):
"""
logger.info("saving dictionary mapping to %s", fname)
with utils.smart_open(fname, 'wb') as fout:
numdocs_line = "%d\n" % self.num_docs
fout.write(utils.to_utf8(numdocs_line))
if sort_by_word:
for token, tokenid in sorted(iteritems(self.token2id)):
line = "%i\t%s\t%i\n" % (tokenid, token, self.dfs.get(tokenid, 0))
Expand Down Expand Up @@ -354,6 +356,13 @@ def load_from_text(fname):
with utils.smart_open(fname) as f:
for lineno, line in enumerate(f):
line = utils.to_unicode(line)
if lineno == 0:
if line.strip().isdigit():
# Older versions of save_as_text may not write num_docs on first line.
result.num_docs = int(line.strip())
continue
else:
logging.warning("Text does not contain num_docs on the first line.")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really a warning?

There's no ambiguity (3 columns vs 1 column), so I think INFO (or even DEBUG) should be sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning is not about number ambiguity, but about the fact that you are using old serialization format.
I put there a warning, because it may cause you errors somewhere down the pipeline (as mentioned by the original issue #56). If it does not make sense, I can change it, but I think warning should be the best.

Copy link
Owner

@piskvorky piskvorky Jun 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, but the current message doesn't help the users.

What "text"? Why should it contain "num_docs"? What should the user do about it?

This should be a backward compatible change, so "errors down the pipeline" are not acceptable. CC @menshikh-iv

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like: "This dictionary was serialized using old format. Please set the dictionary.num_docs manually or some dictionary methods may not work as intended."

try:
wordid, word, docfreq = line[:-1].split('\t')
except Exception:
Expand Down
116 changes: 88 additions & 28 deletions gensim/test/test_corpora_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,53 +120,52 @@ def testFilter(self):
d.filter_extremes(no_below=2, no_above=1.0, keep_n=4)
expected = {0: 3, 1: 3, 2: 3, 3: 3}
self.assertEqual(d.dfs, expected)

def testFilterKeepTokens_keepTokens(self):
# provide keep_tokens argument, keep the tokens given
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey'])
expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unchangedFunctionality(self):
# do not provide keep_tokens argument, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0)
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unseenToken(self):
# do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token'])
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterMostFrequent(self):
d = Dictionary(self.texts)
d.filter_n_most_frequent(4)
expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2}
self.assertEqual(d.dfs, expected)


d = Dictionary(self.texts)
d.filter_n_most_frequent(4)
expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2}
self.assertEqual(d.dfs, expected)

def testFilterTokens(self):
self.maxDiff = 10000
d = Dictionary(self.texts)

removed_word = d[0]
d.filter_tokens([0])

expected = {'computer': 0, 'eps': 8, 'graph': 10, 'human': 1,
'interface': 2, 'minors': 11, 'response': 3, 'survey': 4,
'system': 5, 'time': 6, 'trees': 9, 'user': 7}
expected = {
'computer': 0, 'eps': 8, 'graph': 10, 'human': 1,
'interface': 2, 'minors': 11, 'response': 3, 'survey': 4,
'system': 5, 'time': 6, 'trees': 9, 'user': 7}
del expected[removed_word]
self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys()))

expected[removed_word] = len(expected)
d.add_documents([[removed_word]])
self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys()))


def test_doc2bow(self):
d = Dictionary([["žluťoučký"], ["žluťoučký"]])

Expand All @@ -179,6 +178,66 @@ def test_doc2bow(self):
# unicode must be converted to utf8
self.assertEqual(d.doc2bow([u'\u017elu\u0165ou\u010dk\xfd']), [(0, 1)])

def test_saveAsText(self):
"""`Dictionary` can be saved as textfile. """
tmpf = get_tmpfile('save_dict_test.txt')
small_text = [
["prvé", "slovo"],
["slovo", "druhé"],
["druhé", "slovo"]]

d = Dictionary(small_text)

d.save_as_text(tmpf)
with open(tmpf) as file:
serialized_lines = file.readlines()
self.assertEqual(serialized_lines[0], "3\n")
self.assertEqual(len(serialized_lines), 4)
# We do not know, which word will have which index
self.assertEqual(serialized_lines[1][1:], "\tdruhé\t2\n")
self.assertEqual(serialized_lines[2][1:], "\tprvé\t1\n")
self.assertEqual(serialized_lines[3][1:], "\tslovo\t3\n")

d.save_as_text(tmpf, sort_by_word=False)
with open(tmpf) as file:
serialized_lines = file.readlines()
self.assertEqual(serialized_lines[0], "3\n")
self.assertEqual(len(serialized_lines), 4)
self.assertEqual(serialized_lines[1][1:], "\tslovo\t3\n")
self.assertEqual(serialized_lines[2][1:], "\tdruhé\t2\n")
self.assertEqual(serialized_lines[3][1:], "\tprvé\t1\n")

def test_loadFromText_legacy(self):
"""
`Dictionary` can be loaded from textfile in legacy format.
Legacy format does not have num_docs on the first line.
"""
tmpf = get_tmpfile('load_dict_test_legacy.txt')
no_num_docs_serialization = "1\tprvé\t1\n2\tslovo\t2\n"
with open(tmpf, "w") as file:
file.write(no_num_docs_serialization)

d = Dictionary.load_from_text(tmpf)
self.assertEqual(d.token2id[u"prvé"], 1)
self.assertEqual(d.token2id[u"slovo"], 2)
self.assertEqual(d.dfs[1], 1)
self.assertEqual(d.dfs[2], 2)
self.assertEqual(d.num_docs, 0)

def test_loadFromText(self):
"""`Dictionary` can be loaded from textfile."""
tmpf = get_tmpfile('load_dict_test.txt')
no_num_docs_serialization = "2\n1\tprvé\t1\n2\tslovo\t2\n"
with open(tmpf, "w") as file:
file.write(no_num_docs_serialization)

d = Dictionary.load_from_text(tmpf)
self.assertEqual(d.token2id[u"prvé"], 1)
self.assertEqual(d.token2id[u"slovo"], 2)
self.assertEqual(d.dfs[1], 1)
self.assertEqual(d.dfs[2], 2)
self.assertEqual(d.num_docs, 2)

def test_saveAsText_and_loadFromText(self):
"""`Dictionary` can be saved as textfile and loaded again from textfile. """
tmpf = get_tmpfile('dict_test.txt')
Expand All @@ -194,24 +253,25 @@ def test_saveAsText_and_loadFromText(self):
def test_from_corpus(self):
"""build `Dictionary` from an existing corpus"""

documents = ["Human machine interface for lab abc computer applications",
"A survey of user opinion of computer system response time",
"The EPS user interface management system",
"System and human system engineering testing of EPS",
"Relation of user perceived response time to error measurement",
"The generation of random binary unordered trees",
"The intersection graph of paths in trees",
"Graph minors IV Widths of trees and well quasi ordering",
"Graph minors A survey"]
documents = [
"Human machine interface for lab abc computer applications",
"A survey of user opinion of computer system response time",
"The EPS user interface management system",
"System and human system engineering testing of EPS",
"Relation of user perceived response time to error measurement",
"The generation of random binary unordered trees",
"The intersection graph of paths in trees",
"Graph minors IV Widths of trees and well quasi ordering",
"Graph minors A survey"]
stoplist = set('for a of the and to in'.split())
texts = [[word for word in document.lower().split() if word not in stoplist]
for document in documents]
texts = [
[word for word in document.lower().split() if word not in stoplist]
for document in documents]

# remove words that appear only once
all_tokens = sum(texts, [])
tokens_once = set(word for word in set(all_tokens) if all_tokens.count(word) == 1)
texts = [[word for word in text if word not in tokens_once]
for text in texts]
texts = [[word for word in text if word not in tokens_once] for text in texts]

dictionary = Dictionary(texts)
corpus = [dictionary.doc2bow(text) for text in texts]
Expand Down Expand Up @@ -260,7 +320,7 @@ def test_dict_interface(self):
self.assertTrue(isinstance(d.keys(), list))
self.assertTrue(isinstance(d.values(), list))

#endclass TestDictionary
# endclass TestDictionary


if __name__ == '__main__':
Expand Down