From 091caa890c99421e49fdff7485bf08bdfc946a85 Mon Sep 17 00:00:00 2001 From: Colin Wang Date: Sun, 18 Aug 2024 18:18:17 -0400 Subject: [PATCH] Add support for domain-specific models: 1. ChartAssistant (`chartast.py`) 2. ChartInstruct (`chartinstruct.py`) 3. ChartLlama (`chartllama.py`) 4. CogAgent (`cogagent.py`) 5. DocOwl1.5 (`docowl15.py`) 6. TextMonkey (`textmonkey.py`) 7. TinyChart (`tinychart.py`) 8. UniChart (`unichart.py`) 9. UReader (`ureader.py`) --- src/generate_lib/chartast.py | 98 ++++++++++++++++++ src/generate_lib/chartinstruct.py | 46 +++++++++ src/generate_lib/chartllama.py | 160 ++++++++++++++++++++++++++++++ src/generate_lib/cogagent.py | 49 +++++++++ src/generate_lib/docowl15.py | 25 +++++ src/generate_lib/textmonkey.py | 82 +++++++++++++++ src/generate_lib/tinychart.py | 44 ++++++++ src/generate_lib/unichart.py | 38 +++++++ src/generate_lib/ureader.py | 55 ++++++++++ 9 files changed, 597 insertions(+) create mode 100644 src/generate_lib/chartast.py create mode 100644 src/generate_lib/chartinstruct.py create mode 100644 src/generate_lib/chartllama.py create mode 100644 src/generate_lib/cogagent.py create mode 100644 src/generate_lib/docowl15.py create mode 100644 src/generate_lib/textmonkey.py create mode 100644 src/generate_lib/tinychart.py create mode 100644 src/generate_lib/unichart.py create mode 100644 src/generate_lib/ureader.py diff --git a/src/generate_lib/chartast.py b/src/generate_lib/chartast.py new file mode 100644 index 0000000..46709b2 --- /dev/null +++ b/src/generate_lib/chartast.py @@ -0,0 +1,98 @@ +# Adapted from https://github.com/OpenGVLab/ChartAst/blob/main/accessory/single_turn_eval.py +# This has support for the ChartAssistant model + +import os +vlm_codebase = os.environ['VLM_CODEBASE_DIR'] + +import sys +sys.path.append(vlm_codebase + '/ChartAst/accessory') + +os.environ['MP'] = '1' +os.environ['WORLD_SIZE'] = '1' + +import torch +from tqdm import tqdm +import torch.distributed as dist + + +sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0]) +from fairscale.nn.model_parallel import initialize as fs_init +from model.meta import MetaModel +from util.tensor_parallel import load_tensor_parallel_model_list +from util.misc import init_distributed_mode +from PIL import Image + +import torchvision.transforms as transforms + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +from PIL import Image +import os +import torch + + +class PadToSquare: + def __init__(self, background_color): + """ + pad an image to squre (borrowed from LLAVA, thx) + :param background_color: rgb values for padded pixels, normalized to [0, 1] + """ + self.bg_color = tuple(int(x * 255) for x in background_color) + + def __call__(self, img: Image.Image): + width, height = img.size + if width == height: + return img + elif width > height: + result = Image.new(img.mode, (width, width), self.bg_color) + result.paste(img, (0, (width - height) // 2)) + return result + else: + result = Image.new(img.mode, (height, height), self.bg_color) + result.paste(img, ((height - width) // 2, 0)) + return result + +def T_padded_resize(size=448): + t = transforms.Compose([ + PadToSquare(background_color=(0.48145466, 0.4578275, 0.40821073)), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) + return t + +def generate_response(queries, model_path): + init_distributed_mode() + fs_init.initialize_model_parallel(dist.get_world_size()) + model = MetaModel('llama_ens5', model_path + '/params.json', model_path + '/tokenizer.model', with_visual=True) + print(f"load pretrained from {model_path}") + load_tensor_parallel_model_list(model, model_path) + model.bfloat16().cuda() + max_gen_len = 512 + gen_t = 0.9 + top_p = 0.5 + + for k in tqdm(queries): + question = queries[k]['question'] + img_path = queries[k]['figure_path'] + + prompt = f"""Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\nPlease answer my question based on the chart: {question}\n\n### Response:""" + + image = Image.open(img_path).convert('RGB') + transform_val = T_padded_resize(448) + image = transform_val(image).unsqueeze(0) + image = image.cuda() + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + response = model.generate([prompt], image, max_gen_len=max_gen_len, temperature=gen_t, top_p=top_p) + response = response[0].split('###')[0] + print(response) + queries[k]['response'] = response diff --git a/src/generate_lib/chartinstruct.py b/src/generate_lib/chartinstruct.py new file mode 100644 index 0000000..546aa4d --- /dev/null +++ b/src/generate_lib/chartinstruct.py @@ -0,0 +1,46 @@ +# Adapted from https://huggingface.co/ahmed-masry/ChartInstruct-LLama2, https://huggingface.co/ahmed-masry/ChartInstruct-FlanT5-XL +# This has support for two ChartInstruct models, LLama2 and FlanT5 + +from PIL import Image +from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoModelForSeq2SeqLM +import torch +from tqdm import tqdm + +def generate_response(queries, model_path): + if "LLama2" in model_path: + print("Using LLama2 model") + model = LlavaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16) + elif "FlanT5" in model_path: + print("Using FlanT5 model") + model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True) + else: + raise ValueError(f"Model {model_path} not supported") + processor = AutoProcessor.from_pretrained(model_path) + + + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) + + for k in tqdm(queries): + image_path = queries[k]['figure_path'] + input_prompt = queries[k]['question'] + input_prompt = f"\n Question: {input_prompt} Answer: " + + image = Image.open(image_path).convert('RGB') + inputs = processor(text=input_prompt, images=image, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + # change type if pixel_values in inputs to fp16. + inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16) + if "LLama2" in model_path: + prompt_length = inputs['input_ids'].shape[1] + + # move to device + inputs = {k: v.to(device) for k, v in inputs.items()} + + # Generate + generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512) + output_text = processor.batch_decode(generate_ids[:, prompt_length:] \ + if 'LLama2' in model_path else generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + print(output_text) + queries[k]['response'] = output_text diff --git a/src/generate_lib/chartllama.py b/src/generate_lib/chartllama.py new file mode 100644 index 0000000..8d5eb2f --- /dev/null +++ b/src/generate_lib/chartllama.py @@ -0,0 +1,160 @@ +# Adapted from https://github.com/tingxueronghua/ChartLlama-code/blob/main/model_vqa_lora.py +# This has support for the Chartllama model + +### HEADER START ### +import os +vlm_codebase = os.environ['VLM_CODEBASE_DIR'] + +import sys +sys.path.append(vlm_codebase + '/ChartLlama-code') +### HEADER END ### + +import argparse +import torch +import os +import json +from tqdm import tqdm +import shortuuid +import warnings +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model import * +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path +from torch.utils.data import Dataset, DataLoader + +from PIL import Image +import math + +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"): + kwargs = {"device_map": device_map} + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + + # Load LLaVA model + if model_base is None: + raise ValueError('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') + if model_base is not None: + lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading LLaVA from base model...') + model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features + if model.lm_head.weight.shape[0] != token_num: + model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + + print('Loading additional LLaVA weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + else: + # this is probably from HF Hub + from huggingface_hub import hf_hub_download + def load_from_hf(repo_id, filename, subfolder=None): + cache_file = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder) + return torch.load(cache_file, map_location='cpu') + non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + + image_processor = None + + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) + if mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + model.resize_token_embeddings(len(tokenizer)) + + vision_tower = model.get_vision_tower() + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(device=device, dtype=torch.float16) + image_processor = vision_tower.image_processor + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, image_processor, context_len + + +def generate_response(queries, model_path): + disable_torch_init() + base_model_path, model_path= model_path.split('::') + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, base_model_path, None) + conv_mode = "vicuna_v1" + + def process(image, question, tokenizer, image_processor, model_config): + qs = question.replace(DEFAULT_IMAGE_TOKEN, '').strip() + if model.config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + conv = conv_templates[conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + image_tensor = process_images([image], image_processor, model_config)[0] + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + + return input_ids, image_tensor + + for k in tqdm(queries): + image_path = queries[k]['figure_path'] + image = Image.open(image_path).convert('RGB') + question = queries[k]['question'] + + input_ids, image_tensor = process(image, question, tokenizer, image_processor, model.config) + stop_str = conv_templates[conv_mode].sep if conv_templates[conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[conv_mode].sep2 + input_ids = input_ids.to(device='cuda', non_blocking=True).unsqueeze(0) # added the unsqueeze(0) to make it batch size 1 + image_tensor = image_tensor.unsqueeze(0) # added the unsqueeze(0) to make it batch size 1 + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + do_sample=False, + max_new_tokens=1636, + use_cache=True + ) + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + queries[k]['response'] = outputs diff --git a/src/generate_lib/cogagent.py b/src/generate_lib/cogagent.py new file mode 100644 index 0000000..8279857 --- /dev/null +++ b/src/generate_lib/cogagent.py @@ -0,0 +1,49 @@ +# Adapted from https://huggingface.co/THUDM/cogagent-vqa-hf +# This has support for the CogAgent model + +import torch +from PIL import Image +from transformers import AutoModelForCausalLM, LlamaTokenizer +from tqdm import tqdm + +def generate_response(queries, model_path): + DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + torch_type = torch.bfloat16 + tokenizer_path, model_path = model_path.split('::') + tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + load_in_4bit=False, + trust_remote_code=True + ).to('cuda').eval() + + for k in tqdm(queries): + image_path = queries[k]['figure_path'] + image = Image.open(image_path).convert('RGB') + query = f"Human:{queries[k]['question']}" + history = [] + + input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image]) + inputs = { + 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), + 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), + 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), + 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]], + } + if 'cross_images' in input_by_model and input_by_model['cross_images']: + inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]] + + # add any transformers params here. + gen_kwargs = {"max_length": 2048, + "temperature": 0.9, + "do_sample": False} + with torch.no_grad(): + outputs = model.generate(**inputs, **gen_kwargs) + outputs = outputs[:, inputs['input_ids'].shape[1]:] + response = tokenizer.decode(outputs[0]) + response = response.split("")[0] + print("\nCog:", response) + print('model_answer:', response) + queries[k]['response'] = response diff --git a/src/generate_lib/docowl15.py b/src/generate_lib/docowl15.py new file mode 100644 index 0000000..01d39e9 --- /dev/null +++ b/src/generate_lib/docowl15.py @@ -0,0 +1,25 @@ +# Adapted from https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/DocOwl1.5/docowl_infer.py +# This has support for the DocOwl model + +### HEADER START ### +import os +vlm_codebase = os.environ['VLM_CODEBASE_DIR'] + +import sys +sys.path.append(vlm_codebase + '/mPLUG-DocOwl/DocOwl1.5') +### HEADER END ### + +from docowl_infer import DocOwlInfer +from tqdm import tqdm +import os + +def generate_response(queries, model_path): + docowl = DocOwlInfer(ckpt_path=model_path, anchors='grid_9', add_global_img=True) + print('load model from ', model_path) + # infer the test samples one by one + for k in tqdm(queries): + image = queries[k]['figure_path'] + question = queries[k]['question'] + model_answer = docowl.inference(image, question) + print('model_answer:', model_answer) + queries[k]['response'] = model_answer diff --git a/src/generate_lib/textmonkey.py b/src/generate_lib/textmonkey.py new file mode 100644 index 0000000..9f449c6 --- /dev/null +++ b/src/generate_lib/textmonkey.py @@ -0,0 +1,82 @@ +# Adapted from https://github.com/Yuliang-Liu/Monkey/blob/main/demo_textmonkey.py +# This has support for the TextMonkey model + +import os +vlm_codebase = os.environ['VLM_CODEBASE_DIR'] + +import sys +sys.path.append(vlm_codebase + '/Monkey') + +import re +import gradio as gr +from PIL import Image, ImageDraw, ImageFont +from monkey_model.modeling_textmonkey import TextMonkeyLMHeadModel +from monkey_model.tokenization_qwen import QWenTokenizer +from monkey_model.configuration_monkey import MonkeyConfig +from tqdm import tqdm + +def generate_response(queries, model_path): + device_map = "cuda" + # Create model + config = MonkeyConfig.from_pretrained( + model_path, + trust_remote_code=True, + ) + model = TextMonkeyLMHeadModel.from_pretrained(model_path, + config=config, + device_map=device_map, trust_remote_code=True).eval() + tokenizer = QWenTokenizer.from_pretrained(model_path, + trust_remote_code=True) + tokenizer.padding_side = 'left' + tokenizer.pad_token_id = tokenizer.eod_id + tokenizer.IMG_TOKEN_SPAN = config.visual["n_queries"] + + for k in tqdm(queries): + input_image = queries[k]['figure_path'] + input_str = queries[k]['question'] + input_str = f"{input_image} {input_str}" + input_ids = tokenizer(input_str, return_tensors='pt', padding='longest') + + attention_mask = input_ids.attention_mask + input_ids = input_ids.input_ids + + pred = model.generate( + input_ids=input_ids.cuda(), + attention_mask=attention_mask.cuda(), + do_sample=False, + num_beams=1, + max_new_tokens=2048, + min_new_tokens=1, + length_penalty=1, + num_return_sequences=1, + output_hidden_states=True, + use_cache=True, + pad_token_id=tokenizer.eod_id, + eos_token_id=tokenizer.eod_id, + ) + response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=False).strip() + image = Image.open(input_image).convert("RGB").resize((1000,1000)) + font = ImageFont.truetype('NimbusRoman-Regular.otf', 22) + bboxes = re.findall(r'(.*?)', response, re.DOTALL) + refs = re.findall(r'(.*?)', response, re.DOTALL) + if len(refs)!=0: + num = min(len(bboxes), len(refs)) + else: + num = len(bboxes) + for box_id in range(num): + bbox = bboxes[box_id] + matches = re.findall( r"\((\d+),(\d+)\)", bbox) + draw = ImageDraw.Draw(image) + point_x = (int(matches[0][0])+int(matches[1][0]))/2 + point_y = (int(matches[0][1])+int(matches[1][1]))/2 + point_size = 8 + point_bbox = (point_x - point_size, point_y - point_size, point_x + point_size, point_y + point_size) + draw.ellipse(point_bbox, fill=(255, 0, 0)) + if len(refs)!=0: + text = refs[box_id] + text_width, text_height = font.getsize(text) + draw.text((point_x-text_width//2, point_y+8), text, font=font, fill=(255, 0, 0)) + response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip() + output_str = response + print(f"Answer: {output_str}") + queries[k]['response'] = output_str diff --git a/src/generate_lib/tinychart.py b/src/generate_lib/tinychart.py new file mode 100644 index 0000000..10677aa --- /dev/null +++ b/src/generate_lib/tinychart.py @@ -0,0 +1,44 @@ +# Adapted from https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/TinyChart/inference.ipynb +# This has support for the TinyChart model + +### HEADER START ### +import os +vlm_codebase = os.environ['VLM_CODEBASE_DIR'] + +import sys +sys.path.append(vlm_codebase + '/mPLUG-DocOwl/TinyChart') +### HEADER END ### + +from tqdm import tqdm +import torch +from PIL import Image +from tinychart.model.builder import load_pretrained_model +from tinychart.mm_utils import get_model_name_from_path +from tinychart.eval.run_tiny_chart import inference_model +from tinychart.eval.eval_metric import parse_model_output, evaluate_cmds + +def generate_response(queries, model_path): + tokenizer, model, image_processor, context_len = load_pretrained_model( + model_path, + model_base=None, + model_name=get_model_name_from_path(model_path), + device="cuda" # device="cpu" if running on cpu + ) + for k in tqdm(queries): + img_path = queries[k]['figure_path'] + text = queries[k]['question'] + response = inference_model([img_path], text, model, tokenizer, image_processor, context_len, conv_mode="phi", max_new_tokens=1024) + # print(response) + try: + response = evaluate_cmds(parse_model_output(response)) + print('Command successfully executed') + print(response) + except Exception as e: + # if message is NameError: name 'Answer' is not defined, then skip + if "Error: name 'Answer' is not defined" in str(e): + response = response + else: + print('Error:', e) + response = response + response = str(response) + queries[k]['response'] = response diff --git a/src/generate_lib/unichart.py b/src/generate_lib/unichart.py new file mode 100644 index 0000000..ed1d606 --- /dev/null +++ b/src/generate_lib/unichart.py @@ -0,0 +1,38 @@ +# Adapted from https://github.com/vis-nlp/UniChart/blob/main/README.md +# This has support for the UniChart model + +from transformers import DonutProcessor, VisionEncoderDecoderModel +from PIL import Image +import torch +from tqdm import tqdm + +def generate_response(queries, model_path): + model = VisionEncoderDecoderModel.from_pretrained(model_path) + processor = DonutProcessor.from_pretrained(model_path) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + for k in tqdm(queries): + image_path = queries[k]['figure_path'] + input_prompt = queries[k]['question'] + input_prompt = f" {input_prompt} " + image = Image.open(image_path).convert("RGB") + decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids + pixel_values = processor(image, return_tensors="pt").pixel_values + + outputs = model.generate( + pixel_values.to(device), + decoder_input_ids=decoder_input_ids.to(device), + max_length=model.decoder.config.max_position_embeddings, + early_stopping=True, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, + use_cache=True, + num_beams=4, + bad_words_ids=[[processor.tokenizer.unk_token_id]], + return_dict_in_generate=True, + ) + sequence = processor.batch_decode(outputs.sequences)[0] + sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") + sequence = sequence.split("")[1].strip() + queries[k]['response'] = sequence diff --git a/src/generate_lib/ureader.py b/src/generate_lib/ureader.py new file mode 100644 index 0000000..879fd15 --- /dev/null +++ b/src/generate_lib/ureader.py @@ -0,0 +1,55 @@ +# adapted from https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/UReader/pipeline/interface.py +# This has support for the UReader model + +### HEADER START ### +import os +vlm_codebase = os.environ['VLM_CODEBASE_DIR'] + +import sys +sys.path.append(vlm_codebase + '/mPLUG-DocOwl/UReader') + +UREADER_DIR = os.path.join(vlm_codebase, 'mPLUG-DocOwl/UReader/') +### HEADER END ### + +import os +import torch +from sconf import Config +from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration +from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer +from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor +import torch +from pipeline.data_utils.processors.builder import build_processors +from pipeline.data_utils.processors import * +from pipeline.utils import add_config_args, set_args +import argparse + +from PIL import Image +from tqdm import tqdm + +def generate_response(queries, model_path): + config = Config("{}configs/sft/release.yaml".format(UREADER_DIR)) + args = argparse.ArgumentParser().parse_args([]) + add_config_args(config, args) + set_args(args) + model = MplugOwlForConditionalGeneration.from_pretrained( + model_path, + ) + model.eval() + model.cuda() + model.half() + image_processor = build_processors(config['valid_processors'])['sft'] + tokenizer = MplugOwlTokenizer.from_pretrained(model_path) + processor = MplugOwlProcessor(image_processor, tokenizer) + + for k in tqdm(queries): + image_path = queries[k]['figure_path'] + images = [Image.open(image_path).convert('RGB')] + question = f"Human: \nHuman: {queries[k]['question']}\nAI: " + inputs = processor(text=question, images=images, return_tensors='pt') + inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} + inputs = {k: v.to(model.device) for k, v in inputs.items()} + with torch.no_grad(): + res = model.generate(**inputs) + sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) + print('model_answer:', sentence) + queries[k]['response'] = sentence