diff --git a/deepray/core/base_trainer.py b/deepray/core/base_trainer.py index 269bf63..a00bcfb 100644 --- a/deepray/core/base_trainer.py +++ b/deepray/core/base_trainer.py @@ -596,18 +596,32 @@ def fit( self._checkpoints, self._managers = {}, {} for name, model in self._model.items(): if "main" in name: - _checkpoint = tf.train.Checkpoint(root=model, optimizer=self.optimizer) + _checkpoint = tf.train.Checkpoint(model=model, optimizer=self.optimizer) self._checkpoints[name] = _checkpoint self._managers[name] = tf.train.CheckpointManager( _checkpoint, os.path.join(FLAGS.model_dir, f'ckpt_{name}'), max_to_keep=3 ) else: - _checkpoint = tf.train.Checkpoint(root=model) + _checkpoint = tf.train.Checkpoint(model=model) self._checkpoints[name] = _checkpoint self._managers[name] = tf.train.CheckpointManager( _checkpoint, os.path.join(FLAGS.model_dir, f'ckpt_{name}'), max_to_keep=3 ) + if FLAGS.init_checkpoint: + for (name, ckpt), init_ckpt in zip(self._checkpoints.items(), FLAGS.init_checkpoint): + if init_ckpt: + logging.info(f'Checkpoint file {init_ckpt} found and restoring from initial checkpoint for {name} model.') + ckpt.restore(init_ckpt).assert_existing_objects_matched() + logging.info('Loading from checkpoint file completed') + + if FLAGS.init_weights: + for (name, _model), init_weight in zip(self._model.items(), FLAGS.init_weights): + if init_weight: + logging.info(f'variables file {init_weight} found and restoring from initial variables for {name} model.') + _model.load_weights(os.path.join(init_weight, "variables")) + logging.info('Loading from weights file completed') + if FLAGS.num_accumulation_steps > 1: self.accum_gradients = GradientAccumulator() diff --git a/deepray/utils/export/export.py b/deepray/utils/export/export.py index bba151c..d0b810f 100644 --- a/deepray/utils/export/export.py +++ b/deepray/utils/export/export.py @@ -145,7 +145,8 @@ def helper(name, _model: tf.keras.Model, _checkpoint_dir): # for opt_de_var in opt_de_vars: # opt_de_var.save_to_file_system(dirpath=de_dir, proc_size=get_world_size(), proc_rank=get_rank()) - logging.info(f"save pb model to:{_savedmodel_dir}, without optimizer & traces") + if is_main_process(): + logging.info(f"save pb model to: {_savedmodel_dir}, without optimizer & traces") if isinstance(model, dict): for name, _model in model.items(): diff --git a/deepray/utils/flags/_base.py b/deepray/utils/flags/_base.py index 367fd1d..1f0dcaa 100644 --- a/deepray/utils/flags/_base.py +++ b/deepray/utils/flags/_base.py @@ -89,9 +89,9 @@ def define_base( key_flags.append("num_accumulation_steps") if init_checkpoint: - flags.DEFINE_string('init_checkpoint', '', 'Initial checkpoint (usually from a pre-trained BERT model).') + flags.DEFINE_list('init_checkpoint', '', 'Initial checkpoint (usually from a pre-trained BERT model).') key_flags.append("init_checkpoint") - flags.DEFINE_string("init_weights", '', "Initial weights for the main model.") + flags.DEFINE_list("init_weights", '', "Initial weights for the main model.") key_flags.append("init_weights") if save_checkpoint_steps: diff --git a/modelzoo/LanguageModeling/BERT/run_squad.py b/modelzoo/LanguageModeling/BERT/run_squad.py index 84b0a30..d8c5705 100644 --- a/modelzoo/LanguageModeling/BERT/run_squad.py +++ b/modelzoo/LanguageModeling/BERT/run_squad.py @@ -241,6 +241,7 @@ def _get_squad_model(): callbacks=custom_callbacks, ) trainer.fit(train_input=train_input,) + export.export_to_savedmodel(model=trainer.models) def predict_squad(input_meta_data): diff --git a/modelzoo/LanguageModeling/BERT/scripts/finetune_train_benchmark.sh b/modelzoo/LanguageModeling/BERT/scripts/finetune_train_benchmark.sh index ba910cc..14c8087 100644 --- a/modelzoo/LanguageModeling/BERT/scripts/finetune_train_benchmark.sh +++ b/modelzoo/LanguageModeling/BERT/scripts/finetune_train_benchmark.sh @@ -74,7 +74,7 @@ $mpi_command python run_squad.py \ --train_data=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \ --vocab_file=${BERT_BASE_DIR}/vocab.txt \ --config_file=$BERT_BASE_DIR/bert_config.json \ - --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ + --init_checkpoint=",$BERT_BASE_DIR/bert_model.ckpt" \ --batch_size=$batch_size \ --model_dir=${MODEL_DIR} \ --run_eagerly=false \