From 60f543f00765e3ffdad0d7fc44b2466e48503e0f Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Thu, 2 Nov 2023 19:07:57 +0800 Subject: [PATCH 1/3] try to add early stop --- paddleseg/core/train.py | 20 +++++++++++++++----- tools/train.py | 6 ++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index 33459db7f6..ee3160d0e0 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -65,6 +65,7 @@ def train(model, save_dir='output', iters=10000, batch_size=2, + early_stop=0, resume_model=None, save_interval=1000, log_iters=10, @@ -117,6 +118,7 @@ def train(model, local_rank = paddle.distributed.ParallelEnv().local_rank start_iter = 0 + stop_count = 0 if resume_model is not None: start_iter = resume(model, optimizer, resume_model) @@ -307,7 +309,7 @@ def train(model, update_ema_model(ema_model, model, step=iter) if (iter % save_interval == 0 or - iter == iters) and (val_dataset is not None): + iter == iters) and (val_dataset is not None) and early_stop: num_workers = 1 if num_workers > 0 else 0 if test_config is None: @@ -332,7 +334,8 @@ def train(model, model.train() - if (iter % save_interval == 0 or iter == iters) and local_rank == 0: + if (iter % save_interval == 0 or + iter == iters) and local_rank == 0 and early_stop: current_save_dir = os.path.join(save_dir, "iter_{}".format(iter)) if not os.path.isdir(current_save_dir): @@ -360,9 +363,16 @@ def train(model, paddle.save( model.state_dict(), os.path.join(best_model_dir, 'model.pdparams')) - logger.info( - '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.' - .format(best_mean_iou, best_model_iter)) + elif mean_iou < best_mean_iou: + stop_count += 1 + if stop_count >= early_stop: + logger.info( + 'Early stopping at iter {}. The best mean IoU is {:.4f}.' + .format(iter, best_mean_iou)) + else: + logger.info( + '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.' + .format(best_mean_iou, best_model_iter)) if use_ema: if ema_mean_iou > best_ema_mean_iou: best_ema_mean_iou = ema_mean_iou diff --git a/tools/train.py b/tools/train.py index b9ce6cf7af..6572628af5 100644 --- a/tools/train.py +++ b/tools/train.py @@ -82,6 +82,11 @@ def parse_args(): help='Maximum number of checkpoints to save.', type=int, default=5) + parser.add_argument( + '--early_stop', + help='Whether to early stop when loss is not decreasing and max numbers.', + type=int, + default=0) # Other params parser.add_argument( @@ -187,6 +192,7 @@ def main(args): save_dir=args.save_dir, iters=cfg.iters, batch_size=cfg.batch_size, + early_stop=args.early_stop, resume_model=args.resume_model, save_interval=args.save_interval, log_iters=args.log_iters, From 206028e2c7fd0ae59475dad1f6776a5043cec285 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Thu, 2 Nov 2023 21:11:41 +0800 Subject: [PATCH 2/3] update --- paddleseg/core/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index ee3160d0e0..8c1b1ae634 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -309,7 +309,7 @@ def train(model, update_ema_model(ema_model, model, step=iter) if (iter % save_interval == 0 or - iter == iters) and (val_dataset is not None) and early_stop: + iter == iters) and (val_dataset is not None): num_workers = 1 if num_workers > 0 else 0 if test_config is None: @@ -334,8 +334,7 @@ def train(model, model.train() - if (iter % save_interval == 0 or - iter == iters) and local_rank == 0 and early_stop: + if (iter % save_interval == 0 or iter == iters) and local_rank == 0: current_save_dir = os.path.join(save_dir, "iter_{}".format(iter)) if not os.path.isdir(current_save_dir): @@ -357,6 +356,7 @@ def train(model, if val_dataset is not None: if mean_iou > best_mean_iou: + stop_count = 0 best_mean_iou = mean_iou best_model_iter = iter best_model_dir = os.path.join(save_dir, "best_model") @@ -365,7 +365,7 @@ def train(model, os.path.join(best_model_dir, 'model.pdparams')) elif mean_iou < best_mean_iou: stop_count += 1 - if stop_count >= early_stop: + if early_stop > 0 and stop_count >= early_stop: logger.info( 'Early stopping at iter {}. The best mean IoU is {:.4f}.' .format(iter, best_mean_iou)) From 98fd7435b5483b6472bc1750c3834140ee6bafb2 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Fri, 3 Nov 2023 11:24:49 +0800 Subject: [PATCH 3/3] fix --- paddleseg/core/train.py | 14 +++++++++++--- tools/train.py | 8 ++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index 8c1b1ae634..b12a2f08ea 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -65,7 +65,7 @@ def train(model, save_dir='output', iters=10000, batch_size=2, - early_stop=0, + early_stop_interval=None, resume_model=None, save_interval=1000, log_iters=10, @@ -119,6 +119,7 @@ def train(model, start_iter = 0 stop_count = 0 + stop_status = False if resume_model is not None: start_iter = resume(model, optimizer, resume_model) @@ -173,7 +174,7 @@ def train(model, save_models = deque() batch_start = time.time() iter = start_iter - while iter < iters: + while iter < iters and not stop_status: if iter == start_iter and use_ema: init_ema_params(ema_model, model) for data in loader: @@ -365,7 +366,9 @@ def train(model, os.path.join(best_model_dir, 'model.pdparams')) elif mean_iou < best_mean_iou: stop_count += 1 - if early_stop > 0 and stop_count >= early_stop: + + if early_stop_interval is not None and stop_count >= early_stop_interval: + stop_status = True logger.info( 'Early stopping at iter {}. The best mean IoU is {:.4f}.' .format(iter, best_mean_iou)) @@ -373,6 +376,7 @@ def train(model, logger.info( '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.' .format(best_mean_iou, best_model_iter)) + if use_ema: if ema_mean_iou > best_ema_mean_iou: best_ema_mean_iou = ema_mean_iou @@ -395,6 +399,10 @@ def train(model, ema_mean_iou, iter) log_writer.add_scalar('Evaluate/Ema_Acc', ema_acc, iter) + + if stop_status: + break + batch_start = time.time() # Calculate flops. diff --git a/tools/train.py b/tools/train.py index 6572628af5..b28651035b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -83,9 +83,9 @@ def parse_args(): type=int, default=5) parser.add_argument( - '--early_stop', - help='Whether to early stop when loss is not decreasing and max numbers.', - type=int, + '--early_stop_intervals', + help='Early Stop at args number of save intervals.', + type=None, default=0) # Other params @@ -192,7 +192,7 @@ def main(args): save_dir=args.save_dir, iters=cfg.iters, batch_size=cfg.batch_size, - early_stop=args.early_stop, + early_stop=args.args.early_stop_intervals, resume_model=args.resume_model, save_interval=args.save_interval, log_iters=args.log_iters,