Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reafactor: florencev2 batch processing #68

Merged
merged 5 commits into from
Sep 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 62 additions & 39 deletions vision_agent_tools/models/florencev2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, List, Optional
from typing import Any, List

import torch
from PIL import Image
Expand Down Expand Up @@ -109,10 +109,17 @@ def __call__(
Returns:
Any: The output of the Florence-2 model based on the provided task, images/video, and prompt. The output type can vary depending on the chosen task.
"""

if isinstance(task, str):
try:
task = PromptTask(task)
except ValueError:
raise ValueError(f"Invalid task string: {task}")

if prompt is None:
text_input = task
else:
text_input = task + prompt
prompt = ""
elif not isinstance(prompt, str):
raise ValueError("prompt must be a string or None.")

# Validate input parameters
if image is None and images is None and video is None:
Expand All @@ -127,36 +134,45 @@ def __call__(
)

if image is not None:
# Single image processing
text_input = str(task.value) + prompt
image = self._process_image(image)
return self._single_image_call(text_input, image, task, prompt)
if images is not None:
results = []
for image in images:
processed_image = self._process_image(image)
result = self._single_image_call(
text_input, processed_image, task, prompt
)
results.append(result)
return results
if video is not None:
images = self._process_video(video)
return [
self._single_image_call(text_input, image, task, prompt)
for image in images
]

def _single_image_call(
results = self._batch_image_call([text_input], [image], task)
return results[0]
elif images is not None:
# Batch processing
images_list = [self._process_image(img) for img in images]
num_images = len(images_list)

# Create text_inputs by repeating the task and prompt for each image
text_input = str(task.value) + prompt
text_inputs = [text_input] * num_images

return self._batch_image_call(text_inputs, images_list, task)
elif video is not None:
# Process video frames
images_list = self._process_video(video)
num_images = len(images_list)

# Create text_inputs by repeating the task and prompt for each frame
text_input = str(task.value) + prompt
text_inputs = [text_input] * num_images

return self._batch_image_call(text_inputs, images_list, task)

def _batch_image_call(
self,
text_input: str,
image: Image.Image,
text_inputs: List[str],
images: List[Image.Image],
task: PromptTask,
prompt: str,
):
inputs = self._processor(text=text_input, images=image, return_tensors="pt").to(
self.device
)
inputs = self._processor(
text=text_inputs,
images=images,
return_tensors="pt",
).to(self.device)

with torch.autocast(self.device):
with torch.autocast(device_type=self.device):
generated_ids = self._model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
Expand All @@ -165,26 +181,33 @@ def _single_image_call(
early_stopping=False,
do_sample=False,
)
generated_text = self._processor.batch_decode(
generated_ids, skip_special_tokens=False
)[0]

parsed_answer = self._processor.post_process_generation(
generated_text, task=task, image_size=(image.width, image.height)
# Set skip_special_tokens based on the task
if task == PromptTask.OCR:
skip_special_tokens = True
else:
skip_special_tokens = False

generated_texts = self._processor.batch_decode(
generated_ids, skip_special_tokens=skip_special_tokens
)

return parsed_answer
results = []
for text, img in zip(generated_texts, images):
parsed_answer = self._processor.post_process_generation(
text, task=task, image_size=(img.width, img.height)
)
results.append(parsed_answer)
return results

def to(self, device: Device):
self._model.to(device=device.value)

def predict(
self, image: Image.Image, prompts: Optional[List[str]] = None, **kwargs
self, images: list[Image.Image], prompts: List[str] | None = None, **kwargs
) -> Any:
task = kwargs.get("task", "")
results = []
for prompt in prompts:
results.append(self.__call__(images=image, task=task, prompt=prompt))
results = self.__call__(task=task, images=images, prompt=prompts)
return results


Expand Down