-
Notifications
You must be signed in to change notification settings - Fork 10
/
run_gradio.py
246 lines (179 loc) · 11.6 KB
/
run_gradio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import skimage
from PIL import Image
import gradio as gr
from utils.render import PointsRendererWithMasks, render
from utils.ops import snap_high_gradients_to_nn, project_points, get_pointcloud, merge_pointclouds, outpaint_with_depth_estimation
from utils.gs import gs_options, read_cameras_from_optimization_bundle, Scene, run_gaussian_splatting, get_blank_gs_bundle
from pytorch3d.utils import opencv_from_cameras_projection
from utils.ops import focal2fov, fov2focal
from utils.models import infer_with_zoe_dc
from utils.scene import GaussianModel
from utils.demo import downsample_point_cloud
from typing import Iterable, Tuple, Dict, Optional
import itertools
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
look_at_view_transform,
PerspectiveCameras,
)
from pytorch3d.io import IO
def get_blank_gs_bundle(h, w):
return {
"camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w),
"W": w,
"H": h,
"pcd_points": None,
"pcd_colors": None,
'frames': [],
}
def extrapolate_point_cloud(prompt: str, image_size: Tuple[int, int], look_at_params: Iterable[Tuple[float, float, float, Tuple[float, float, float]]], point_cloud: Pointclouds = None, dry_run: bool = False, discard_mask: bool = False, initial_image: Optional[Image.Image] = None, depth_scaling: float = 1, **render_kwargs):
w, h = image_size
optimization_bundle_frames = []
for azim, elev, dist, at in look_at_params:
R, T = look_at_view_transform(device=device, azim=azim, elev=elev, dist=dist, at=at)
cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)
if point_cloud is not None:
images, masks, depths = render(cameras, point_cloud, **render_kwargs)
if not dry_run:
eroded_mask = skimage.morphology.binary_erosion((depths[0] > 0).cpu().numpy(), footprint=None)#skimage.morphology.disk(1))
eroded_depth = depths[0].clone()
eroded_depth[torch.from_numpy(eroded_mask).to(depths.device) <= 0] = 0
outpainted_img, aligned_depth = outpaint_with_depth_estimation(images[0], masks[0], eroded_depth, h, w, pipe, zoe_dc_model, prompt, cameras, dilation_size=2, depth_scaling=depth_scaling, generator=torch.Generator(device=pipe.device).manual_seed(0))
aligned_depth = torch.from_numpy(aligned_depth).to(device)
else:
# in a dry run, we do not actually outpaint the image
outpainted_img = Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8))
else:
assert initial_image is not None
assert not dry_run
# jumpstart the point cloud with a regular depth estimation
t_initial_image = torch.from_numpy(np.asarray(initial_image)/255.).permute(2,0,1).float()
depth = aligned_depth = infer_with_zoe_dc(zoe_dc_model, t_initial_image, torch.zeros(h, w))
outpainted_img = initial_image
images = [t_initial_image.to(device)]
masks = [torch.ones(h, w, dtype=torch.bool).to(device)]
if not dry_run:
# snap high gradients to nearest neighbor, which eliminates noodle artifacts
aligned_depth = snap_high_gradients_to_nn(aligned_depth, threshold=12).cpu()
xy_depth_world = project_points(cameras, aligned_depth)
c2w = cameras.get_world_to_view_transform().get_matrix()[0]
optimization_bundle_frames.append({
"image": outpainted_img,
"mask": masks[0].cpu().numpy(),
"transform_matrix": c2w.tolist(),
"azim": azim,
"elev": elev,
"dist": dist,
})
if discard_mask:
optimization_bundle_frames[-1].pop("mask")
if not dry_run:
optimization_bundle_frames[-1]["center_point"] = xy_depth_world[0].mean(dim=0).tolist()
optimization_bundle_frames[-1]["depth"] = aligned_depth.cpu().numpy()
optimization_bundle_frames[-1]["mean_depth"] = aligned_depth.mean().item()
else:
# in a dry run, we do not modify the point cloud
continue
rgb = (torch.from_numpy(np.asarray(outpainted_img).copy()).reshape(-1, 3).float() / 255).to(device)
if point_cloud is None:
point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)
else:
# pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
# in theory, 1 pixel is sufficient but we use 2 to be safe
masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(2))).to(device)
partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])
point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])
return optimization_bundle_frames, point_cloud
def generate_point_cloud(initial_image: Image.Image, prompt: str):
image_size = initial_image.size
w, h = image_size
optimization_bundle = get_blank_gs_bundle(h, w)
step_size = 25
azim_steps = [0, step_size, -step_size]
look_at_params = [(azim, 0, 0.01, torch.zeros((1, 3))) for azim in azim_steps]
optimization_bundle["frames"], point_cloud = extrapolate_point_cloud(prompt, image_size, look_at_params, discard_mask=True, initial_image=initial_image, depth_scaling=0.5, fill_point_cloud_holes=True)
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
return optimization_bundle, point_cloud
def supplement_point_cloud(optimization_bundle: Dict, point_cloud: Pointclouds, prompt: str):
w, h = optimization_bundle["W"], optimization_bundle["H"]
supporting_frames = []
for i, frame in enumerate(optimization_bundle["frames"]):
# skip supporting views
if frame.get("supporting", False):
continue
center_point = torch.tensor(frame["center_point"]).to(device)
mean_depth = frame["mean_depth"]
azim, elev = frame["azim"], frame["elev"]
azim_jitters = torch.linspace(-5, 5, 3).tolist()
elev_jitters = torch.linspace(-5, 5, 3).tolist()
# build the product of azim and elev jitters
camera_jitters = [{"azim": azim + azim_jitter, "elev": elev + elev_jitter} for azim_jitter, elev_jitter in itertools.product(azim_jitters, elev_jitters)]
look_at_params = [(camera_jitter["azim"], camera_jitter["elev"], mean_depth, center_point.unsqueeze(0)) for camera_jitter in camera_jitters]
local_supporting_frames, point_cloud = extrapolate_point_cloud(prompt, (w, h), look_at_params, point_cloud, dry_run=True, depth_scaling=0.5, antialiasing=3)
for local_supporting_frame in local_supporting_frames:
local_supporting_frame["supporting"] = True
supporting_frames.extend(local_supporting_frames)
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
return optimization_bundle, point_cloud
def generate_scene(img: Image.Image, prompt: str):
assert isinstance(img, Image.Image)
# resize image maintaining the aspect ratio so the longest side is 720 pixels
max_size = 720
img.thumbnail((max_size, max_size))
# crop to ensure the image dimensions are divisible by 8
img = img.crop((0, 0, img.width - img.width % 8, img.height - img.height % 8))
from hashlib import sha1
from datetime import datetime
run_id = sha1(datetime.now().isoformat().encode()).hexdigest()[:6]
run_name = f"gradio_{run_id}"
gs_optimization_bundle, point_cloud = generate_point_cloud(img, prompt)
#downsampled_point_cloud = downsample_point_cloud(gs_optimization_bundle, device=device)
#gs_optimization_bundle["pcd_points"] = downsampled_point_cloud.points_padded()[0].cpu().numpy()
#gs_optimization_bundle["pcd_colors"] = downsampled_point_cloud.features_padded()[0].cpu().numpy()
scene = Scene(gs_optimization_bundle, GaussianModel(gs_options.sh_degree), gs_options)
scene = run_gaussian_splatting(scene, gs_optimization_bundle)
# coordinate system transformation
scene.gaussians._xyz = scene.gaussians._xyz.detach()
scene.gaussians._xyz[:, 1] = -scene.gaussians._xyz[:, 1]
scene.gaussians._xyz[:, 2] = -scene.gaussians._xyz[:, 2]
os.makedirs("outputs", exist_ok=True)
save_path = os.path.join("outputs", f"{run_name}.ply")
scene.gaussians.save_ply(save_path)
return save_path
if __name__ == "__main__":
global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utils.models import get_zoe_dc_model, get_sd_pipeline
global zoe_dc_model
from huggingface_hub import hf_hub_download
zoe_dc_model = get_zoe_dc_model(ckpt_path=hf_hub_download(repo_id="paulengstler/invisible-stitch", filename="invisible-stitch.pt")).to(device)
global pipe
pipe = get_sd_pipeline(device)
demo = gr.Interface(
fn=generate_scene,
inputs=[
gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil"),
gr.Textbox(label="Scene Hallucination Prompt")
],
outputs=gr.Model3D(label="Generated Scene"),
allow_flagging="never",
title="Invisible Stitch: Generating Smooth 3D Scenes with Depth Inpainting",
description="Hallucinate geometrically coherent 3D scenes from a single input image in less than 30 seconds.<br /> [Project Page](https://research.paulengstler.com/invisible-stitch) | [GitHub](https://github.com/paulengstler/invisible-stitch) | [Paper](#) <br /><br />To keep this demo snappy, we have limited its functionality. Scenes are generated at a low resolution without densification, supporting views are not inpainted, and we do not optimize the resulting point cloud. Imperfections are to be expected, in particular around object borders. Please allow a couple of seconds for the generated scene to be downloaded (about 40 megabytes).",
article="Please consider running this demo locally to obtain high-quality results (see the GitHub repository).<br /><br />Here are some observations we made that might help you to get better results:<ul><li>Use generic prompts that match the surroundings of your input image.</li><li>Ensure that the borders of your input image are free from partially visible objects.</li><li>Keep your prompts simple and avoid adding specific details.</li></ul>",
examples=[
["examples/photo-1667788000333-4e36f948de9a.jpeg", "a street with traditional buildings in Kyoto, Japan"],
["examples/photo-1628624747186-a941c476b7ef.jpeg", "a suburban street in North Carolina on a bright, sunny day"],
["examples/photo-1469559845082-95b66baaf023.jpeg", "a view of Zion National Park"],
["examples/photo-1514984879728-be0aff75a6e8.jpeg", "a close-up view of a muddy path in a forest"],
["examples/photo-1618197345638-d2df92b39fe1.jpeg", "a close-up view of a white linen bed in a minimalistic room"],
["examples/photo-1546975490-e8b92a360b24.jpeg", "a warm living room with plants"],
["examples/photo-1499916078039-922301b0eb9b.jpeg", "a cozy bedroom on a bright day"],
])
demo.queue().launch(share=True)