Skip to content

Perceiver (transformer variant) implemented in JAX and Flax

License

Notifications You must be signed in to change notification settings

badiadamas/perceiver-jax

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Perceiver - JAX (Flax)

Implementation of Perceiver in JAX and Flax. Also includes ReZero in lieu of LayerNorm, given its empirical benefits for very deep Transformers.

Install

pip install perceiver-jax

Be sure to also install the correct accelerated jaxlib for your hardware.

Usage

import jax

from perceiver_jax import Perceiver

model = Perceiver(
    n_fourier_features=6,
    depth=8,
    n_latents=512,  # if input length is much smaller than this, reconsider using this architecture
    latent_n_heads=8,
    latent_head_features=64,
    cross_n_heads=2,
    cross_head_features=128,
    attn_dropout=0.,
    ff_mult=4,
    ff_dropout=0.,
    tie_layer_weights=False,
)

RNG = jax.random.PRNGKey(42)
input_batch = jax.random.normal(RNG, (1, 224 * 224, 3))

y, variables = model.init_with_output({'params': RNG, 'dropout': RNG}, input_batch)

You'll notice the parametrization is slightly simpler than with PyTorch, as you can infer input feature dimension shapes in JAX.

Acknowledgements

Thanks to lucidrains and his PyTorch implementation on which this is heavily based.

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{bachlechner2020rezero,
      title={ReZero is All You Need: Fast Convergence at Large Depth}, 
      author={Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
      year={2020},
      eprint={2003.04887},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

Perceiver (transformer variant) implemented in JAX and Flax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%