Skip to content

Commit

Permalink
Add support for domain-specific models:
Browse files Browse the repository at this point in the history
1. ChartAssistant (`chartast.py`)
2. ChartInstruct (`chartinstruct.py`)
3. ChartLlama (`chartllama.py`)
4. CogAgent (`cogagent.py`)
5. DocOwl1.5 (`docowl15.py`)
6. TextMonkey (`textmonkey.py`)
7. TinyChart (`tinychart.py`)
8. UniChart (`unichart.py`)
9. UReader (`ureader.py`)
  • Loading branch information
Colin Wang committed Aug 18, 2024
1 parent 4e52142 commit 091caa8
Show file tree
Hide file tree
Showing 9 changed files with 597 additions and 0 deletions.
98 changes: 98 additions & 0 deletions src/generate_lib/chartast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Adapted from https://github.com/OpenGVLab/ChartAst/blob/main/accessory/single_turn_eval.py
# This has support for the ChartAssistant model

import os
vlm_codebase = os.environ['VLM_CODEBASE_DIR']

import sys
sys.path.append(vlm_codebase + '/ChartAst/accessory')

os.environ['MP'] = '1'
os.environ['WORLD_SIZE'] = '1'

import torch
from tqdm import tqdm
import torch.distributed as dist


sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
from fairscale.nn.model_parallel import initialize as fs_init
from model.meta import MetaModel
from util.tensor_parallel import load_tensor_parallel_model_list
from util.misc import init_distributed_mode
from PIL import Image

import torchvision.transforms as transforms

try:
from torchvision.transforms import InterpolationMode

BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC

from PIL import Image
import os
import torch


class PadToSquare:
def __init__(self, background_color):
"""
pad an image to squre (borrowed from LLAVA, thx)
:param background_color: rgb values for padded pixels, normalized to [0, 1]
"""
self.bg_color = tuple(int(x * 255) for x in background_color)

