From 867a5d2601a744b3abaa7a40e85cee06cede75d7 Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Fri, 17 May 2024 21:11:45 +0200 Subject: [PATCH] Store all LmDataset parameters --- returnn/datasets/lm.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/returnn/datasets/lm.py b/returnn/datasets/lm.py index 2ca62769f6..955630504f 100644 --- a/returnn/datasets/lm.py +++ b/returnn/datasets/lm.py @@ -102,6 +102,7 @@ def __init__( """ super(LmDataset, self).__init__(**kwargs) + if callable(corpus_file): corpus_file = corpus_file() if callable(orth_symbols_file): @@ -113,12 +114,27 @@ def __init__( print("LmDataset, loading file", corpus_file, file=log.v4) + # Need to store all those for the pickling logic + self.corpus_file = corpus_file + self.skip_empty_lines = skip_empty_lines + self.orth_symbols_file = orth_symbols_file + self.orth_symbols_map_file = orth_symbols_map_file + self.orth_replace_map_file = orth_replace_map_file self.word_based = word_based self.word_end_symbol = word_end_symbol self.seq_end_symbol = seq_end_symbol self.unknown_symbol = unknown_symbol self.parse_orth_opts = parse_orth_opts or {} self.parse_orth_opts.setdefault("word_based", self.word_based) + self.phone_info = phone_info + self.add_random_phone_seqs = add_random_phone_seqs + self.auto_replace_unknown_symbol = auto_replace_unknown_symbol + self.log_auto_replace_unknown_symbols = log_auto_replace_unknown_symbols + self.log_skipped_seqs = log_skipped_seqs + self.error_on_invalid_seq = error_on_invalid_seq + self.add_delayed_seq_data = add_delayed_seq_data + self.delayed_seq_data_start_symbol = delayed_seq_data_start_symbol + if self.word_end_symbol and not self.word_based: # Character-based modeling and word_end_symbol is specified. # In this case, sentences end with self.word_end_symbol followed by the self.seq_end_symbol. self.parse_orth_opts.setdefault( @@ -215,15 +231,8 @@ def __init__( self.num_inputs = num_labels self.seq_order = None self._tag_prefix = "line-" # sequence tag is "line-n", where n is the line number (to be compatible with translation) # nopep8 - self.auto_replace_unknown_symbol = auto_replace_unknown_symbol - self.log_auto_replace_unknown_symbols = log_auto_replace_unknown_symbols - self.log_skipped_seqs = log_skipped_seqs - self.error_on_invalid_seq = error_on_invalid_seq - self.add_random_phone_seqs = add_random_phone_seqs for i in range(add_random_phone_seqs): self.num_outputs["random%i" % i] = self.num_outputs["data"] - self.add_delayed_seq_data = add_delayed_seq_data - self.delayed_seq_data_start_symbol = delayed_seq_data_start_symbol if add_delayed_seq_data: self.num_outputs["delayed"] = self.num_outputs["data"] self.labels["delayed"] = self.labels["data"]