Skip to content

Commit

Permalink
Fix issue with max_train_steps
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Nov 17, 2024
1 parent d47e3e6 commit a6f0ff7
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 54 deletions.
34 changes: 17 additions & 17 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,23 +781,23 @@ def train_model(

log.info(f"Regularization factor: {reg_factor}")

if max_train_steps == 0:
# calculate max_train_steps
max_train_steps = int(
math.ceil(
float(total_steps)
/ int(train_batch_size)
/ int(gradient_accumulation_steps)
* int(epoch)
* int(reg_factor)
)
)
max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
else:
if max_train_steps == 0:
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
else:
max_train_steps_info = f"Max train steps: {max_train_steps}"
# if max_train_steps == 0:
# # calculate max_train_steps
# max_train_steps = int(
# math.ceil(
# float(total_steps)
# / int(train_batch_size)
# / int(gradient_accumulation_steps)
# * int(epoch)
# * int(reg_factor)
# )
# )
# max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
# else:
# if max_train_steps == 0:
# max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
# else:
# max_train_steps_info = f"Max train steps: {max_train_steps}"

log.info(f"Total steps: {total_steps}")

Expand Down
20 changes: 10 additions & 10 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,16 +846,16 @@ def train_model(
repeats = int(image_num) * int(dataset_repeats)
log.info(f"repeats = {str(repeats)}")

if max_train_steps == 0:
# calculate max_train_steps
max_train_steps = int(
math.ceil(
float(repeats)
/ int(train_batch_size)
/ int(gradient_accumulation_steps)
* int(epoch)
)
)
# if max_train_steps == 0:
# # calculate max_train_steps
# max_train_steps = int(
# math.ceil(
# float(repeats)
# / int(train_batch_size)
# / int(gradient_accumulation_steps)
# * int(epoch)
# )
# )

# Divide by two because flip augmentation create two copied of the source images
if flip_aug and max_train_steps:
Expand Down
12 changes: 4 additions & 8 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def train_model(

log.info(f"Regularization factor: {reg_factor}")

if max_train_steps == 0:
if (max_train_steps == 0) and (stop_text_encoder_training != 0):
# calculate max_train_steps
max_train_steps = int(
math.ceil(
Expand All @@ -1094,13 +1094,9 @@ def train_model(
else:
max_train_steps_info = f"Max train steps: {max_train_steps}"

# calculate stop encoder training
if stop_text_encoder_training == 0:
stop_text_encoder_training = 0
else:
stop_text_encoder_training = math.ceil(
float(max_train_steps) / 100 * int(stop_text_encoder_training)
)
stop_text_encoder_training = math.ceil(
float(max_train_steps) / 100 * int(stop_text_encoder_training)
) if stop_text_encoder_training != 0 else 0

# Calculate lr_warmup_steps
if lr_warmup_steps > 0:
Expand Down
21 changes: 2 additions & 19 deletions kohya_gui/textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,22 +664,9 @@ def train_model(
log.info(f"Regularization factor: {reg_factor}")

if max_train_steps == 0:
# calculate max_train_steps
max_train_steps = int(
math.ceil(
float(total_steps)
/ int(train_batch_size)
/ int(gradient_accumulation_steps)
* int(epoch)
* int(reg_factor)
)
)
max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
else:
if max_train_steps == 0:
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
else:
max_train_steps_info = f"Max train steps: {max_train_steps}"
max_train_steps_info = f"Max train steps: {max_train_steps}"

# calculate stop encoder training
if stop_text_encoder_training_pct == 0:
Expand Down Expand Up @@ -1076,10 +1063,6 @@ def list_embedding_files(path):
step=1,
label="Vectors",
)
# max_train_steps = gr.Textbox(
# label='Max train steps',
# placeholder='(Optional) Maximum number of steps',
# )
template = gr.Dropdown(
label="Template",
choices=[
Expand Down

2 comments on commit a6f0ff7

@bbecausereasonss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to have caused my issue with max_train_steps instead of fixing it.

#2969

@FurkanGozukara
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bmaltais sadly this broken the training :(

Please sign in to comment.