Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT PR] added AIG for SD and documentation #286

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion docs/user-guide/draftp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,159 @@ DRaFT+ Results

Once you have completed fine-tuning Stable Diffusion with DRaFT+, you can run inference on your saved model using the `sd_infer.py <https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py>`__
and `sd_lora_infer.py <https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py>`__ scripts from the NeMo codebase. The generated images with the fine-tuned model should have
better prompt alignment and aesthetic quality.
better prompt alignment and aesthetic quality.

User controllable finetuning with Annealed Importance Guidance (AIG)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

AIG provides the inference-time flexibility to interpolate between the base Stable Diffusion model (with low rewards and high diversity) and DRaFT-finetuned model (with high rewards and low diversity) to obtain images with high rewards and high diversity. AIG inference is easily done by specifying comma-separated `weight_type` strategies to interpolate between the base and finetuned model.

.. tab-set::
.. tab-item:: AIG on Stable Diffusion XL
:sync: key2

Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case).
Weight type of the form `power_<float>` interpolates using an exponential decay specified in the AIG paper.

To run AIG inference on the terminal directly:

.. code-block:: bash

NUMNODES=1
LR=${LR:=0.00025}
INF_STEPS=${INF_STEPS:=25}
KL_COEF=${KL_COEF:=0.1}
ETA=${ETA:=0.0}
DATASET=${DATASET:="pickapic50k.tar"}
MICRO_BS=${MICRO_BS:=1}
GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4}
PEFT=${PEFT:="sdlora"}
NUM_DEVICES=${NUM_DEVICES:=8}
GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES))
LOG_WANDB=${LOG_WANDB:="False"}

echo "additional kwargs: ${ADDITIONAL_KWARGS}"

WANDB_NAME=SDXL_Draft_annealing
WEBDATASET_PATH=/path/to/${DATASET}

CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf"
CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"}
UNET_CKPT="/path/to/unet.ckpt"
VAE_CKPT="/path/to/vae.ckpt"
RM_CKPT="/path/to/reward_model.nemo"
PROMPT=${PROMPT:="Bananas growing on an apple tree"}
DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir

if [ ! -z "${ACT_CKPT}" ]; then
ACT_CKPT="model.activation_checkpointing=$ACT_CKPT "
echo $ACT_CKPT
fi

EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"}
export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1
set -x
CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \
--config-path=${CONFIG_PATH} \
--config-name=${CONFIG_NAME} \
model.optim.lr=${LR} \
model.optim.weight_decay=0.0005 \
model.optim.sched.warmup_steps=0 \
model.sampling.base.steps=${INF_STEPS} \
model.kl_coeff=${KL_COEF} \
model.truncation_steps=1 \
trainer.draftp_sd.max_epochs=5 \
trainer.draftp_sd.max_steps=10000 \
trainer.draftp_sd.save_interval=200 \
trainer.draftp_sd.val_check_interval=20 \
trainer.draftp_sd.gradient_clip_val=10.0 \
model.micro_batch_size=${MICRO_BS} \
model.global_batch_size=${GLOBAL_BATCH_SIZE} \
model.peft.peft_scheme=${PEFT} \
model.data.webdataset.local_root_path=$WEBDATASET_PATH \
rm.model.restore_from_path=${RM_CKPT} \
trainer.devices=${NUM_DEVICES} \
trainer.num_nodes=${NUMNODES} \
rm.trainer.devices=${NUM_DEVICES} \
rm.trainer.num_nodes=${NUMNODES} \
+prompt="${PROMPT}" \
exp_manager.create_wandb_logger=${LOG_WANDB} \
model.first_stage_config.from_pretrained=${VAE_CKPT} \
model.first_stage_config.from_NeMo=True \
model.unet_config.from_pretrained=${UNET_CKPT} \
model.unet_config.from_NeMo=True \
$ACT_CKPT \
exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \
exp_manager.resume_if_exists=True \
exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \
exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0'

.. tab-item:: AIG on Stable Diffusion v1.1 - v1.5
:sync: key

Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case).
Weight type of the form `power_<float>` interpolates using an exponential decay specified in the AIG paper.

To run AIG inference on the terminal directly:

.. code-block:: bash

