diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index 33459db7f6..b12a2f08ea 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -65,6 +65,7 @@ def train(model, save_dir='output', iters=10000, batch_size=2, + early_stop_interval=None, resume_model=None, save_interval=1000, log_iters=10, @@ -117,6 +118,8 @@ def train(model, local_rank = paddle.distributed.ParallelEnv().local_rank start_iter = 0 + stop_count = 0 + stop_status = False if resume_model is not None: start_iter = resume(model, optimizer, resume_model) @@ -171,7 +174,7 @@ def train(model, save_models = deque() batch_start = time.time() iter = start_iter - while iter < iters: + while iter < iters and not stop_status: if iter == start_iter and use_ema: init_ema_params(ema_model, model) for data in loader: @@ -354,15 +357,26 @@ def train(model, if val_dataset is not None: if mean_iou > best_mean_iou: + stop_count = 0 best_mean_iou = mean_iou best_model_iter = iter best_model_dir = os.path.join(save_dir, "best_model") paddle.save( model.state_dict(), os.path.join(best_model_dir, 'model.pdparams')) - logger.info( - '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.' - .format(best_mean_iou, best_model_iter)) + elif mean_iou < best_mean_iou: + stop_count += 1 + + if early_stop_interval is not None and stop_count >= early_stop_interval: + stop_status = True + logger.info( + 'Early stopping at iter {}. The best mean IoU is {:.4f}.' + .format(iter, best_mean_iou)) + else: + logger.info( + '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.' + .format(best_mean_iou, best_model_iter)) + if use_ema: if ema_mean_iou > best_ema_mean_iou: best_ema_mean_iou = ema_mean_iou @@ -385,6 +399,10 @@ def train(model, ema_mean_iou, iter) log_writer.add_scalar('Evaluate/Ema_Acc', ema_acc, iter) + + if stop_status: + break + batch_start = time.time() # Calculate flops. diff --git a/tools/train.py b/tools/train.py index b9ce6cf7af..b28651035b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -82,6 +82,11 @@ def parse_args(): help='Maximum number of checkpoints to save.', type=int, default=5) + parser.add_argument( + '--early_stop_intervals', + help='Early Stop at args number of save intervals.', + type=None, + default=0) # Other params parser.add_argument( @@ -187,6 +192,7 @@ def main(args): save_dir=args.save_dir, iters=cfg.iters, batch_size=cfg.batch_size, + early_stop=args.args.early_stop_intervals, resume_model=args.resume_model, save_interval=args.save_interval, log_iters=args.log_iters,