Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New features] add early stop in training #3558

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions paddleseg/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def train(model,
save_dir='output',
iters=10000,
batch_size=2,
early_stop_interval=None,
resume_model=None,
save_interval=1000,
log_iters=10,
Expand Down Expand Up @@ -117,6 +118,8 @@ def train(model,
local_rank = paddle.distributed.ParallelEnv().local_rank

start_iter = 0
stop_count = 0
stop_status = False
if resume_model is not None:
start_iter = resume(model, optimizer, resume_model)

Expand Down Expand Up @@ -171,7 +174,7 @@ def train(model,
save_models = deque()
batch_start = time.time()
iter = start_iter
while iter < iters:
while iter < iters and not stop_status:
if iter == start_iter and use_ema:
init_ema_params(ema_model, model)
for data in loader:
Expand Down Expand Up @@ -354,15 +357,26 @@ def train(model,

if val_dataset is not None:
if mean_iou > best_mean_iou:
stop_count = 0
best_mean_iou = mean_iou
best_model_iter = iter
best_model_dir = os.path.join(save_dir, "best_model")
paddle.save(
model.state_dict(),
os.path.join(best_model_dir, 'model.pdparams'))
logger.info(
'[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
.format(best_mean_iou, best_model_iter))
elif mean_iou < best_mean_iou:
stop_count += 1

if early_stop_interval is not None and stop_count >= early_stop_interval:
stop_status = True
logger.info(
'Early stopping at iter {}. The best mean IoU is {:.4f}.'
.format(iter, best_mean_iou))
else:
logger.info(
'[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
.format(best_mean_iou, best_model_iter))

if use_ema:
if ema_mean_iou > best_ema_mean_iou:
best_ema_mean_iou = ema_mean_iou
Expand All @@ -385,6 +399,10 @@ def train(model,
ema_mean_iou, iter)
log_writer.add_scalar('Evaluate/Ema_Acc', ema_acc,
iter)

if stop_status:
break

batch_start = time.time()

# Calculate flops.
Expand Down
6 changes: 6 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def parse_args():
help='Maximum number of checkpoints to save.',
type=int,
default=5)
parser.add_argument(
'--early_stop_intervals',
help='Early Stop at args number of save intervals.',
type=None,
default=0)

# Other params
parser.add_argument(
Expand Down Expand Up @@ -187,6 +192,7 @@ def main(args):
save_dir=args.save_dir,
iters=cfg.iters,
batch_size=cfg.batch_size,
early_stop=args.args.early_stop_intervals,
resume_model=args.resume_model,
save_interval=args.save_interval,
log_iters=args.log_iters,
Expand Down