Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show moving average loss in the progress bar #899

Merged
merged 5 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for m in training_models:
m.train()

loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
Expand Down Expand Up @@ -405,17 +405,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)
accelerator.log(logs, step=global_step)

# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
18 changes: 18 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4685,3 +4685,21 @@ def __call__(self, examples):
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]


class LossRecorder:
def __init__(self):
self.loss_list: List[float] = []
self.loss_total: float = 0.0

def add(self, *, epoch:int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss

@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
11 changes: 5 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for m in training_models:
m.train()

loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
Expand Down Expand Up @@ -632,17 +632,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.log(logs, step=global_step)

# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ def train(args):
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -500,14 +499,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -518,7 +512,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ def train(args):
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -470,14 +469,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -488,7 +482,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -500,14 +499,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -518,7 +512,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down Expand Up @@ -392,21 +391,16 @@ def train(args):
)
accelerator.log(logs, step=global_step)

if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,7 @@ def train(self, args):
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# callback for step start
Expand Down Expand Up @@ -854,14 +853,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.scale_weight_norms:
Expand All @@ -875,7 +869,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down