Skip to content

Commit

Permalink
Merge pull request #899 from shirayu/use_moving_average
Browse files Browse the repository at this point in the history
Show moving average loss in the progress bar
  • Loading branch information
kohya-ss authored Oct 29, 2023
2 parents fb97a7a + 63992b8 commit 9d6a5a0
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 67 deletions.
11 changes: 5 additions & 6 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for m in training_models:
m.train()

loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
Expand Down Expand Up @@ -408,17 +408,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)
accelerator.log(logs, step=global_step)

# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
18 changes: 18 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4697,3 +4697,21 @@ def __call__(self, examples):
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]


class LossRecorder:
def __init__(self):
self.loss_list: List[float] = []
self.loss_total: float = 0.0

def add(self, *, epoch:int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss

@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
11 changes: 5 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,14 +451,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for m in training_models:
m.train()

loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
Expand Down Expand Up @@ -633,17 +633,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.log(logs, step=global_step)

# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,7 @@ def train(args):
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -503,14 +502,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -521,7 +515,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,7 @@ def train(args):
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -473,14 +472,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -491,7 +485,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -500,14 +499,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -518,7 +512,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down Expand Up @@ -395,21 +394,16 @@ def train(args):
)
accelerator.log(logs, step=global_step)

if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,7 @@ def train(self, args):
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# callback for step start
Expand Down Expand Up @@ -863,14 +862,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.scale_weight_norms:
Expand All @@ -884,7 +878,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down

0 comments on commit 9d6a5a0

Please sign in to comment.