Skip to content

Commit

Permalink
Add inference script
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrii Zadaianchuk committed Apr 10, 2024
1 parent dbbb63d commit 524797e
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 369 deletions.
21 changes: 15 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,27 @@ If you want to continue training from a previous run, you can use the `--continu
poetry run python -m videosaur.train --continue <path_to_log_dir_or_checkpoint_file> configs/videosaur/movi_c.yml
```

### Inference
If you want to run one of the released checkpoints (see below) on your own video you can use inference script with corresponding config file:

```
poetry run python -m videosaur.inference --config configs/inference/movi_c.yml
```
in the released config, please change `checkpoint: path/to/videosaur-movi-c.ckpt` to the real path to your checkpoint.
For different video formats you would need to modify corresponding transformations in `build_inference_transform` function.

## Results

### VideoSAUR

We list the results you should roughly be able to obtain with the configs included in this repository:

| Dataset | Model Variant | Video ARI | Video mBO | Config |
|--------------|------------------|------------|-----------|-----------------------------|
| MOVi-C | ViT-B/8, DINO | 64.8 | 38.9 | videosaur/movi_c.yml |
| MOVi-E | ViT-B/8, DINO | 73.9 | 35.6 | videosaur/movi_e.yml |
| YT-VIS 2021 | ViT-B/16, DINO | 39.5 | 29.1 | videosaur/ytvis.yml |
| YT-VIS 2021 | ViT-B/14, DINOv2 | 39.7 | 35.6 | videosaur/ytvis_dinov2.yml |
| Dataset | Model Variant | Video ARI | Video mBO | Config | Checkpoint Link |
|--------------|------------------|-----------|-----------|-----------------------------|------------------------------------------------------------------------------------------------------------|
| MOVi-C | ViT-B/8, DINO | 64.8 | 38.9 | videosaur/movi_c.yml | [Checkpoint](https://huggingface.co/andriizadaianchuk/videosaur-movi-c/resolve/main/videosaur-movi-c.ckpt) |
| MOVi-E | ViT-B/8, DINO | 73.9 | 35.6 | videosaur/movi_e.yml | [Checkpoint](https://huggingface.co/andriizadaianchuk/videosaur-movi-e/resolve/main/videosaur-movi-e.ckpt) |
| YT-VIS 2021 | ViT-B/16, DINO | 39.5 | 29.1 | videosaur/ytvis.yml | [Checkpoint](https://huggingface.co/andriizadaianchuk/videosaur-ytvis/resolve/main/videosaur-ytvis.ckpt) |
| YT-VIS 2021 | ViT-B/14, DINOv2 | 39.7 | 35.6 | videosaur/ytvis_dinov2.yml | [Checkpoint](https://huggingface.co/andriizadaianchuk/videosaur-ytvis-dinov2-518/resolve/main/videosaur_dinov2.ckpt) |

### DINOSAUR

Expand Down
10 changes: 10 additions & 0 deletions configs/inference/movi_c.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
checkpoint: path/to/videosaur-movi-c.ckpt
model_config: configs/videosaur/movi_c.yml
input:
path: docs/static/videos/video.mp4
transforms:
use_movi_normalization: true
type: video
input_size: 224
output:
save_video_path: docs/static/videos/video_masks.mp4
Binary file added docs/static/videos/video.mp4
Binary file not shown.
Binary file added docs/static/videos/video_masks.mp4
Binary file not shown.
418 changes: 56 additions & 362 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ matplotlib = {version = "^3.6.2", optional=true}
moviepy = "^1.0.3"
pycocotools = {version = "^2.0.6", optional=true}
requests = "^2.31.0"
av = "^12.0.0"

[tool.poetry.extras]
tensorflow = ["tensorflow-cpu", "tensorflow_datasets"]
Expand Down
33 changes: 33 additions & 0 deletions videosaur/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,39 @@ def build(config):

return transforms

def build_inference_transform(config):
"""Builds the transform for inference.
Modity if needed to match the preprocessing needed for your video.
"""
use_movi_normalization = config.get("use_movi_normalization", True)
size = config.get("input_size", 224)
dataset_type = config.get("dataset_type", "video")

resize_input = CropResize(
dataset_type=dataset_type,
crop_type="central",
size=size,
resize_mode="bilinear",
clamp_zero_one=False,
)

if use_movi_normalization:
normalize = Normalize(
dataset_type=dataset_type, mean=MOVI_DEFAULT_MEAN, std=MOVI_DEFAULT_STD
)
else:
normalize = Normalize(
dataset_type=dataset_type, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
)
input_transform = tvt.Compose(
[
resize_input,
normalize,
]
)
return input_transform


def _to_2tuple(val):
if val is None:
Expand Down
69 changes: 69 additions & 0 deletions videosaur/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import argparse
import torch
from torchvision.io import read_video
from omegaconf import OmegaConf
from videosaur import configuration, models
from videosaur.data.transforms import CropResize, Normalize
import os
import numpy as np
import imageio
from torchvision import transforms as tvt
from videosaur.visualizations import mix_inputs_with_masks
from videosaur.data.transforms import build_inference_transform


def load_model_from_checkpoint(checkpoint_path: str, config_path: str):
config = configuration.load_config(config_path)
model = models.build(config.model, config.optimizer)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])
model.eval()
return model, config

def prepare_inputs(video_path: str, transfom_config=None):
# Load video
video, _, _ = read_video(video_path)
video = video.float() / 255.0
#change size of the video to 224x224
video_vis = video.permute(0, 3, 1, 2)
video_vis = tvt.Resize((transfom_config.input_size, transfom_config.input_size))(video_vis)
video_vis = video_vis.permute(1, 0, 2, 3)


if transfom_config:
tfs = build_inference_transform(transfom_config)
video = video.permute(3, 0, 1, 2)
video = tfs(video).permute(1, 0, 2, 3)
# Add batch dimension
inputs = {"video": video.unsqueeze(0),
"video_visualization": video_vis.unsqueeze(0)}
return inputs




def main(config):
# Load the model from checkpoint
model, _ = load_model_from_checkpoint(config.checkpoint, config.model_config)
# Prepare the video dict
inputs = prepare_inputs(config.input.path, config.input.transforms)
# Perform inference
with torch.no_grad():
outputs = model(inputs)
if config.output.save_video_path:
# Save the results
save_dir = os.path.dirname(config.output.save_video_path)
os.makedirs(save_dir, exist_ok=True)
masked_video_frames = mix_inputs_with_masks(inputs, outputs)
with imageio.get_writer(config.output.save_video_path, fps=10) as writer:
for frame in masked_video_frames:
writer.append_data(frame)
writer.close()
print("Inference completed.")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Perform inference on a single MP4 video.")
parser.add_argument("--config", default="configs/inference/movi_c.yml", help="Configuration to run")
args = parser.parse_args()
config = OmegaConf.load(args.config)
main(config)
91 changes: 90 additions & 1 deletion videosaur/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import numpy as np
import torch
from PIL import ImageColor
from PIL import ImageColor, Image
from videosaur.data.transforms import Resize

CMAP_STYLE = "tab"

Expand Down Expand Up @@ -252,3 +253,91 @@ def bitget(byteval, idx):
cmap = cmap / 255 if normalized else cmap

return cmap

def create_grid_frame(frames, grid_size=(2, 6), image_size=(224, 224), padding=2):
# Initialize an empty frame with padding
grid_frame = np.zeros((grid_size[0] * (image_size[0] + padding) - padding,
grid_size[1] * (image_size[1] + padding) - padding), dtype=np.float64)

for index, frame in enumerate(frames):
row = index // grid_size[1]
col = index % grid_size[1]
start_row = row * (image_size[0] + padding)
start_col = col * (image_size[1] + padding)
grid_frame[start_row:start_row+image_size[0], start_col:start_col+image_size[1]] = frame

return grid_frame

def create_grid_frame_rgb(frames, grid_size=(2, 6), image_size=(224, 224), padding=2):
"""
Create a grid frame from individual RGB frames.
Args:
frames (list of np.ndarray): List of frames, each frame should be of shape (height, width, 3).
grid_size (tuple): The grid size as (rows, columns).
image_size (tuple): The size of each image in the grid as (height, width).
padding (int): The padding size between images in the grid.
Returns:
np.ndarray: An image of the grid.
"""
# Initialize an empty frame with padding for RGB channels
grid_frame = np.zeros((
grid_size[0] * (image_size[0] + padding) - padding,
grid_size[1] * (image_size[1] + padding) - padding,
3), # Depth of 3 for RGB
dtype=np.float32)

for index, frame in enumerate(frames):
if frame.ndim < 3:
raise ValueError("All frames must have 3 dimensions (height, width, channels)")
if frame.shape[2] != 3:
raise ValueError("All frames must be RGB with 3 channels")

row = index // grid_size[1]
col = index % grid_size[1]
start_row = row * (image_size[0] + padding)
start_col = col * (image_size[1] + padding)
end_row = start_row + image_size[0]
end_col = start_col + image_size[1]

# Check if frame resizing is needed
assert frame.shape[:2] == image_size

grid_frame[start_row:end_row, start_col:end_col, :] = frame

return grid_frame

def mix_inputs_with_masks(inputs, outputs, softmasks=True):

b, f, n_slots, hw = outputs["decoder"]["masks"].shape
h = int(np.sqrt(hw))
w = h
masks_video = outputs["decoder"]["masks"].reshape(b, f, n_slots, h, w)
assert b == 1, "Batch size must be 1 for visualization"
masks_video = masks_video.squeeze(0)

#resize masks to 224x224
resizer = Resize(224, mode='bilinear')
masks_video = resizer(masks_video)

if not softmasks:
ind = torch.argmax(masks_video, dim=1, keepdim=True)
masks_video = torch.zeros_like(masks_video)
masks_video.scatter_(1, ind, 1)

#create a grid of videos multiplied with binary masks
masked_video_frames = []
for t in range(masks_video.shape[0]): # Iterate through each time step
frames = [masks_video[t, i] for i in range(masks_video.shape[1])] # Get all frames for this time step
#incule one masks of ones for the original video as first frame
frames = [np.ones_like(frames[0])] + frames
grid_frame = create_grid_frame(frames)
# Optional: Convert grid_frame to RGB if needed
grid_frame_rgb = np.repeat(grid_frame[:, :, np.newaxis], 3, axis=2)
video = inputs["video_visualization"]
video_frames = [video[0, :, t].permute(1,2,0).numpy() for i in range(masks_video.shape[1]+1)]
grid_video = create_grid_frame_rgb(video_frames)
masked_video = (grid_video * grid_frame_rgb * 255).astype(np.uint8)
masked_video_frames.append(masked_video)
return masked_video_frames

0 comments on commit 524797e

Please sign in to comment.