This repo is part of the project "Reproduce Neural ODE and SDE" which was undertaken under HuggingFace Flax/JAX community week 2021. The full submition can be found here.
We used the model proposed in the paper Score-Based Generative Modeling through Stochastic Differential Equations to generate Birdcall using Mel Spectograms. We used the notebook provided in the original repo and fne tuned it to our Mel Spectogram dataset.
In order to run the sampler or training script, required dependecies are required.
For JAX installation, please follow here.
or simply, type
pip install jax jaxlib
For Flax installation,
pip install flax
- Librosa
- Soundfile
- Numpy
- Matplotlib
- torch
- torchvision
- tensorflow
- tqdm
- scipy
These are the codes for the bird call generation score sde model.
core-sde-sampler.py
will execute the sampler. The sampler uses pretrained weight to generate bird calls. The ckpt.flax file is the weight.
A sample of generated bird call.
Sample.3.mp4
For using different sample generation parameters change the argument values. For example,
python main.py --sigma=25 --num_steps=500 --signal_to_noise_ratio=0.10 --etol=1e-5 --sample_batch_size = 128 --sample_no = 47
In order to generate the audios, these dependencies are required,
pip install librosa
pip install soundfile
In order to train the model from scratch, please generate the dataset using this notebook. The dataset is generated in kaggle. Therefore, during training your username and api key is required in the specified section inside the script.
python main.py --sigma=35 --n_epochs=1000 --batch_size=512 --lr=1e-3 --num_steps=500 --signal_to_noise_ratio=0.15 --etol=1e-5 --sample_batch_size = 64 --sample_no = 23
We have also provided the colab notebook we used to train the model. More generated bird call samples can be found there apart from here.