Skip to content

Commit

Permalink
inpainting stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
pesser committed Sep 5, 2022
1 parent 693e713 commit bbb5298
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 47 deletions.
5 changes: 4 additions & 1 deletion ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,10 @@ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
if "shape" in kwargs:
shape = kwargs.pop("shape")
else:
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
shape, cond, verbose=False, **kwargs)

Expand Down
149 changes: 104 additions & 45 deletions scripts/demo/inpainting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import streamlit as st
import torch
import cv2
Expand Down Expand Up @@ -48,6 +49,7 @@ def sample(
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
shape=(self.channels, H//8, W//8)
)
samples = self.decode_first_stage(samples_cfg)
else:
Expand All @@ -61,10 +63,6 @@ def np2batch(
image,
mask,
txt):
print("###")
print(image.shape)
print(mask.shape)
print("###")
# image hwc in -1 1
image = torch.from_numpy(image).to(dtype=torch.float32)/127.5-1.0

Expand Down Expand Up @@ -119,6 +117,7 @@ def run(
#ckpt="/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-07-28T07-44-05_v1-finetune-for-inpainting-laion-aesthetic-larger-masks/checkpoints/last.ckpt",
ckpt="/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-08-01T08-52-14_v1-finetune-for-inpainting-laion-aesthetic-larger-masks-and-ucfg/checkpoints/last.ckpt",
):
st.set_page_config(layout="wide")
st.title("Stable Inpainting")
state = init()

Expand All @@ -131,15 +130,50 @@ def run(
if uploaded_file is not None:
image = Image.open(io.BytesIO(uploaded_file.getvalue())).convert("RGB")
width, height = image.size
smaller = min(width, height)
crop = (
(width-smaller)//2,
(height-smaller)//2,
(width-smaller)//2+smaller,
(height-smaller)//2+smaller,
)
image = image.crop(crop)
image = image.resize((512, 512))
orig_width, orig_height = image.size
resize = st.selectbox("Resize", ["padtop", "crop", "keepar"])
if resize=="crop":
smaller = min(width, height)
crop = (
(width-smaller)//2,
(height-smaller)//2,
(width-smaller)//2+smaller,
(height-smaller)//2+smaller,
)
image = image.crop(crop)
image = image.resize((512, 512))
elif resize=="padtop":
pad = max(width, height)-min(width, height)
padh = max(0, width - height)
padw = max(0, height - width)

full = np.zeros((height+padh, width+padw, 3), dtype=np.uint8)
print(full.shape)
image = np.array(image)
full[padh:, padw:, :] = image
image = full
image = Image.fromarray(image)
image = image.resize((512, 512))
invalidh = int(math.ceil(512/(height+padh)*padh))+1
elif resize=="keepar":
target_size = 512
ar = height/width
if width < height:
target_width = 512
target_height = target_width/width*height
else:
target_height = 512
target_width = target_height/height*width

mod = 16
target_height = mod*round(target_height/mod)
target_width = mod*round(target_width/mod)

image = image.resize((target_width, target_height))

width, height = image.size
print(width, height)

#st.write("Uploaded Image")
#st.image(image)

Expand All @@ -150,11 +184,10 @@ def run(
stroke_width=stroke_width,
stroke_color="rgb(0, 0, 0)",
background_color="rgb(0, 0, 0)",
background_image=image if image is not None else Image.fromarray(255*np.ones((512,512,3),
dtype=np.uint8)),
background_image=image,
update_streamlit=False,
height=image.size[1] if image is not None else 512,
width=image.size[0] if image is not None else 512,
height=height,
width=width,
drawing_mode="freedraw",
point_display_radius=0,
key="canvas",
Expand All @@ -163,6 +196,8 @@ def run(
mask = canvas_result.image_data
mask = np.array(mask)[:,:,[3,3,3]]
mask = mask > 127
if resize == "padtop":
mask[:invalidh, :] = True

# visualize
bdry = cv2.dilate(mask.astype(np.uint8), np.ones((3,3), dtype=np.uint8))
Expand All @@ -180,34 +215,58 @@ def run(
t_total = int(st.number_input("Diffusion steps", value=50))

if st.button("Sample"):
st.text("Sampling")
batch_progress = st.progress(0)
batch_total = 3
t_progress = st.progress(0)
result = st.empty()
#canvas = make_canvas(2, 3)
def callback(x, batch, t):
#result.text(f"{batch}, {t}")
batch_progress.progress(min(1.0, (batch+1)/batch_total))
t_progress.progress(min(1.0, (t+1)/t_total))
update_canvas(canvas, x, batch)
result.image(canvas)

samples = sample(
state["model"],
prompt,
n_runs=3,
n_samples=2,
H=512,
W=512,
scale=scale,
ddim_steps=t_total,
callback=callback,
image=np.array(image),
mask=np.array(mask),
)
st.text("Samples")
st.image(samples[0])
with torch.inference_mode():
with torch.autocast("cuda"):
st.text("Sampling")
batch_progress = st.progress(0)
batch_total = 3
t_progress = st.progress(0)
result = st.empty()
#canvas = make_canvas(2, 3)
def callback(x, batch, t):
#result.text(f"{batch}, {t}")
batch_progress.progress(min(1.0, (batch+1)/batch_total))
t_progress.progress(min(1.0, (t+1)/t_total))
update_canvas(canvas, x, batch)
result.image(canvas)

samples = sample(
state["model"],
prompt,
n_runs=3,
n_samples=2,
H=height,
W=width,
scale=scale,
ddim_steps=t_total,
callback=callback,
image=np.array(image),
mask=np.array(mask),
)
st.text("Samples")
st.image(samples[0])

orig = samples[0]

if resize=="padtop":
orig = Image.fromarray(orig)
orig = orig.resize((orig_width+padw, orig_height+padh))
orig = np.array(orig)
orig = orig[padh:, padw:]
else:
orig = Image.fromarray(orig)
orig = orig.resize((orig_width, orig_height))
orig = np.array(orig)

orig = Image.fromarray(orig).save("tmp.png")
with open("tmp.png", "rb") as f:
st.download_button(
"Original Image",
data=f,
file_name=prompt.replace(" ", "_")+".png",
mime=f"image/png",
)



if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion scripts/inpaint_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def make_batch_sd(

image = torch.clamp((batch["jpg"]+1.0)/2.0,
min=0.0, max=1.0)
mask = torch.clamp((batch["mask"]+1.0)/2.0,
mask = torch.clamp(batch["mask"],
min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
min=0.0, max=1.0)
Expand Down
32 changes: 32 additions & 0 deletions scripts/slurm/eval_inpainting/launcher.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

# mpi version for node rank
H=`hostname`
THEID=`echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]"`
export NODE_RANK=${THEID}
echo THEID=$THEID

echo "##########################################"
echo MASTER_ADDR=${MASTER_ADDR}
echo MASTER_PORT=${MASTER_PORT}
echo NODE_RANK=${NODE_RANK}
echo WORLD_SIZE=${WORLD_SIZE}
echo CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}
echo SLURM_PROCID=${SLURM_PROCID}
echo "##########################################"
# debug environment worked great so we stick with it
# no magic there, just a miniconda python=3.9, pytorch=1.12, cudatoolkit=11.3
# env with pip dependencies from stable diffusion's requirements.txt
eval "$(/fsx/stable-diffusion/debug/miniconda3/bin/conda shell.bash hook)"
#conda activate stable
conda activate torch111
cd /fsx/stable-diffusion/stable-diffusion

#/bin/bash /fsx/stable-diffusion/stable-diffusion/scripts/test_gpu.sh

EXTRA="--indir /fsx/stable-diffusion/data/eval-inpainting/random_thick_512 --worldsize 8 --rank ${SLURM_PROCID}"
EXTRA="${EXTRA} --ckpt ${1} --outdir /fsx/stable-diffusion/stable-diffusion/inpainting-eval-results/${2}"

echo "Running ${EXTRA}"
cd /fsx/stable-diffusion/stable-diffusion/
python scripts/inpaint_sd.py ${EXTRA}
43 changes: 43 additions & 0 deletions scripts/slurm/eval_inpainting/sbatch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash
#SBATCH --partition=gpu
#SBATCH --job-name=stable-diffusion-inpainting-eval
#SBATCH --nodes 1
#SBATCH --ntasks-per-node 8
#SBATCH --cpus-per-gpu=4
#SBATCH --gpus-per-task=1
#SBATCH --exclusive
#SBATCH --output=%x_%j.out
#SBATCH --comment=stablediffusion
#SBATCH --no-requeue

module load intelmpi
source /opt/intel/mpi/latest/env/vars.sh
export LD_LIBRARY_PATH=/opt/aws-ofi-nccl/lib:/opt/amazon/efa/lib64:/usr/local/cuda-11.0/efa/lib:/usr/local/cuda-11.0/lib:/usr/local/cuda-11.0/lib64:/usr/local/cuda-11.0:/opt/nccl/build/lib:/opt/aws-ofi-nccl-install/lib:/opt/aws-ofi-nccl/lib:$LD_LIBRARY_PATH
export NCCL_PROTO=simple
export PATH=/opt/amazon/efa/bin:$PATH
export LD_PRELOAD="/opt/nccl/build/lib/libnccl.so"
export FI_EFA_FORK_SAFE=1
export FI_LOG_LEVEL=1
export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn
export NCCL_DEBUG=info
export PYTHONFAULTHANDLER=1
export CUDA_LAUNCH_BLOCKING=0
export OMPI_MCA_mtl_base_verbose=1
export FI_EFA_ENABLE_SHM_TRANSFER=0
export FI_PROVIDER=efa
export FI_EFA_TX_MIN_CREDITS=64
export NCCL_TREE_THRESHOLD=0

# sent to sub script
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=12802
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
export WORLD_SIZE=$COUNT_NODE

echo go $COUNT_NODE
echo $HOSTNAMES
echo $WORLD_SIZE

echo "Starting"
srun --comment stablediffusion --mpi=pmix_v3 /fsx/stable-diffusion/stable-diffusion/scripts/slurm/eval_inpainting/launcher.sh $ckpt $outdir

0 comments on commit bbb5298

Please sign in to comment.