From 3d2bb1a8f19960dd38de90e7ca3d71a59cd62597 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 27 Oct 2023 17:49:49 +0900 Subject: [PATCH 1/5] Add LossRecorder and use moving average in all places --- fine_tune.py | 9 ++++----- library/train_util.py | 17 +++++++++++++++++ sdxl_train.py | 9 ++++----- train_db.py | 14 ++++---------- train_network.py | 14 ++++---------- 5 files changed, 33 insertions(+), 30 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 2ecb4ff36..0de4aff19 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -295,7 +295,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - loss_total = 0 + loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく @@ -405,9 +405,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) accelerator.log(logs, step=global_step) - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step + 1) + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.get_moving_average() logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -415,7 +414,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_recorder.get_moving_average()} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/library/train_util.py b/library/train_util.py index 51610e700..7f7190b37 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4685,3 +4685,20 @@ def __call__(self, examples): dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) return examples[0] + + +class LossRecorder: + def __init__(self): + self.loss_list: List[float] = [] + self.loss_total: float = 0.0 + + def add(self, *, epoch:int, step: int, loss: float) -> None: + if epoch == 0: + self.loss_list.append(loss) + else: + self.loss_total -= self.loss_list[step] + self.loss_list[step] = loss + self.loss_total += loss + + def get_moving_average(self) -> float: + return self.loss_total / len(self.loss_list) diff --git a/sdxl_train.py b/sdxl_train.py index 7bde3cab7..5e5d528d2 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -459,7 +459,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - loss_total = 0 + loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく @@ -632,9 +632,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.log(logs, step=global_step) - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step + 1) + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.get_moving_average() logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -642,7 +641,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_recorder.get_moving_average()} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_db.py b/train_db.py index a1b9cac8b..221a1e45e 100644 --- a/train_db.py +++ b/train_db.py @@ -264,8 +264,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -392,13 +391,8 @@ def train(args): ) accelerator.log(logs, step=global_step) - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.get_moving_average() logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -406,7 +400,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.get_moving_average()} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index 2232a384a..aeefe2a58 100644 --- a/train_network.py +++ b/train_network.py @@ -703,8 +703,7 @@ def train(self, args): "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # callback for step start @@ -854,13 +853,8 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.get_moving_average() logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -875,7 +869,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.get_moving_average()} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From efef5c8ead18d98770350540914bb14545509482 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 27 Oct 2023 17:59:58 +0900 Subject: [PATCH 2/5] Show "avr_loss" instead of "loss" because it is moving average --- fine_tune.py | 2 +- sdxl_train.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 0de4aff19..c5e99ad4e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -407,7 +407,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.get_moving_average() - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: diff --git a/sdxl_train.py b/sdxl_train.py index 5e5d528d2..096c89e9a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -634,7 +634,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.get_moving_average() - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: diff --git a/train_db.py b/train_db.py index 221a1e45e..112303498 100644 --- a/train_db.py +++ b/train_db.py @@ -393,7 +393,7 @@ def train(args): loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.get_moving_average() - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: diff --git a/train_network.py b/train_network.py index aeefe2a58..58f7e4451 100644 --- a/train_network.py +++ b/train_network.py @@ -855,7 +855,7 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.get_moving_average() - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.scale_weight_norms: From 0d21925bdff76c5a6e7bef1e649a43bfc1c3fd0c Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 27 Oct 2023 18:14:27 +0900 Subject: [PATCH 3/5] Use @property --- fine_tune.py | 4 ++-- library/train_util.py | 3 ++- sdxl_train.py | 4 ++-- train_db.py | 4 ++-- train_network.py | 4 ++-- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c5e99ad4e..27d647392 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -406,7 +406,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.log(logs, step=global_step) loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.get_moving_average() + avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -414,7 +414,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.get_moving_average()} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/library/train_util.py b/library/train_util.py index 7f7190b37..e86293e31 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4700,5 +4700,6 @@ def add(self, *, epoch:int, step: int, loss: float) -> None: self.loss_list[step] = loss self.loss_total += loss - def get_moving_average(self) -> float: + @property + def moving_average(self) -> float: return self.loss_total / len(self.loss_list) diff --git a/sdxl_train.py b/sdxl_train.py index 096c89e9a..9017d7b8c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -633,7 +633,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.log(logs, step=global_step) loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.get_moving_average() + avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -641,7 +641,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.get_moving_average()} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_db.py b/train_db.py index 112303498..aa741794a 100644 --- a/train_db.py +++ b/train_db.py @@ -392,7 +392,7 @@ def train(args): accelerator.log(logs, step=global_step) loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.get_moving_average() + avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -400,7 +400,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.get_moving_average()} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index 58f7e4451..c81aeff8c 100644 --- a/train_network.py +++ b/train_network.py @@ -854,7 +854,7 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.get_moving_average() + avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -869,7 +869,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.get_moving_average()} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From 9d00c8eea2aa3ce9ce292a6eef2bb862a2ec9213 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 27 Oct 2023 18:31:36 +0900 Subject: [PATCH 4/5] Use LossRecorder --- sdxl_train_control_net_lllite.py | 16 +++++----------- sdxl_train_control_net_lllite_old.py | 16 +++++----------- train_controlnet.py | 16 +++++----------- 3 files changed, 15 insertions(+), 33 deletions(-) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0df61e848..8e9752520 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -350,8 +350,7 @@ def train(args): "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -500,14 +499,9 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -518,7 +512,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 79920a972..066aca594 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -323,8 +323,7 @@ def train(args): "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -470,14 +469,9 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -488,7 +482,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_controlnet.py b/train_controlnet.py index 5bc8d399c..bbd915cb3 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -337,8 +337,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -500,14 +499,9 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -518,7 +512,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From 63992b81c840ea42b53d70d611ef27ff85ae397e Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 27 Oct 2023 21:13:29 +0900 Subject: [PATCH 5/5] Fix initialize place of loss_recorder --- fine_tune.py | 2 +- sdxl_train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 27d647392..afec7d273 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -288,6 +288,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -295,7 +296,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく diff --git a/sdxl_train.py b/sdxl_train.py index 9017d7b8c..f681f28fc 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -452,6 +452,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -459,7 +460,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく