Skip to content
This repository has been archived by the owner on May 22, 2024. It is now read-only.

Commit

Permalink
Merge pull request #5 from dariopavllo/main
Browse files Browse the repository at this point in the history
Update documentation, add demo inference
  • Loading branch information
djnewtan committed Mar 20, 2023
2 parents 29a5929 + 5a518d6 commit fddf72a
Show file tree
Hide file tree
Showing 12 changed files with 425 additions and 165 deletions.
43 changes: 40 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,42 @@
*This is not an officially supported Google product.*

This repository contains the code for the paper
> Dario Pavllo, David Joseph Tan, Marie-Julie Rakotosaona, Federico Tombari. [Shape, Pose, and Appearance from a Single Image via Bootstrapped Radiance Field Inversion](https://arxiv.org/abs/2211.11674). In arXiv, 2022.
> Dario Pavllo, David Joseph Tan, Marie-Julie Rakotosaona, Federico Tombari. [Shape, Pose, and Appearance from a Single Image via Bootstrapped Radiance Field Inversion](https://arxiv.org/abs/2211.11674). In IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2023.
Our approach recovers an SDF-parameterized 3D shape, pose, and appearance from a single image of an object, without exploiting multiple views during training. More specifically, we leverage an unconditional 3D-aware generator, to which we apply a hybrid inversion scheme where a model produces a first guess of the solution which is then refined via optimization.

![](images/teaser.jpg)
![](images/anim.gif)

# Setup
Please follow the instructions in [SETUP.md](SETUP.md).

# Demo
After setting up the pretrained models, you can quickly visualize some results by specifying `--inv_export_demo_sample`, e.g.
```
python run.py --resume_from g_p3d_car_pretrained --inv_export_demo_sample --gpus 4 --batch_size 16
python run.py --resume_from g_cub_pretrained --inv_export_demo_sample --gpus 4 --batch_size 16
python run.py --resume_from g_shapenet_chairs_pretrained --inv_export_demo_sample --gpus 4 --batch_size 16
```
This will run the inversion procedure on a random batch from the test set, and save the resulting images to `outputs/`. You can vary the number of GPUs using `--gpus` (default: 4) and the total batch size using `--batch_size` (default: 32).

# Inference on a custom image
You can try out the model on a custom image by specifying `--inv_manual_input_path <URL or path>`. Internally, it uses detectron2 (which you need to install) to extract the segmentation mask from the image. For an example on CUB birds, try out the following:
```
python run.py --resume_from g_cub_pretrained --inv_manual_input_path https://upload.wikimedia.org/wikipedia/commons/a/a7/Pyrrhula_pyrrhula_female_2.jpg
```

You can also increase the number of inversion steps through `--inv_steps` (from the default 30).

# Evaluation
You can evaluate the reconstruction process quantitatively using the `--run_inversion` flag, e.g.
```
python run.py --resume_from g_p3d_car_pretrained --run_inversion
python run.py --resume_from g_cub_pretrained --run_inversion
python run.py --resume_from g_shapenet_chairs_pretrained --run_inversion
```
This command will first look for a pre-trained encoder, which we provide for all the datasets used in this work (if it is missing, it will train one from scratch). Afterwards, it will invert the full test set (if available) and produce a report with the metrics shown in the paper. As before, you can vary the batch size and number of GPUs using `--batch_size` and `--gpus`, which will not affect the results. You can also compute results in feed-forward mode by specifying `--inv_encoder_only`, which produces the numbers labeled as N=0 in the paper. For `p3d_car`, you can evaluate on our custom ImageNet test set by specifying `--inv_use_imagenet_testset` (otherwise, the official test set is used).

# Training
The unconditional generator can be trained as follows:
```
Expand All @@ -22,11 +49,21 @@ Afterwards, the hybrid inversion procedure can be launched via:
```
python run.py --resume_from EXPERIMENT_NAME --run_inversion
```
where `EXPERIMENT_NAME` is the name of the experiment produced by the previous step (you can also find it in `gan_checkpoints/`). This will first train the encoder, save it to `coords_checkpoints/`, and finally launch the actual inversion procedure (whose outputs and TensorBoard logs are exported to `reports/`). You can also compute results in feed-forward mode by specifying `--inv_encoder_only`, which produces the numbers labeled as N=0 in the paper. For `p3d_car`, you can evaluate on our custom ImageNet test set by specifying `--inv_use_imagenet_testset` (otherwise, the official test set is used).
where `EXPERIMENT_NAME` is the name of the experiment produced by the previous step (you can also find it in `gan_checkpoints/`). This will first train the encoder, save it to `coords_checkpoints/`, and finally launch the actual inversion procedure (whose outputs and TensorBoard logs are exported to `reports/`). Once trained, the encoder will be cached and reused in subsequent calls.

The full list of arguments can be found in [arguments.py](arguments.py)

*More details coming soon.*
# Citation
If you use this work in your research, consider citing our paper:
```
@inproceedings{pavllo2023shape,
title={Shape, Pose, and Appearance from a Single Image via Bootstrapped Radiance Field Inversion},
author={Pavllo, Dario and Tan, David Joseph and Rakotosaona, Marie-Julie and Tombari, Federico},
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2023}
}
```


# License
This code is licensed under the Apache 2.0 License. See [LICENSE](LICENSE) for more details.
10 changes: 10 additions & 0 deletions arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def parse_args():
type=int,
default=5,
help='Gain to use for inversion')
parser.add_argument('--inv_steps',
type=int,
help='Specify max number of inversion steps manually')
parser.add_argument('--inv_no_split',
action='store_true',
help='Do not split latent code for inversion')
Expand All @@ -183,6 +186,13 @@ def parse_args():
parser.add_argument('--inv_encoder_only',
action='store_true',
help='Do not apply inversion (show result with N=0)')
parser.add_argument('--inv_export_demo_sample',
action='store_true',
help='Export demo image on the first batch')
parser.add_argument('--inv_manual_input_path',
type=str,
help='Path or URL for demo inference')


# Coord regressor params
parser.add_argument('--coord_resume_from', type=str)
Expand Down
67 changes: 42 additions & 25 deletions data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,33 +259,47 @@ def mirror_image(self, img, mask, sfm_pose, bbox=None):
else:
return img_flip, mask_flip, sfm_pose

def forward_img(self, idx):
idx_ = idx
if self.add_mirrored and idx >= len(self.detections):
idx_ -= len(self.detections)
mirrored = True
def forward_img(self, idx, manual_image=None):
if manual_image is None:
idx_ = idx
if self.add_mirrored and idx >= len(self.detections):
idx_ -= len(self.detections)
mirrored = True
else:
mirrored = False
item = self.detections[idx_]

img_path_rel = os.path.join(self.root_dir,
item['image_path'].replace('datasets/', ''))
img_path = img_path_rel
mask = pycocotools.mask.decode(item['mask'])
bbox = item['bbox'].flatten()

img = skimage.io.imread(img_path) / 255.0
# Some are grayscale:
if len(img.shape) == 2:
img = np.repeat(np.expand_dims(img, 2), 3, axis=2)
mask = np.expand_dims(mask, 2)

# Sfm pose layout:
# Focal, Translation xyz, Rot
sfm_pose = [
self.poses['f'][idx_].numpy(), self.poses['t'][idx_].numpy(),
self.poses['R'][idx_].numpy()
]
else:
img = manual_image['image']
mask = manual_image['mask']
bbox = manual_image['bbox']
mirrored = False
item = self.detections[idx_]
img_path_rel = ''

img_path_rel = os.path.join(self.root_dir,
item['image_path'].replace('datasets/', ''))
img_path = img_path_rel
mask = pycocotools.mask.decode(item['mask'])
bbox = item['bbox'].flatten()

img = skimage.io.imread(img_path) / 255.0
# Some are grayscale:
if len(img.shape) == 2:
img = np.repeat(np.expand_dims(img, 2), 3, axis=2)
mask = np.expand_dims(mask, 2)

# Sfm pose layout:
# Focal, Translation xyz, Rot
sfm_pose = [
self.poses['f'][idx_].numpy(), self.poses['t'][idx_].numpy(),
self.poses['R'][idx_].numpy()
]
# Dummy pose
sfm_pose = [
np.zeros((1,), dtype=np.float32),
np.zeros((3,), dtype=np.float32),
np.zeros((4,), dtype=np.float32),
]

crop = self.crop # ImageNet / P3D

Expand Down Expand Up @@ -464,7 +478,10 @@ def normalize_kp(self, sfm_pose, img_h, img_w):
sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1
return sfm_pose

def forward_img(self, idx):
def forward_img(self, idx, manual_image=None):
if manual_image is not None:
return super().forward_img(idx, manual_image)

idx_ = idx
if self.add_mirrored and idx >= len(self.anno):
idx_ -= len(self.anno)
Expand Down
66 changes: 52 additions & 14 deletions data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ def get_dataset_loaders():
'imagenet_elephant': load_custom,
}

def get_coco_mapping():
return {
'p3d_car': 2,
'cub': 14,
'imagenet_car': 2,
'imagenet_airplane': 4,
'imagenet_motorcycle': 3,
'imagenet_zebra': 22,
'imagenet_elephant': 20,
}


class DatasetSplitView:

Expand Down Expand Up @@ -199,17 +210,33 @@ def autodetect_dataset(experiment_name):
)


def load_dataset(args, device):
def load_dataset(args, device, manual_image=None):
override_default_args(args)
dataset_config = get_dataset_config(args.dataset)
loader = get_dataset_loaders()[args.dataset]
if manual_image is not None:
extra_kwargs = {'manual_image': manual_image}
args.augment_p = 0
else:
extra_kwargs = {}
train_split, train_eval_split, test_split = loader(dataset_config, args,
device)
device, **extra_kwargs)

return dataset_config, train_split, train_eval_split, test_split


def load_custom(dataset_config, args, device):
def insert_manual_image(dataset, split, manual_image):
img, mask, _, _, _, _, _, bbox, _ = dataset.forward_img(None, manual_image)
mask = mask[None, :, :]
img = img * 2 - 1
img *= mask
img = np.concatenate((img, mask), axis=0)
img = torch.FloatTensor(img).permute(1, 2, 0)
split.images[0] = img
if split.bbox[0] is not None and split.bbox[0].shape[-1] == 4:
split.bbox[0] = torch.FloatTensor(bbox)

def load_custom(dataset_config, args, device, manual_image=None):
if args.dataset.startswith('p3d_') or args.dataset.startswith('imagenet_'):
dataset_inst = lambda *fn_args, **fn_kwargs: datasets.CustomDataset(
args.dataset, *fn_args, **fn_kwargs, root_dir=args.data_path)
Expand Down Expand Up @@ -281,17 +308,18 @@ def load_custom(dataset_config, args, device):
F.avg_pool2d(sample['img'], 2).clamp(-1, 1).permute(0, 2, 3, 1))
else:
all_images.append(sample['img'].clamp(-1, 1).permute(0, 2, 3, 1))
all_poses.append(sample['pose'])
all_focal.append(sample['focal'])
all_bbox.append(sample['normalized_bbox'])
all_classes.append(sample['class'])
# We clone the tensors to avoid issues with shared memory
all_poses.append(sample['pose'].clone())
all_focal.append(sample['focal'].clone())
all_bbox.append(sample['normalized_bbox'].clone())
all_classes.append(sample['class'].clone())

for i, sample in enumerate(tqdm(loader_fid)):
all_images_fid.append(sample['img'].clamp(-1, 1).permute(0, 2, 3, 1))
all_poses_fid.append(sample['pose'])
all_focal_fid.append(sample['focal'])
all_bbox_fid.append(sample['normalized_bbox'])
all_classes_fid.append(sample['class'])
all_poses_fid.append(sample['pose'].clone())
all_focal_fid.append(sample['focal'].clone())
all_bbox_fid.append(sample['normalized_bbox'].clone())
all_classes_fid.append(sample['class'].clone())

if dataset_config['views_per_object_test'] and (args.use_encoder or
args.run_inversion):
Expand All @@ -302,15 +330,19 @@ def load_custom(dataset_config, args, device):
for i, sample in enumerate(tqdm(loader_test)):
all_images_test.append(sample['img'].clamp(-1,
1).permute(0, 2, 3, 1))
all_poses_test.append(sample['pose'])
all_focal_test.append(sample['focal'])
all_bbox_test.append(sample['normalized_bbox'])
all_poses_test.append(sample['pose'].clone())
all_focal_test.append(sample['focal'].clone())
all_bbox_test.append(sample['normalized_bbox'].clone())
test_split.images = torch.cat(all_images_test, dim=0)
test_split.tform_cam2world = torch.cat(all_poses_test, dim=0)
test_split.focal_length = torch.cat(all_focal_test, dim=0).squeeze(1)
test_split.bbox = torch.cat(all_bbox_test, dim=0)
print('Loaded test split with shape', test_split.images.shape)

if manual_image is not None:
# Replace first image with supplied image (demo inference)
insert_manual_image(dataset_test, test_split, manual_image)

train_split.images = torch.cat(all_images, dim=0)
train_eval_split.images = torch.cat(all_images_fid, dim=0)
all_images = None # Free up memory
Expand All @@ -323,12 +355,18 @@ def load_custom(dataset_config, args, device):
train_split.bbox = torch.cat(all_bbox, dim=0)
train_split.classes = torch.cat(all_classes, dim=0)
train_split.num_classes = train_split.classes.max().item() + 1
if manual_image is not None:
# Replace first image with supplied image (demo inference)
insert_manual_image(dataset, train_split, manual_image)

train_eval_split.tform_cam2world = torch.cat(all_poses_fid, dim=0)
train_eval_split.focal_length = torch.cat(all_focal_fid, dim=0).squeeze(1)
train_eval_split.bbox = torch.cat(all_bbox_fid, dim=0)
train_eval_split.classes = torch.cat(all_classes_fid, dim=0)
train_eval_split.num_classes = train_split.num_classes
if manual_image is not None:
# Replace first image with supplied image (demo inference)
insert_manual_image(dataset_fid, train_eval_split, manual_image)

if args.dataset == 'cub':
# Ortho camera
Expand Down
Binary file added images/anim.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions lib/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def ssim(pred, target, reduction='mean'):
skimage.metrics.structural_similarity(pred,
target,
channel_axis=0,
range=1.)
data_range=1.)
]).to(device)
else:
pred = pred.cpu().numpy()
Expand All @@ -72,7 +72,7 @@ def ssim(pred, target, reduction='mean'):
skimage.metrics.structural_similarity(pred_elem,
target_elem,
channel_axis=0,
range=1.))
data_range=1.))
return torch.FloatTensor(similarities).to(device)


Expand Down
2 changes: 1 addition & 1 deletion lib/nerf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def sample_pdf(bins,

u = u.contiguous()
cdf = cdf.contiguous()
inds = torch.searchsorted(cdf, u, side='right')
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack((below, above), dim=-1)
Expand Down
5 changes: 3 additions & 2 deletions lib/pose_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def quaternion_rotate_vector(q, v):

def quaternion_to_matrix(q):
"""Converts a unit quaternion to a rotation matrix."""
return quaternion_rotate_vector(q,
torch.eye(3, device=q.device).unsqueeze(0))
return quaternion_rotate_vector(
q,
torch.eye(3, device=q.device).unsqueeze(0).expand(q.shape[0], -1, -1))


def pose_to_matrix(z0, t2, s, q, camera_flipped: bool):
Expand Down
47 changes: 47 additions & 0 deletions lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,53 @@ def restore_random_state(rng_state, data_sampler, rng, gpu_ids):
data_sampler.set_state(rng_state['data_sampler_state'], rng)


def load_manual_image(path_or_url, coco_class_id):
# On-demand imports
import urllib
import PIL
import detectron2
import detectron2.config
import detectron2.model_zoo
import detectron2.engine

if path_or_url.startswith('http'):
with urllib.request.urlopen(path_or_url) as response:
demo_input_img = PIL.Image.open(io.BytesIO(response.read()))
else:
demo_input_img = PIL.Image.open(path_or_url)
demo_input_img = np.array(demo_input_img)
assert len(demo_input_img.shape) == 3
assert demo_input_img.shape[-1] in [3, 4] # RGB(A)
demo_input_img = demo_input_img[:, :, :3]

# In the paper we use PointRend to extract masks,
# but for a demo a simpler model is also fine.
# cfg_file = 'COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'
cfg_file = 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml'

cfg = detectron2.config.get_cfg()
cfg.merge_from_file(detectron2.model_zoo.get_config_file(cfg_file))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Detection threshold
cfg.MODEL.WEIGHTS = detectron2.model_zoo.get_checkpoint_url(cfg_file)
predictor = detectron2.engine.DefaultPredictor(cfg)

# Detectron expects BGR format
outputs = predictor(demo_input_img[:, :, ::-1])['instances']
outputs = outputs[outputs.pred_classes == coco_class_id]
if len(outputs) == 0:
raise RuntimeError('Could not detect any object in the provided image')

# Extract largest detected object
outputs = outputs[outputs.pred_masks.sum(dim=[1, 2]).argmax().item()]

manual_image = {
'image': demo_input_img.astype(np.float32) / 255,
'mask': outputs.pred_masks[0].cpu().float().unsqueeze(-1),
'bbox': outputs.pred_boxes.tensor[0].cpu().tolist(),
}
return manual_image


class EndlessSampler:

def __init__(self, dataset_size, rng):
Expand Down
Loading

0 comments on commit fddf72a

Please sign in to comment.