diff --git a/parlai/agents/drqa/drqa.py b/parlai/agents/drqa/drqa.py index f8d52dfc847..2aacd12431f 100644 --- a/parlai/agents/drqa/drqa.py +++ b/parlai/agents/drqa/drqa.py @@ -147,7 +147,9 @@ def _init_from_saved(self, fname): map_location=lambda storage, loc: storage) # TODO expand dict and embeddings for new data - self.word_dict = saved_params['word_dict'] + loaded_words = saved_params['word_dict'] + self.word_dict.copy_dict(loaded_words) + self.feature_dict = saved_params['feature_dict'] self.state_dict = saved_params['state_dict'] config.override_args(self.opt, saved_params['config']) diff --git a/parlai/core/dict.py b/parlai/core/dict.py index 021a863b62e..517fd424523 100644 --- a/parlai/core/dict.py +++ b/parlai/core/dict.py @@ -251,6 +251,14 @@ def __setitem__(self, key, value): self.tok2ind[key] = index self.ind2tok[index] = key + def copy_dict(self, dictionary): + """Overwrite own state with any state in the other dictionary. + This allows loading of the contents of another dictionary while keeping + the current dictionary version. + """ + for k, v in vars(dictionary).items(): + setattr(self, k, v) + def freqs(self): return self.freq