LR=${LR:=0.00025}
INF_STEPS=${INF_STEPS:=25}
KL_COEF=${KL_COEF:=0.1}
ETA=${ETA:=0.0}
DATASET=${DATASET:="pickapic50k.tar"}
MICRO_BS=${MICRO_BS:=2}
GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4}
PEFT=${PEFT:="sdlora"}
NUM_DEVICES=8
GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION))

WANDB_NAME=SD_DRaFT_annealing
WEBDATASET_PATH=/path/to/${DATASET}

CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf"
CONFIG_NAME="draftp_sd"
UNET_CKPT="/path/to/unet.ckpt"
VAE_CKPT="/path/to/vae.ckpt"
RM_CKPT="/path/to/rewardmodel.nemo"

# change this as an end-user
PROMPT=${PROMPT:-"Bananas growing on an apple tree"}

EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sd.py"}
set -x
DEVICE="0,1,2,3,4,5,6,7"
echo "Running DRaFT on ${DEVICE}"
export HYDRA_FULL_ERROR=1 \
&& MASTER_PORT=15003 CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=${NUM_DEVICES} /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \
--config-path=${CONFIG_PATH} \
--config-name=${CONFIG_NAME} \
model.optim.lr=${LR} \
model.optim.weight_decay=0.005 \
model.optim.sched.warmup_steps=0 \
model.infer.inference_steps=${INF_STEPS} \
model.infer.eta=0.0 \
model.kl_coeff=${KL_COEF} \
model.truncation_steps=1 \
trainer.draftp_sd.max_epochs=1 \
trainer.draftp_sd.max_steps=4000 \
trainer.draftp_sd.save_interval=100 \
model.unet_config.from_pretrained=${UNET_CKPT} \
model.first_stage_config.from_pretrained=${VAE_CKPT} \
model.micro_batch_size=${MICRO_BS} \
model.global_batch_size=${GLOBAL_BATCH_SIZE} \
model.peft.peft_scheme=${PEFT} \
model.data.webdataset.local_root_path=$WEBDATASET_PATH \
rm.model.restore_from_path=${RM_CKPT} \
+prompt="${PROMPT}" \
trainer.draftp_sd.val_check_interval=20 \
trainer.draftp_sd.gradient_clip_val=10.0 \
trainer.devices=${NUM_DEVICES} \
rm.trainer.devices=${NUM_DEVICES} \
exp_manager.create_wandb_logger=True \
exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \
exp_manager.resume_if_exists=True \
exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \
exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0'

215 changes: 215 additions & 0 deletions examples/mm/stable_diffusion/anneal_sd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from copy import deepcopy
from functools import partial

import numpy as np
import torch
import torch.distributed
import torch.multiprocessing as mp
from megatron.core import parallel_state
from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name
from megatron.core.utils import divide
from omegaconf.omegaconf import OmegaConf, open_dict
from packaging.version import Version
from PIL import Image
from torch import nn

from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronStableDiffusionTrainerBuilder
from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.supervised import SupervisedTrainer
from nemo_aligner.data.mm import text_webdataset
from nemo_aligner.data.nlp.builders import build_dataloader
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model
from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
add_custom_checkpoint_callback,
extract_optimizer_scheduler_from_ptl_model,
init_distributed,
init_peft,
init_using_ptl,
retrieve_custom_trainer_state_dict,
temp_pop_from_config,
)

mp.set_start_method("spawn", force=True)


def resolve_and_create_trainer(cfg, pop_trainer_key):
"""resolve the cfg, remove the key before constructing the PTL trainer
and then restore it after
"""
OmegaConf.resolve(cfg)
with temp_pop_from_config(cfg.trainer, pop_trainer_key):
return MegatronStableDiffusionTrainerBuilder(cfg).create_trainer()


@hydra_runner(config_path="conf", config_name="draftp_sd")
def main(cfg) -> None:

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")

# set cuda device for each process
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

cfg.exp_manager.create_wandb_logger = False

if Version(torch.__version__) >= Version("1.12"):
torch.backends.cuda.matmul.allow_tf32 = True
cfg.model.data.train.dataset_path = [cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices)]
cfg.model.data.validation.dataset_path = [
cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices)
]

trainer = resolve_and_create_trainer(cfg, "draftp_sd")
exp_manager(trainer, cfg.exp_manager)
logger = CustomLoggerWrapper(trainer.loggers)
# Instatiating the model here
ptl_model = MegatronSDDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device())
init_peft(ptl_model, cfg.model)

