Skip to content

Commit

Permalink
Store all LmDataset parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki committed May 24, 2024
1 parent 2aabab9 commit 867a5d2
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions returnn/datasets/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
"""
super(LmDataset, self).__init__(**kwargs)


if callable(corpus_file):
corpus_file = corpus_file()
if callable(orth_symbols_file):
Expand All @@ -113,12 +114,27 @@ def __init__(

print("LmDataset, loading file", corpus_file, file=log.v4)

# Need to store all those for the pickling logic
self.corpus_file = corpus_file
self.skip_empty_lines = skip_empty_lines
self.orth_symbols_file = orth_symbols_file
self.orth_symbols_map_file = orth_symbols_map_file
self.orth_replace_map_file = orth_replace_map_file
self.word_based = word_based
self.word_end_symbol = word_end_symbol
self.seq_end_symbol = seq_end_symbol
self.unknown_symbol = unknown_symbol
self.parse_orth_opts = parse_orth_opts or {}
self.parse_orth_opts.setdefault("word_based", self.word_based)
self.phone_info = phone_info
self.add_random_phone_seqs = add_random_phone_seqs
self.auto_replace_unknown_symbol = auto_replace_unknown_symbol
self.log_auto_replace_unknown_symbols = log_auto_replace_unknown_symbols
self.log_skipped_seqs = log_skipped_seqs
self.error_on_invalid_seq = error_on_invalid_seq
self.add_delayed_seq_data = add_delayed_seq_data
self.delayed_seq_data_start_symbol = delayed_seq_data_start_symbol

if self.word_end_symbol and not self.word_based: # Character-based modeling and word_end_symbol is specified.
# In this case, sentences end with self.word_end_symbol followed by the self.seq_end_symbol.
self.parse_orth_opts.setdefault(
Expand Down Expand Up @@ -215,15 +231,8 @@ def __init__(
self.num_inputs = num_labels
self.seq_order = None
self._tag_prefix = "line-" # sequence tag is "line-n", where n is the line number (to be compatible with translation) # nopep8
self.auto_replace_unknown_symbol = auto_replace_unknown_symbol
self.log_auto_replace_unknown_symbols = log_auto_replace_unknown_symbols
self.log_skipped_seqs = log_skipped_seqs
self.error_on_invalid_seq = error_on_invalid_seq
self.add_random_phone_seqs = add_random_phone_seqs
for i in range(add_random_phone_seqs):
self.num_outputs["random%i" % i] = self.num_outputs["data"]
self.add_delayed_seq_data = add_delayed_seq_data
self.delayed_seq_data_start_symbol = delayed_seq_data_start_symbol
if add_delayed_seq_data:
self.num_outputs["delayed"] = self.num_outputs["data"]
self.labels["delayed"] = self.labels["data"]
Expand Down

0 comments on commit 867a5d2

Please sign in to comment.