Skip to content

Latest commit

 

History

History
137 lines (87 loc) · 4.8 KB

README.md

File metadata and controls

137 lines (87 loc) · 4.8 KB

Gated Delta Networks: Improving Mamba2 with Delta Rule

nvidia-deltanet-badge

Official PyTorch implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule.

Star on GitHub

Songlin Yang, Jan Kautz and Ali Hatamizadeh.

For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing

🌟 Why Gated DeltaNet?

Gated DeltaNet introduces a novel approach to linear transformers by combining:

  • 🧠 Smart Memory Management: Intelligent memory management that knows what to keep and what to forget
  • Precise Updates: Targeted memory updates that enhance model efficiency
  • 💻 Hardware Efficiency: Optimized implementation for real-world deployment

Architecture Overview

Efficiency

Gated DeltaNet shows exceptional performance in terms of training throughput compared to models like Mamba2 and Samba:

Language Modeling and Reasoning

Our model outperforms competitors of various types(e.g. Transformer, RNN, hybrid) in terms of perplexity and zero-shot accuracy on reasoning benchmarks:

Long-context

Gated DeltaNet also achieves favorable perplexity scores on long-context benchmarks:

📢 Latest Updates

  • 12/09/2024: 🔥 Code Release: Train your own Gated DeltaNet on Slimpajama dataset
  • Watch this space for more exciting updates!

🚀 Getting Started

Training Your Model

Launch your training with our streamlined command:

python ../pretrain.py \
--train_data_dir ${TRAIN_DATA} \
--val_data_dir ${VALIDATION_DATA} \
--output_root ${SAVE_DIR} \
--exp_name ${NAME} \
--model_name ${MODEL} \
--train_config ${CONFIG} \
--eval_iters ${EVAL_ITERS} \
--learning_rate ${LR} \
--micro_batch_size ${MICRO_BATCH_SIZE}

💡 Pro Tip: Add --interactive_job --debug for interactive debugging sessions!

Please see this slurm script for training the GatedDeltaNet_H1 model with 0.4B parameters on 15B tokens. The training requires 4 nodes and can be finished in approximately 4 hours. For this run, the validation loss and perplexitty curves (1x & 2x for lengh extrapolation) are expected as follows:

curves

📜 License

Copyright © 2024, NVIDIA Corporation. All rights reserved.

Licensed under the NVIDIA Source Code License-NC. See LICENSE for details.

🙏 Acknowledgements

Built on the shoulders of giants:

⭐ Support Us

If you find this work useful, please consider:

  • Starring the repository
  • Citing our paper
  • Contributing to the codebase

Join us in pushing the boundaries of linear transformers! 🚀

Citation

If you find Gated DeltaNet to be useful for your work, please consider citing our paper:

@article{yang2024gated,
  title={Gated Delta Networks: Improving Mamba2 with Delta Rule},
  author={Yang, Songlin and Kautz, Jan and Hatamizadeh, Ali},
  journal={arXiv preprint arXiv:2412.06464},
  year={2024}
}

Star History

Stargazers repo roster for @NVlabs/GatedDeltaNet

Star History Chart