From cf3f66dd379b5d047ef6d5735aa42c36fca16a18 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <3323290568@qq.com> Date: Tue, 7 Jun 2022 14:40:04 +0800 Subject: [PATCH] update train.py --- train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train.py b/train.py index 857a218..c9a7fbe 100644 --- a/train.py +++ b/train.py @@ -253,6 +253,8 @@ wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4 total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch if total_step <= wanted_step: + if num_train // Unfreeze_batch_size == 0: + raise ValueError('数据集过小,无法进行训练,请扩充数据集。') wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1 print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step)) print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))