This repository contains the official code for the paper "Enhancing Large Vision Language Models with Self-Training on Image Comprehension".
Authors (*Equal Contribution): Yihe Deng*, Pan Lu*, Fan Yin, Ziniu Hu, Sheng Shen, James Zou, Kai-Wei Chang, Wei Wang
Citation: If you find this repo useful for your research, please consider citing the paper
@misc{deng2024enhancing,
title={Enhancing Large Vision Language Models with Self-Training on Image Comprehension},
author={Yihe Deng and Pan Lu and Fan Yin and Ziniu Hu and Sheng Shen and James Zou and Kai-Wei Chang and Wei Wang},
year={2024},
eprint={2405.19716},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
- [05/30/2024] Our paper is released on arXiv: https://arxiv.org/abs/2405.19716.
Left: Accuracy improvement of STIC compared to the original LLaVA-v1.6 (Mistral-7B) on seven benchmarks. Right: Response examples from the original LLaVA-v1.6 and STIC (LLaVA-v1.6).
To tackle the data acquisition bottleneck in multi-modality, we propose Self-Training on Image Comprehension (STIC). Inspired by the recent success of self-training methods on LLMs, our method leverages self-generated data to improve their downstream performance. Different from the text-only domain, the unique vision modality of LVLMs introduces new challenges, as LVLMs must understand the input image content before reasoning and responding to any related textual queries about the image. Therefore, the proposed STIC approach is a novel two-stage self-training method that targets both image perception and reasoning over images and texts.
Framework overview of STIC, a two-stage self-training algorithm focusing on the image comprehension capability of the LVLMs.
๐ The overall framework is summarized in the above Figure. STIC specifically emphasizes the image comprehension self-training of LVLMs where the model generates its own preference dataset focused on image description.
- The self-generated dispreferred response is obtained by gathering model responses from either
- prompts likely to elicit inaccurate responses or
- corrupted images.
- The preferred responses are collected via a detailed prompt that guides the model through a step-by-step image description process.
The Figure below shows examples of such generated responses.
Framework overview of STIC, a two-stage self-training algorithm focusing on the image comprehension capability of the LVLMs.
During fine-tuning, we consider a DPO loss with an additional regularized term explicitly emphasizing the preferred response. At stage 2, we allow the model to self-improve its reasoning ability based on its own extracted image information by reusing a small amount of existing instruction fine-tuning data and appending its self-generated image descriptions to the prompts. We refer to this second stage as description-infused fine-tuning.
Notably, STIC does not require pre-labeled information of the images.
The following instructions provide the setup of environment on Linux.
- Create a virtual environment with Conda and activate.
conda create -n stic python=3.10 -y
conda activate stic
- Install packages
pip install --upgrade pip
pip install -e .
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
pip install trl
- Modify the TRL library adjust DPO for LVLMs. Replace
dpo_trainer.py
withstic/dpo_trainer. py
. The file can be found at the following directory, whereusername
should be replaced according the user's case.
rm /home/username/miniconda3/envs/stic/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py
cp ./stic/dpo_trainer.py /home/username/miniconda3/envs/stic/lib/python3.10/site-packages/trl/trainer/
- (For stage 1 fine-tuning) Download the 5k unlabeled image data for stage 1 from HuggingFace to your desired directory. Alternatively, one could also download the entire train2014 split from MSCOCO.
wget http://images.cocodataset.org/zips/train2014.zip
unzip train2014.zip
- (For stage 2 fine-tuning) [TODO: we will soon upload the specific 6k images for stage 2 to huggingface.] Download the image data for stage 2 to your desired directory and organize the data as follows.
wget http://images.cocodataset.org/zips/train2017.zip
wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip
wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip
โโโ coco
โ โโโ train2017
โโโ textvqa
โ โโโ train_images
โโโ vg
โโโ VG_100K
โโโ VG_100K_2
Download the 5k instruction fine-tuning data here and put it in ./data
.
We provide both the self-constructed preference data and the description-infused instruction data on HuggingFace.
Datasets | Download |
---|---|
Stage 1. Self-Constructed Preference Data | ๐ค HuggingFace |
Stage 2. Description-Infused Instruction Data | ๐ค HuggingFace |
Models (LoRA) | Download |
---|---|
Stage 1. Image Comprhension Self-Training | ๐ค HuggingFace |
Stage 2. Description-infused Fine-tuning | ๐ค HuggingFace |
Note: Skip to step 2 if using our provided preference data on Huggingface. Skip to step 4 if using our provided model checkpoint for stage 1 and description-infused data on Huggingface.
python stic/generate_pref.py [options]
Options
--model-path
: path to the target LVLM model for training (local or huggingface)- default:
liuhaotian/llava-v1.6-mistral-7b
- default:
--image-dir
: local directory to the unlabeled images- example:
/data/username/MSCOCO/train2014
- example:
--save-dir
: local directory/filename that will save the self-constructed preference data- default:
pref_data_mscoco.jsonl
will save the current directory
- default:
Example script:
CUDA_VISIBLE_DEVICES=0 bash scripts/generate_pref.sh
Note: parallelize the generation tasks on multiple gpus for faster generation.
python stic/convert_jsonl_to_json.py --input pref_data_mscoco.jsonl
or directly download the json file from huggingface.
bash scripts/dpo_finetune.sh
Options (change the necessary arguments in the shell script)
--data_path
: path to the input preference data (local or huggingface)- example:
pref_data_mscoco.json
- example:
--image_folder
: local directory to the unlabeled images- example:
/data/username/MSCOCO/train2014
- example:
--output_dir
: the directory to hold the lora weights after fine-tuning- example:
/data/username/checkpoints/llava_stic_stage1
- example:
Ensure the global batch size (number_of_devices * batch_size * gradient_accumulation_steps) is equal to our setting of 8.
python stic/generate_des_stage2.py [options]
Options
--model-path
: path to the target LVLM model for training (local or huggingface)- default:
liuhaotian/llava-v1.6-mistral-7b
- default:
--adapter-path
: path to the LoRA weights after stage 1 (local or huggingface)- example
checkpoints/llava_coco_test
- example
--image-dir
: local directory to the images for instruction fine-tuning- example:
/data/username/image_data
- example:
--save-dir
: local directory/filename that will save the self-constructed preference data- default:
image_description.jsonl
will save the current directory
- default:
Example script:
CUDA_VISIBLE_DEVICES=0 bash scripts/generate_des_stage2.sh
Parellelize the task across multiple gpus to speed up the process. Lastly, combine the description with the instruction fine-tuning data.
python stic/add_des_to_data.py
bash scripts/finetune_lora.sh
Note: change the argument for --load_peft
to change the LoRA weight that stage 2 will start from (for example, --load_peft STIC-LVLM/llava-v1.6-mistral-7b-STIC-stage1
).
Step 4. Evaluation (Please find more details for evaluation in Evaluation.)
(We note that, our evaluation scripts follow the ones released for LLaVA-1.5, as the new evaluation scripts for LLaVA-1.6 were not released at the time of this work. This may result in some evaluation differences from the official reported values of LLaVA-1.6, as similar in this issue. Nevertheless, we maintain the same evaluation scripts for with/without STIC to ensure fair comparison.)
Take MMBench for example. (pip install openpyxl
before evaluating with MMBench.)
Option 1. Evaluating the model performance without DaR (Describe and Respond)
python llava/eval/model_vqa_mmbench.py [options]
--load-peft
: path to the lora weights fine-tuned by SITC (local or huggingface)- default:
None
will evaluate the original LVLM model from--model-path
. - example:
ydeng9/llava-v1.6-mistral-7b-STIC
to use the provided LoRA weights.
- default:
--model-path
: path to the target LVLM model for training (local or huggingface)- default:
liuhaotian/llava-v1.6-mistral-7b
- default:
--answers-file
: local directory to the save the answers- example:
/data/username/MSCOCO/train2014
- example:
Option 2. Evaluating the model performance with DaR (Describe and Respond)
python llava/eval/model_vqa_mmbench_dar.py [options]
Arguments are the same as model_vqa_mmbench.py
.
Example script:
CUDA_VISIBLE_DEVICES=0 bash scripts/eval/mmbench_dar.sh
This repo is built upon LLaVA and POVID. We thank all the authors for their great work.