- Finish implementing the
UNet2D
model inmodeling_unte2d.py
. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file. - Adapt the
PNDMScheduler
fromdiffusers
for JAX: Usejnp
arrays and make it stateless. - Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in
modeling_vae.py
file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence. - Add an inference loop in
pipeline_stabel_diffusion
. We should able tojit
/pmap
the loop to deploy on TPUs.
-
Notifications
You must be signed in to change notification settings - Fork 8
License
patil-suraj/stable-diffusion-jax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
About
No description, website, or topics provided.
Resources
License
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published