Skip to content

Commit

Permalink
reafactor: florencev2 batch processing (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
hrnn committed Sep 20, 2024
1 parent 2eea1c2 commit 8ead0c3
Showing 1 changed file with 62 additions and 39 deletions.
101 changes: 62 additions & 39 deletions vision_agent_tools/models/florencev2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, List, Optional
from typing import Any, List

import torch
from PIL import Image
Expand Down Expand Up @@ -109,10 +109,17 @@ def __call__(
Returns:
Any: The output of the Florence-2 model based on the provided task, images/video, and prompt. The output type can vary depending on the chosen task.
"""

if isinstance(task, str):
try:
task = PromptTask(task)
except ValueError:
raise ValueError(f"Invalid task string: {task}")

if prompt is None:
text_input = task
else:
text_input = task + prompt
prompt = ""
elif not isinstance(prompt, str):
raise ValueError("prompt must be a string or None.")

# Validate input parameters
if image is None and images is None and video is None:
Expand All @@ -127,36 +134,45 @@ def __call__(
)

if image is not None:
# Single image processing
text_input = str(task.value) + prompt
image = self._process_image(image)
return self._single_image_call(text_input, image, task, prompt)
if images is not None:
results = []
for image in images:
processed_image = self._process_image(image)
result = self._single_image_call(
text_input, processed_image, task, prompt
)
results.append(result)
return results
if video is not None:
images = self._process_video(video)
return [
self._single_image_call(text_input, image, task, prompt)
for image in images
]

def _single_image_call(
results = self._batch_image_call([text_input], [image], task)
return results[0]
elif images is not None:
# Batch processing
images_list = [self._process_image(img) for img in images]
num_images = len(images_list)

# Create text_inputs by repeating the task and prompt for each image
text_input = str(task.value) + prompt
text_inputs = [text_input] * num_images

return self._batch_image_call(text_inputs, images_list, task)
elif video is not None:
# Process video frames
images_list = self._process_video(video)
num_images = len(images_list)

# Create text_inputs by repeating the task and prompt for each frame
text_input = str(task.value) + prompt
text_inputs = [text_input] * num_images

return self._batch_image_call(text_inputs, images_list, task)

def _batch_image_call(
self,
text_input: str,
image: Image.Image,
text_inputs: List[str],
images: List[Image.Image],
task: PromptTask,
prompt: str,
):
inputs = self._processor(text=text_input, images=image, return_tensors="pt").to(
self.device
)
inputs = self._processor(
text=text_inputs,
images=images,
return_tensors="pt",
).to(self.device)

with torch.autocast(self.device):
with torch.autocast(device_type=self.device):
generated_ids = self._model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
Expand All @@ -165,26 +181,33 @@ def _single_image_call(
early_stopping=False,
do_sample=False,
)
generated_text = self._processor.batch_decode(
generated_ids, skip_special_tokens=False
)[0]

parsed_answer = self._processor.post_process_generation(
generated_text, task=task, image_size=(image.width, image.height)
# Set skip_special_tokens based on the task
if task == PromptTask.OCR:
skip_special_tokens = True
else:
skip_special_tokens = False

generated_texts = self._processor.batch_decode(
generated_ids, skip_special_tokens=skip_special_tokens
)

return parsed_answer
results = []
for text, img in zip(generated_texts, images):
parsed_answer = self._processor.post_process_generation(
text, task=task, image_size=(img.width, img.height)
)
results.append(parsed_answer)
return results

def to(self, device: Device):
self._model.to(device=device.value)

def predict(
self, image: Image.Image, prompts: Optional[List[str]] = None, **kwargs
self, images: list[Image.Image], prompts: List[str] | None = None, **kwargs
) -> Any:
task = kwargs.get("task", "")
results = []
for prompt in prompts:
results.append(self.__call__(images=image, task=task, prompt=prompt))
results = self.__call__(task=task, images=images, prompt=prompts)
return results


Expand Down

0 comments on commit 8ead0c3

Please sign in to comment.