Skip to content

Commit

Permalink
Merge branch 'main' into handle-dora
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Oct 6, 2024
2 parents 97d13a5 + 99f6082 commit 114ceff
Show file tree
Hide file tree
Showing 37 changed files with 2,726 additions and 170 deletions.
3 changes: 3 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial

## StableDiffusionControlNetPAGPipeline
[[autodoc]] StableDiffusionControlNetPAGPipeline

## StableDiffusionControlNetPAGInpaintPipeline
[[autodoc]] StableDiffusionControlNetPAGInpaintPipeline
- all
- __call__

Expand Down
1 change: 1 addition & 0 deletions docs/source/en/api/schedulers/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
| sgm_uniform | init with `timestep_spacing="trailing"` |
| simple | init with `timestep_spacing="trailing"` |
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` |

All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.

Expand Down
8 changes: 3 additions & 5 deletions examples/cogvideo/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import (
cast_training_params,
clear_objs_and_retain_memory,
)
from diffusers.training_utils import cast_training_params, free_memory
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
Expand Down Expand Up @@ -726,7 +723,8 @@ def log_validation(
}
)

clear_objs_and_retain_memory([pipe])
del pipe
free_memory()

return videos

Expand Down
8 changes: 5 additions & 3 deletions examples/controlnet/train_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from diffusers.models.controlnet_flux import FluxControlNetModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling
from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
Expand Down Expand Up @@ -193,7 +193,8 @@ def log_validation(
else:
logger.warning(f"image logging not implemented for {tracker.name}")

clear_objs_and_retain_memory([pipeline])
del pipeline
free_memory()
return image_logs


Expand Down Expand Up @@ -1103,7 +1104,8 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
)

clear_objs_and_retain_memory([text_encoders, tokenizers])
del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()

# Then get the training dataset ready to be passed to the dataloader.
train_dataset = prepare_train_dataset(train_dataset, accelerator)
Expand Down
13 changes: 6 additions & 7 deletions examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@
StableDiffusion3ControlNetPipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
Expand Down Expand Up @@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
else:
logger.warning(f"image logging not implemented for {tracker.name}")

clear_objs_and_retain_memory(pipeline)
del pipeline
free_memory()

if not is_final_validation:
controlnet.to(accelerator.device)
Expand Down Expand Up @@ -1131,7 +1128,9 @@ def compute_text_embeddings(batch, text_encoders, tokenizers):
new_fingerprint = Hasher.hash(args)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)

clear_objs_and_retain_memory(text_encoders + tokenizers)
del text_encoder_one, text_encoder_two, text_encoder_three
del tokenizer_one, tokenizer_two, tokenizer_three
free_memory()

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
Expand Down
11 changes: 7 additions & 4 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
Expand Down Expand Up @@ -1437,7 +1437,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):

# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1480,7 +1481,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)

if args.validation_prompt is None:
clear_objs_and_retain_memory([vae])
del vae
free_memory()

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
Expand Down Expand Up @@ -1817,7 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
del text_encoder_one, text_encoder_two
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down
20 changes: 10 additions & 10 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
Expand Down Expand Up @@ -211,7 +211,8 @@ def log_validation(
}
)

clear_objs_and_retain_memory(objs=[pipeline])
del pipeline
free_memory()

return images

Expand Down Expand Up @@ -1106,7 +1107,8 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)

clear_objs_and_retain_memory(objs=[pipeline])
del pipeline
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1453,9 +1455,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
clear_objs_and_retain_memory(
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
)
del tokenizers, text_encoders
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1791,11 +1793,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
epoch=epoch,
torch_dtype=weight_dtype,
)
objs = []
if not args.train_text_encoder:
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])

clear_objs_and_retain_memory(objs=objs)
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down
25 changes: 18 additions & 7 deletions examples/instruct_pix2pix/train_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,17 +747,22 @@ def collate_fn(examples):
)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
)

# Prepare everything with our `accelerator`.
Expand All @@ -782,8 +787,14 @@ def collate_fn(examples):

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@
"StableDiffusionAttendAndExcitePipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPAGInpaintPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetXSPipeline",
Expand Down Expand Up @@ -778,6 +779,7 @@
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetXSPipeline,
Expand Down
41 changes: 39 additions & 2 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,47 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
f"transformer.single_transformer_blocks.{i}.norm.linear",
)

remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
continue

lora_name = key.split(".")[0]
lora_name_up = f"{lora_name}.lora_up.weight"
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)

if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]
te_state_dict[diffusers_name] = down_weight
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)

if lora_name_alpha in sds_sd:
alpha = sds_sd.pop(lora_name_alpha).item()
scale = alpha / sd_lora_rank

scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

te_state_dict[diffusers_name] *= scale_down
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up

if len(sds_sd) > 0:
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")

if te_state_dict:
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}

return ait_sd
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict

return _convert_sd_scripts_to_ai_toolkit(state_dict)

Expand Down
Loading

0 comments on commit 114ceff

Please sign in to comment.