To train the video model:
- separate_grasped_model_seg: for segmentation images
- separate_grasped_model: for rgb images
python fitvid/scripts/train_fitvid.py --output_dir /home/arpit/test_projects/fitvid/run_test_seg --dataset_file /home/arpit/test_projects/OmniGibson/dynamics_model_dataset_seg/dataset.hdf5 --wandb_online
To train the grasped model:
- separate_grasped_model_seg: to use segmentation images
- separate_grasped_model: to use rgb images
python fitvid/scripts/train_grasped_model.py --output_dir run_seg_grasped --dataset_file /home/arpit/test_projects/OmniGibson/dynamics_model_dataset_seg/dataset.hdf5 --pretrained_video_model /home/arpit/test_projects/fitvid/run_test_seg/model_epoch50_seg --wandb_online
Implementation of FitVid video prediction model in JAX/Flax.
If you find this code useful, please cite it in your paper:
@article{babaeizadeh2021fitvid,
title={FitVid: Overfitting in Pixel-Level Video Prediction},
author= {Babaeizadeh, Mohammad and Saffar, Mohammad Taghi and Nair, Suraj
and Levine, Sergey and Finn, Chelsea and Erhan, Dumitru},
journal={arXiv preprint arXiv:2106.13195},
year={2020}
}
FitVid is a new architecture for conditional variational video prediction. It has ~300 million parameters and can be trained with minimal training tricks.
Human3.6M | RoboNet |
---|---|
For more samples please visit FitVid. website: https://sites.google.com/view/fitvidpaper
Get dependencies:
pip3 install --user tensorflow
pip3 install --user tensorflow_addons
pip3 install --user flax
pip3 install --user ffmpeg
Train on RoboNet:
python -m fitvid.train --output_dir /tmp/output
Disclaimer: Not an official Google product.