def __call__(self, img: Image.Image):
width, height = img.size
if width == height:
return img
elif width > height:
result = Image.new(img.mode, (width, width), self.bg_color)
result.paste(img, (0, (width - height) // 2))
return result
else:
result = Image.new(img.mode, (height, height), self.bg_color)
result.paste(img, ((height - width) // 2, 0))
return result

def T_padded_resize(size=448):
t = transforms.Compose([
PadToSquare(background_color=(0.48145466, 0.4578275, 0.40821073)),
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
return t

def generate_response(queries, model_path):
init_distributed_mode()
fs_init.initialize_model_parallel(dist.get_world_size())
model = MetaModel('llama_ens5', model_path + '/params.json', model_path + '/tokenizer.model', with_visual=True)
print(f"load pretrained from {model_path}")
load_tensor_parallel_model_list(model, model_path)
model.bfloat16().cuda()
max_gen_len = 512
gen_t = 0.9
top_p = 0.5

for k in tqdm(queries):
question = queries[k]['question']
img_path = queries[k]['figure_path']

prompt = f"""Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\nPlease answer my question based on the chart: {question}\n\n### Response:"""

image = Image.open(img_path).convert('RGB')
transform_val = T_padded_resize(448)
image = transform_val(image).unsqueeze(0)
image = image.cuda()

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
response = model.generate([prompt], image, max_gen_len=max_gen_len, temperature=gen_t, top_p=top_p)
response = response[0].split('###')[0]
print(response)
queries[k]['response'] = response
46 changes: 46 additions & 0 deletions src/generate_lib/chartinstruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Adapted from https://huggingface.co/ahmed-masry/ChartInstruct-LLama2, https://huggingface.co/ahmed-masry/ChartInstruct-FlanT5-XL
# This has support for two ChartInstruct models, LLama2 and FlanT5

from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoModelForSeq2SeqLM
import torch
from tqdm import tqdm

def generate_response(queries, model_path):
if "LLama2" in model_path:
print("Using LLama2 model")
model = LlavaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
elif "FlanT5" in model_path:
print("Using FlanT5 model")
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
else:
raise ValueError(f"Model {model_path} not supported")
processor = AutoProcessor.from_pretrained(model_path)


device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

for k in tqdm(queries):
image_path = queries[k]['figure_path']
input_prompt = queries[k]['question']
input_prompt = f"<image>\n Question: {input_prompt} Answer: "

image = Image.open(image_path).convert('RGB')
inputs = processor(text=input_prompt, images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

# change type if pixel_values in inputs to fp16.
inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)
if "LLama2" in model_path:
prompt_length = inputs['input_ids'].shape[1]

# move to device
inputs = {k: v.to(device) for k, v in inputs.items()}

# Generate
generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
output_text = processor.batch_decode(generate_ids[:, prompt_length:] \
if 'LLama2' in model_path else generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output_text)
queries[k]['response'] = output_text
160 changes: 160 additions & 0 deletions src/generate_lib/chartllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Adapted from https://github.com/tingxueronghua/ChartLlama-code/blob/main/model_vqa_lora.py
# This has support for the Chartllama model

### HEADER START ###
import os
vlm_codebase = os.environ['VLM_CODEBASE_DIR']

import sys
sys.path.append(vlm_codebase + '/ChartLlama-code')
### HEADER END ###

import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
import warnings
import shutil

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model import *
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import math

def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
kwargs = {"device_map": device_map}

if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.float16

# Load LLaVA model
if model_base is None:
raise ValueError('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
if model_base is not None:
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
print('Loading LLaVA from base model...')
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))

print('Loading additional LLaVA weights...')
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder)
return torch.load(cache_file, map_location='cpu')
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith('model.model.') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
model.load_state_dict(non_lora_trainables, strict=False)

from peft import PeftModel
print('Loading LoRA weights...')
model = PeftModel.from_pretrained(model, model_path)
print('Merging LoRA weights...')
model = model.merge_and_unload()
print('Model is loaded...')

image_processor = None

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048

return tokenizer, model, image_processor, context_len


def generate_response(queries, model_path):
disable_torch_init()
base_model_path, model_path= model_path.split('::')
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, base_model_path, None)
conv_mode = "vicuna_v1"

def process(image, question, tokenizer, image_processor, model_config):
qs = question.replace(DEFAULT_IMAGE_TOKEN, '').strip()
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

image_tensor = process_images([image], image_processor, model_config)[0]

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

return input_ids, image_tensor

for k in tqdm(queries):
image_path = queries[k]['figure_path']
image = Image.open(image_path).convert('RGB')
question = queries[k]['question']

input_ids, image_tensor = process(image, question, tokenizer, image_processor, model.config)
stop_str = conv_templates[conv_mode].sep if conv_templates[conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[conv_mode].sep2
input_ids = input_ids.to(device='cuda', non_blocking=True).unsqueeze(0) # added the unsqueeze(0) to make it batch size 1
image_tensor = image_tensor.unsqueeze(0) # added the unsqueeze(0) to make it batch size 1
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
do_sample=False,
max_new_tokens=1636,
use_cache=True
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
queries[k]['response'] = outputs
49 changes: 49 additions & 0 deletions src/generate_lib/cogagent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Adapted from https://huggingface.co/THUDM/cogagent-vqa-hf
# This has support for the CogAgent model

import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
from tqdm import tqdm

def generate_response(queries, model_path):
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_type = torch.bfloat16
tokenizer_path, model_path = model_path.split('::')
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
load_in_4bit=False,
trust_remote_code=True
).to('cuda').eval()

for k in tqdm(queries):
image_path = queries[k]['figure_path']
image = Image.open(image_path).convert('RGB')
query = f"Human:{queries[k]['question']}"
history = []

input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]],
}
if 'cross_images' in input_by_model and input_by_model['cross_images']:
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]

# add any transformers params here.
gen_kwargs = {"max_length": 2048,
"temperature": 0.9,
"do_sample": False}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0])
response = response.split("</s>")[0]
print("\nCog:", response)
print('model_answer:', response)
queries[k]['response'] = response
25 changes: 25 additions & 0 deletions src/generate_lib/docowl15.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Adapted from https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/DocOwl1.5/docowl_infer.py
# This has support for the DocOwl model

### HEADER START ###
import os
vlm_codebase = os.environ['VLM_CODEBASE_DIR']

import sys
sys.path.append(vlm_codebase + '/mPLUG-DocOwl/DocOwl1.5')
### HEADER END ###

from docowl_infer import DocOwlInfer
from tqdm import tqdm
import os

def generate_response(queries, model_path):
docowl = DocOwlInfer(ckpt_path=model_path, anchors='grid_9', add_global_img=True)
print('load model from ', model_path)
# infer the test samples one by one
for k in tqdm(queries):
image = queries[k]['figure_path']
question = queries[k]['question']
model_answer = docowl.inference(image, question)
print('model_answer:', model_answer)
queries[k]['response'] = model_answer
Loading

0 comments on commit 091caa8

Please sign in to comment.