Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add start of sentence token to dict.py #221

Merged
merged 4 commits into from
Jul 19, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions parlai/core/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class DictionaryAgent(Agent):
default_maxngram = -1
default_minfreq = 0
default_null = '__NULL__'
default_eos = '__EOS__'
default_end = '__END__'
default_unk = '__UNK__'
default_start = '__START__'

@staticmethod
def add_cmdline_args(argparser):
Expand Down Expand Up @@ -101,11 +102,14 @@ def add_cmdline_args(argparser):
'--dict-nulltoken', default=DictionaryAgent.default_null,
help='empty token, can be used for padding or just empty values')
dictionary.add_argument(
'--dict-eostoken', default=DictionaryAgent.default_eos,
'--dict-endtoken', default=DictionaryAgent.default_end,
help='token for end of sentence markers, if needed')
dictionary.add_argument(
'--dict-unktoken', default=DictionaryAgent.default_unk,
help='token to return for unavailable words')
dictionary.add_argument(
'--dict-starttoken', default=DictionaryAgent.default_start,
help='token for starting sentence generation, if needed')
dictionary.add_argument(
'--dict-maxexs', default=100000, type=int,
help='max number of examples to build dict on')
Expand All @@ -115,8 +119,9 @@ def __init__(self, opt, shared=None):
# initialize fields
self.opt = copy.deepcopy(opt)
self.null_token = opt['dict_nulltoken']
self.eos_token = opt['dict_eostoken']
self.end_token = opt['dict_endtoken']
self.unk_token = opt['dict_unktoken']
self.start_token = opt['dict_starttoken']
self.max_ngram_size = opt['dict_max_ngram_size']

if shared:
Expand All @@ -132,18 +137,24 @@ def __init__(self, opt, shared=None):
self.tok2ind[self.null_token] = 0
self.ind2tok[0] = self.null_token

if self.eos_token:
# set special end of sentence token
if self.end_token:
# set special end of sentence word token
index = len(self.tok2ind)
self.tok2ind[self.eos_token] = index
self.ind2tok[index] = self.eos_token
self.tok2ind[self.end_token] = index
self.ind2tok[index] = self.end_token

if self.unk_token:
# set special unknown word token
index = len(self.tok2ind)
self.tok2ind[self.unk_token] = index
self.ind2tok[index] = self.unk_token

if self.start_token:
# set special start of sentence word token
index = len(self.tok2ind)
self.tok2ind[self.start_token] = index
self.ind2tok[index] = self.start_token

if opt.get('dict_file') and os.path.isfile(opt['dict_file']):
# load pre-existing dictionary
self.load(opt['dict_file'])
Expand All @@ -164,13 +175,17 @@ def __init__(self, opt, shared=None):

if not shared:

if self.start_token:
# fix count for start of sentence token to one billion and three
self.freq[self.start_token] = 1000000003

if self.null_token:
# fix count for null token to one billion and two
self.freq[self.null_token] = 1000000002

if self.eos_token:
if self.end_token:
# fix count for end of sentence token to one billion and one
self.freq[self.eos_token] = 1000000001
self.freq[self.end_token] = 1000000001

if self.unk_token:
# fix count for unknown token to one billion
Expand Down