Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) #1057

Merged
merged 31 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ed883f2
Add fp8 support
KohakuBlueleaf Oct 24, 2023
c0c7109
remove some debug prints
KohakuBlueleaf Oct 24, 2023
9e2b99f
Better implementation for te
KohakuBlueleaf Oct 24, 2023
628b44d
Fix some misunderstanding
KohakuBlueleaf Oct 24, 2023
65da2cc
as same as unet, add explicit convert
KohakuBlueleaf Oct 24, 2023
1623d6c
better impl for convert TE to fp8
KohakuBlueleaf Oct 24, 2023
e2f32f5
fp8 for not only unet
KohakuBlueleaf Oct 24, 2023
ed82a1a
Better cache TE and TE lr
KohakuBlueleaf Oct 25, 2023
f582142
match arg name
KohakuBlueleaf Oct 25, 2023
f52cbc2
Fix with list
KohakuBlueleaf Oct 26, 2023
34f1cd6
Add timeout settings
KohakuBlueleaf Oct 26, 2023
19d6617
Fix arg style
KohakuBlueleaf Oct 26, 2023
e355764
Add custom seperator
KohakuBlueleaf Oct 30, 2023
3f670d2
Fix typo
KohakuBlueleaf Oct 30, 2023
35c32a6
Fix typo again
KohakuBlueleaf Oct 30, 2023
3f0c40a
Merge remote-tracking branch 'upstream/dev' into fp8-experiments
KohakuBlueleaf Nov 4, 2023
9be4f44
Fix dtype error
KohakuBlueleaf Nov 15, 2023
b003446
Fix gradient problem
KohakuBlueleaf Nov 15, 2023
c0f8d28
Fix req grad
KohakuBlueleaf Nov 30, 2023
9bb4fcb
Merge remote-tracking branch 'upstream/dev' into fp8-experiments
KohakuBlueleaf Nov 30, 2023
67a3ad8
fix merge
KohakuBlueleaf Nov 30, 2023
54f5f46
Fix merge
KohakuBlueleaf Nov 30, 2023
fa82f2a
Merge remote-tracking branch 'upstream/dev' into fp8-experiments
KohakuBlueleaf Dec 16, 2023
3f0414c
Merge pull request #1 from kohya-ss/dev
KohakuBlueleaf Dec 21, 2023
4664c2d
Merge branch 'kohya-ss:main' into fp8-experiments
KohakuBlueleaf Jan 6, 2024
12151c9
Merge remote-tracking branch 'upstream/dev' into fp8-experiments
KohakuBlueleaf Jan 6, 2024
b4b872e
Resolve merge
KohakuBlueleaf Jan 6, 2024
9229282
arrangement and document
KohakuBlueleaf Jan 6, 2024
3438703
Merge remote-tracking branch 'upstream/dev' into fp8-experiments
KohakuBlueleaf Jan 17, 2024
c5b9187
Resolve merge error
KohakuBlueleaf Jan 17, 2024
100f852
Add assert for mixed precision
KohakuBlueleaf Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument(
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
)
parser.add_argument(
"--ddp_timeout",
type=int,
Expand Down
36 changes: 30 additions & 6 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,16 +390,36 @@ def train(self, args):
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)

unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert (
torch.__version__ >= '2.1.0'
), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != 'no'
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn

unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(
weight_dtype
if te_weight_dtype == torch.float8_e4m3fn
else te_weight_dtype
))

# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
Expand All @@ -421,9 +441,6 @@ def train(self, args):
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)

# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
Expand Down Expand Up @@ -778,10 +795,17 @@ def remove_model(old_ckpt_name):
args, noise_scheduler, latents
)

# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)

# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype
)

if args.v_parameterization:
Expand Down