-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Andrii Zadaianchuk
committed
Apr 10, 2024
1 parent
dbbb63d
commit 524797e
Showing
9 changed files
with
274 additions
and
369 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
checkpoint: path/to/videosaur-movi-c.ckpt | ||
model_config: configs/videosaur/movi_c.yml | ||
input: | ||
path: docs/static/videos/video.mp4 | ||
transforms: | ||
use_movi_normalization: true | ||
type: video | ||
input_size: 224 | ||
output: | ||
save_video_path: docs/static/videos/video_masks.mp4 |
Binary file not shown.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import argparse | ||
import torch | ||
from torchvision.io import read_video | ||
from omegaconf import OmegaConf | ||
from videosaur import configuration, models | ||
from videosaur.data.transforms import CropResize, Normalize | ||
import os | ||
import numpy as np | ||
import imageio | ||
from torchvision import transforms as tvt | ||
from videosaur.visualizations import mix_inputs_with_masks | ||
from videosaur.data.transforms import build_inference_transform | ||
|
||
|
||
def load_model_from_checkpoint(checkpoint_path: str, config_path: str): | ||
config = configuration.load_config(config_path) | ||
model = models.build(config.model, config.optimizer) | ||
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) | ||
model.load_state_dict(checkpoint['state_dict']) | ||
model.eval() | ||
return model, config | ||
|
||
def prepare_inputs(video_path: str, transfom_config=None): | ||
# Load video | ||
video, _, _ = read_video(video_path) | ||
video = video.float() / 255.0 | ||
#change size of the video to 224x224 | ||
video_vis = video.permute(0, 3, 1, 2) | ||
video_vis = tvt.Resize((transfom_config.input_size, transfom_config.input_size))(video_vis) | ||
video_vis = video_vis.permute(1, 0, 2, 3) | ||
|
||
|
||
if transfom_config: | ||
tfs = build_inference_transform(transfom_config) | ||
video = video.permute(3, 0, 1, 2) | ||
video = tfs(video).permute(1, 0, 2, 3) | ||
# Add batch dimension | ||
inputs = {"video": video.unsqueeze(0), | ||
"video_visualization": video_vis.unsqueeze(0)} | ||
return inputs | ||
|
||
|
||
|
||
|
||
def main(config): | ||
# Load the model from checkpoint | ||
model, _ = load_model_from_checkpoint(config.checkpoint, config.model_config) | ||
# Prepare the video dict | ||
inputs = prepare_inputs(config.input.path, config.input.transforms) | ||
# Perform inference | ||
with torch.no_grad(): | ||
outputs = model(inputs) | ||
if config.output.save_video_path: | ||
# Save the results | ||
save_dir = os.path.dirname(config.output.save_video_path) | ||
os.makedirs(save_dir, exist_ok=True) | ||
masked_video_frames = mix_inputs_with_masks(inputs, outputs) | ||
with imageio.get_writer(config.output.save_video_path, fps=10) as writer: | ||
for frame in masked_video_frames: | ||
writer.append_data(frame) | ||
writer.close() | ||
print("Inference completed.") | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Perform inference on a single MP4 video.") | ||
parser.add_argument("--config", default="configs/inference/movi_c.yml", help="Configuration to run") | ||
args = parser.parse_args() | ||
config = OmegaConf.load(args.config) | ||
main(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters