Skip to content

The official CLIP training codebase of Inf-CL: "Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss". A super memory-efficiency CLIP training scheme.

License

Notifications You must be signed in to change notification settings

DAMO-NLP-SG/Inf-CLIP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

29 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

If our project helps you, please give us a star ⭐ on GitHub to support us. πŸ™πŸ™

arXiv hf_paper PyPI
License Hits GitHub issues GitHub closed issues
zhihu Twitter

πŸ’‘ Some other multimodal foundation model projects from our team may interest you ✨.

VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding
Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing
github github arXiv

VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs
Zesen Cheng, Sicong Leng, Hang Zhang, Yifei Xin, Xin Li, Guanzheng Chen, Yongxin Zhu, Wenqi Zhang, Ziyang Luo, Deli Zhao, Lidong Bing
github github arXiv

The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio
Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing
github github arXiv

πŸ“° News

  • [2024.10.18] Release training and evaluation codes of Inf-CLIP.

πŸ› οΈ Requirements and Installation

Basic Dependencies:

  • Python >= 3.8
  • Pytorch >= 2.0.0
  • CUDA Version >= 11.8

[Remote] Install Inf-CL:

# remote installing
pip install inf_cl -i https://pypi.org/simple

[Local] Install Inf-CL:

pip install -e .

Install required packages:

git clone https://github.com/DAMO-NLP-SG/Inf-CLIP
cd Inf-CLIP
pip install -r requirements.txt

⭐ Features

inf_cl is the triton implementation of Inf-CL loss:

inf_clip is the CLIP training codebase with Inf-CL loss and other training features:

πŸ”‘ Usage

A simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting:

torchrun --nproc_per_node 2 tests/example.py
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np

from inf_cl import cal_inf_loss


def create_cl_tensors(rank, world_size):
    # Parameters
    dtype = torch.float32
    num_heads = 3        # Number of attention heads
    seq_length_q = 32768 # Sequence length
    seq_length_k = 32768
    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)

    # Randomly initialize inputs
    q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
    k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
    l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07)

    q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
    k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
    l = l.requires_grad_() # Logit scale

    return q, k, l


if __name__ == "__main__":
    # Assume that the distributed environment has been initialized
    dist.init_process_group("nccl")

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    torch.cuda.set_device(rank)

    # Exampled by Image-Text Contrastive Learning, q is the global image features, 
    # k is the text features, and l is the logit scale.
    q, k, l = create_cl_tensors(rank, world_size)

    # labels are diagonal elements by default. 
    # labels = torch.arange(q.shape[0])
    loss = cal_inf_loss(q, k, scale=l.exp())

    print(loss)

πŸš€ Main Results

Memory Cost

* denotes adopting "data offload" strategy.

Max Supported Batch Size

Speed

Batch Size Scaling

Training with larger data scale needs larger batch size.

πŸ—οΈ Training & Evaluation

Quick Start

To facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks.

  1. Training Data Structure:
Inf-CLIP
β”œβ”€β”€ datasets
β”‚   β”œβ”€β”€ cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md
|   |   β”œβ”€β”€ 0000.tar
|   |   β”œβ”€β”€ 0001.tar
|   |   β”œβ”€β”€ ...
|   |   └── 0301.tar
β”‚   β”œβ”€β”€ cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md
|   |   β”œβ”€β”€ 0000.tar
|   |   β”œβ”€β”€ 0001.tar
|   |   β”œβ”€β”€ ...
|   |   └── 1044.tar
β”‚   β”œβ”€β”€ laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md
|   |   β”œβ”€β”€ 00000.tar
|   |   β”œβ”€β”€ 00001.tar
|   |   β”œβ”€β”€ ...
|   |   └── 41407.tar
  1. Command:
bash scripts/cc3m/lit_vit-b-32_bs16k.sh
bash scripts/cc12m/lit_vit-b-32_bs32k.sh
bash scripts/laion400m/lit_vit-b-32_bs256k.sh
  1. Evaluation Data Structure:
Inf-CLIP
β”œβ”€β”€ datasets
β”‚   β”œβ”€β”€ imagenet-1k/ # download val_images.tar.gz of imagenet
|   |   └── val/
|   |   |   β”œβ”€β”€ n01440764
|   |   |   β”œβ”€β”€ n01443537
|   |   |   β”œβ”€β”€ ...
|   |   |   └── n15075141
β”‚   β”œβ”€β”€ clip-benchmark/ # bash datasets/benchmarks_download.sh
|   |   β”œβ”€β”€ wds_mscoco_captions
|   |   β”œβ”€β”€ wds_flickr8k
|   |   β”œβ”€β”€ wds_flickr30k
|   |   β”œβ”€β”€ wds_imagenet1k
|   |   β”œβ”€β”€ wds_imagenetv2
|   |   β”œβ”€β”€ wds_imagenet_sketch
|   |   β”œβ”€β”€ wds_imagenet-a
|   |   β”œβ”€β”€ wds_imagenet-r
|   |   β”œβ”€β”€ wds_imagenet-o
|   |   └── wds_objectnet
  1. Command:
# imagenet evaluation
bash scripts/imagenet_eval.sh
# overall evaluation
bash scripts/benchmarks_eval.sh

πŸ“‘ Citation

If you find Inf-CLIP useful for your research and applications, please cite using this BibTeX:

@article{damovl2024infcl,
  title={Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss},
  author={Zesen Cheng, Hang Zhang, Kehan Li, Sicong Leng, Zhiqiang Hu, Fei Wu, Deli Zhao, Xin Li, Lidong Bing},
  journal={arXiv preprint arXiv:2410.17243},
  year={2024},
  url={https://arxiv.org/abs/2410.12787}
}

πŸ‘ Acknowledgement

The codebase of Inf-CLIP is adapted from OpenCLIP. We are also grateful for the following projects our Inf-CL arose from:

πŸ”’ License

This project is released under the Apache 2.0 license as found in the LICENSE file. The service is a research preview intended for non-commercial use ONLY, subject to the model Licenses of CLIP, Terms of Use of the data generated by OpenAI, and Laion. Please get in touch with us if you find any potential violations.

About

The official CLIP training codebase of Inf-CL: "Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss". A super memory-efficiency CLIP training scheme.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •