diff --git a/examples/bert/bert_classifier.py b/examples/bert/bert_classifier.py index 6bb9ea3beec92..ef43ae2076e66 100644 --- a/examples/bert/bert_classifier.py +++ b/examples/bert/bert_classifier.py @@ -103,7 +103,7 @@ def mnli_line_processor(line_id, line): batch_size=config.batch_size, line_processor=mnli_line_processor) - dev_dataloader = BertDataLoader( + test_dataloader = BertDataLoader( "./data/glue_data/MNLI/dev_matched.tsv", tokenizer, ["contradiction", "entailment", "neutral"], max_seq_length=config.max_seq_len, diff --git a/examples/bert_leveldb/bert_classifier.py b/examples/bert_leveldb/bert_classifier.py index 891901388d742..11bc85758ebbe 100644 --- a/examples/bert_leveldb/bert_classifier.py +++ b/examples/bert_leveldb/bert_classifier.py @@ -105,7 +105,7 @@ def mnli_line_processor(line_id, line): mode="leveldb", phase="train") - dev_dataloader = BertDataLoader( + test_dataloader = BertDataLoader( "./data/glue_data/MNLI/dev_matched.tsv", tokenizer, ["contradiction", "entailment", "neutral"], max_seq_length=config.max_seq_len, diff --git a/hapi/text/bert/optimization.py b/hapi/text/bert/optimization.py index 2bf6b7f262127..b2ba8f65a7447 100755 --- a/hapi/text/bert/optimization.py +++ b/hapi/text/bert/optimization.py @@ -130,6 +130,18 @@ def exclude_from_weight_decay(self, name): return True return False + def state_dict(self): + return self.optimizer.state_dict() + + def set_dict(self, state_dict): + return self.optimizer.set_dict(state_dict) + + def get_opti_var_name_list(self): + return self.optimizer.get_opti_var_name_list() + + def current_step_lr(self): + return self.optimizer.current_step_lr() + def minimize(self, loss, use_data_parallel=False, model=None): param_list = dict()