Skip to content

Commit

Permalink
release training code (#99)
Browse files Browse the repository at this point in the history
Co-authored-by: SuBazinga <suqingkun@gmail.com>
  • Loading branch information
ShenhaoZhu and subazinga authored May 1, 2024
1 parent ad5c78f commit 9a88e62
Show file tree
Hide file tree
Showing 13 changed files with 3,053 additions and 0 deletions.
75 changes: 75 additions & 0 deletions configs/train/stage1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
exp_name: 'stage1'
output_dir: './exp_output'
seed: 42
resume_from_checkpoint: ''

checkpointing_steps: 2000
save_model_epoch_interval: 20

data:
train_bs: 4
video_folder: '' # Your data root folder
guids:
- 'depth'
- 'normal'
- 'semantic_map'
- 'dwpose'
image_size: 768
bbox_crop: false
bbox_resize_ratio: [0.9, 1.5]
aug_type: "Resize"
data_parts:
- "all"
sample_margin: 30

validation:
validation_steps: 1000
ref_images:
- validation_data/ref_images/val-0.png
guidance_folders:
- validation_data/guid_sequences/0
guidance_indexes: [0, 30, 60, 90, 120]

solver:
gradient_accumulation_steps: 1
mixed_precision: 'fp16'
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: False
max_train_steps: 100000 # 50000
max_grad_norm: 1.0
# lr
learning_rate: 1.0e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: 'constant'

# optimizer
use_8bit_adam: False
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "scaled_linear"
steps_offset: 1
clip_sample: false

guidance_encoder_kwargs:
guidance_embedding_channels: 320
guidance_input_channels: 3
block_out_channels: [16, 32, 96, 256]

base_model_path: 'pretrained_models/stable-diffusion-v1-5'
vae_model_path: 'pretrained_models/sd-vae-ft-mse'
image_encoder_path: 'pretrained_models/image_encoder'

weight_dtype: 'fp16' # [fp16, fp32]
uncond_ratio: 0.1
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True

103 changes: 103 additions & 0 deletions configs/train/stage2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
exp_name: 'stage2'
output_dir: './exp_output'
seed: 42
resume_from_checkpoint: ''

stage1_ckpt_step: 'latest'
stage1_ckpt_dir: '' # stage1 checkpoint folder

checkpointing_steps: 2000
save_model_epoch_interval: 20

data:
train_bs: 1
video_folder: '' # Your data root folder
guids:
- 'depth'
- 'normal'
- 'semantic_map'
- 'dwpose'
image_size: 512
bbox_crop: false
bbox_resize_ratio: [0.9, 1.5]
aug_type: "Resize"
data_parts:
- "all"
sample_frames: 24
sample_rate: 4

validation:
validation_steps: 1000
clip_length: 24
ref_images:
- validation_data/ref_images/val-1.png
guidance_folders:
- validation_data/guid_sequences/0
guidance_indexes: [0, 30, 60, 90, 120]

solver:
gradient_accumulation_steps: 1
mixed_precision: 'fp16'
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: True
max_train_steps: 50000
max_grad_norm: 1.0
# lr
learning_rate: 1e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: 'constant'

# optimizer
use_8bit_adam: True
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
steps_offset: 1
clip_sample: false

guidance_encoder_kwargs:
guidance_embedding_channels: 320
guidance_input_channels: 3
block_out_channels: [16, 32, 96, 256]

unet_additional_kwargs:
use_inflated_groupnorm: true
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: true
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1

base_model_path: 'pretrained_models/stable-diffusion-v1-5'
vae_model_path: 'pretrained_models/sd-vae-ft-mse'
image_encoder_path: 'pretrained_models/image_encoder'
mm_path: './pretrained_models/mm_sd_v15_v2.ckpt'

weight_dtype: 'fp16' # [fp16, fp32]
uncond_ratio: 0.1
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
Empty file added datasets/__init__.py
Empty file.
76 changes: 76 additions & 0 deletions datasets/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
import json
import random
from typing import List
import csv
import glob
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
from tqdm import tqdm


def process_bbox(bbox, H, W, scale=1.):
# transform a bbox(xmin, ymin, xmax, ymax) to (H, W) square
x_min, y_min, x_max, y_max = bbox
width = x_max - x_min
height = y_max - y_min

side_length = max(width, height)

center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2

scaled_side_length = side_length * scale
scaled_xmin = center_x - scaled_side_length / 2
scaled_xmax = center_x + scaled_side_length / 2
scaled_ymin = center_y - scaled_side_length / 2
scaled_ymax = center_y + scaled_side_length / 2

scaled_xmin = int(max(0, scaled_xmin))
scaled_xmax = int(min(W, scaled_xmax))
scaled_ymin = int(max(0, scaled_ymin))
scaled_ymax = int(min(H, scaled_ymax))

return scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax

def crop_bbox(img, bbox, do_resize=False, size=512):

if isinstance(img, (Path, str)):
img = Image.open(img)
cropped_img = img.crop(bbox)
if do_resize:
cropped_W, cropped_H = cropped_img.size
ratio = size / max(cropped_W, cropped_H)
new_W = cropped_W * ratio
new_H = cropped_H * ratio
cropped_img = cropped_img.resize((new_W, new_H))

return cropped_img

def mask_to_bbox(mask_path):
mask = np.array(Image.open(mask_path))[..., 0]
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)

ymin, ymax = np.where(rows)[0][[0, -1]]
xmin, xmax = np.where(cols)[0][[0, -1]]
return xmin, ymin, xmax, ymax

def mask_to_bkgd(img_path, mask_path):
img = Image.open(img_path)
img_array = np.array(img)

mask = Image.open(mask_path).convert("RGB")
mask_array = np.array(mask)

img_array = np.where(mask_array > 0, img_array, 0)
return Image.fromarray(img_array)

Loading

0 comments on commit 9a88e62

Please sign in to comment.