trainer_restore_path = trainer.ckpt_path

if trainer_restore_path is not None:
custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer)
consumed_samples = custom_trainer_state_dict["consumed_samples"]
else:
custom_trainer_state_dict = None
consumed_samples = 0

init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False))

train_ds, validation_ds = text_webdataset.build_train_valid_datasets(
cfg.model.data, consumed_samples=consumed_samples
)
validation_ds = [d["captions"] for d in list(validation_ds)]

val_dataloader = build_dataloader(
cfg,
dataset=validation_ds,
consumed_samples=consumed_samples,
mbs=cfg.model.micro_batch_size,
gbs=cfg.model.global_batch_size,
load_gbs=True,
)

init_using_ptl(trainer, ptl_model, val_dataloader, validation_ds)

optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model)

ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model)

logger.log_hyperparams(OmegaConf.to_container(cfg))

reward_model = get_reward_model(cfg.rm, mbs=cfg.model.micro_batch_size, gbs=cfg.model.global_batch_size)
ptl_model.reward_model = reward_model

ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model)
timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:12:00:00"))

draft_p_trainer = SupervisedTrainer(
cfg=cfg.trainer.draftp_sd,
model=ptl_model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=val_dataloader,
val_dataloader=val_dataloader,
test_dataloader=[],
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)

if custom_trainer_state_dict is not None:
draft_p_trainer.load_state_dict(custom_trainer_state_dict)

# Run annealed guidance
if cfg.get("prompt") is not None:
logging.info(f"Override val dataset with custom prompt: {cfg.prompt}")
val_dataloader = [[cfg.prompt]]

wt_types = cfg.get("weight_type", None)
if wt_types is None:
wt_types = ["base", "draft", "linear", "power_2", "power_4", "step_0.6"]
else:
wt_types = wt_types.split(",") if isinstance(wt_types, str) else wt_types
logging.info(f"Running on types: {wt_types}")

# run for all weight types
for wt_type in wt_types:
global_idx = 0
if wt_type is None or wt_type == "base":
# dummy function that assigns a value of 0 all the time
logging.info("using the base model")
wt_draft = lambda sigma, sigma_next, i, total: 0
else:
if wt_type == "linear":
wt_draft = lambda sigma, sigma_next, i, total: i * 1.0 / total
elif wt_type == "draft":
wt_draft = lambda sigma, sigma_next, i, total: 1
elif wt_type.startswith("power"): # its of the form power_{power}
pow = float(wt_type.split("_")[1])
wt_draft = lambda sigma, sigma_next, i, total: (i * 1.0 / total) ** pow
elif wt_type.startswith("step"): # use a step function (step_{p})
frac = float(wt_type.split("_")[1])
wt_draft = lambda sigma, sigma_next, i, total: float((i * 1.0 / total) >= frac)
else:
raise ValueError(f"invalid weighing type: {wt_type}")
logging.info(f"using weighing type for annealed outputs: {wt_type}.")

# initialize generator
gen = torch.Generator(device="cpu")
gen.manual_seed((1243 + 1247837 * local_rank) % (int(2 ** 32 - 1)))
os.makedirs(f"./annealed_outputs_sd_{wt_type}/", exist_ok=True)

for batch in val_dataloader:
batch_size = len(batch)
with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
latents = torch.randn(
[
batch_size,
ptl_model.in_channels,
ptl_model.height // ptl_model.downsampling_factor,
ptl_model.width // ptl_model.downsampling_factor,
],
generator=gen,
).to(torch.cuda.current_device())
images = ptl_model.annealed_guidance(batch, latents, weighing_fn=wt_draft)
images = (
images.permute(0, 2, 3, 1).detach().float().cpu().numpy().astype(np.uint8)
) # outputs are already scaled from [0, 255]
# save to pil
for i in range(images.shape[0]):
i = i + global_idx
img_path = f"annealed_outputs_sd_{wt_type}/img_{i:05d}_{local_rank:02d}.png"
prompt_path = f"annealed_outputs_sd_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt"
Image.fromarray(images[i]).save(img_path)
with open(prompt_path, "w") as fi:
fi.write(batch[i])
# increment global index
global_idx += batch_size
logging.info("Saved all images.")


if __name__ == "__main__":
main()
Loading