diff --git a/library/train_util.py b/library/train_util.py index ff161feab..21e7638da 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument( + "--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う" + ) parser.add_argument( "--ddp_timeout", type=int, diff --git a/train_network.py b/train_network.py index c2b7fbdef..5f28a5e0d 100644 --- a/train_network.py +++ b/train_network.py @@ -390,16 +390,36 @@ def train(self, args): accelerator.print("enable full bf16 training.") network.to(weight_dtype) + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base: + assert ( + torch.__version__ >= '2.1.0' + ), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" + assert ( + args.mixed_precision != 'no' + ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" + accelerator.print("enable fp8 training.") + unet_weight_dtype = torch.float8_e4m3fn + te_weight_dtype = torch.float8_e4m3fn + unet.requires_grad_(False) - unet.to(dtype=weight_dtype) + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=( + weight_dtype + if te_weight_dtype == torch.float8_e4m3fn + else te_weight_dtype + )) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: unet = accelerator.prepare(unet) else: - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: if len(text_encoders) > 1: text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] @@ -421,9 +441,6 @@ def train(self, args): if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - # set top parameter requires_grad = True for gradient checkpointing works - if not train_text_encoder: # train U-Net only - unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: @@ -778,10 +795,17 @@ def remove_model(old_ckpt_name): args, noise_scheduler, latents ) + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype ) if args.v_parameterization: