diff --git a/vision_agent_tools/models/florencev2.py b/vision_agent_tools/models/florencev2.py index 9aaa298..c1071cd 100644 --- a/vision_agent_tools/models/florencev2.py +++ b/vision_agent_tools/models/florencev2.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, List, Optional +from typing import Any, List import torch from PIL import Image @@ -109,10 +109,17 @@ def __call__( Returns: Any: The output of the Florence-2 model based on the provided task, images/video, and prompt. The output type can vary depending on the chosen task. """ + + if isinstance(task, str): + try: + task = PromptTask(task) + except ValueError: + raise ValueError(f"Invalid task string: {task}") + if prompt is None: - text_input = task - else: - text_input = task + prompt + prompt = "" + elif not isinstance(prompt, str): + raise ValueError("prompt must be a string or None.") # Validate input parameters if image is None and images is None and video is None: @@ -127,36 +134,45 @@ def __call__( ) if image is not None: + # Single image processing + text_input = str(task.value) + prompt image = self._process_image(image) - return self._single_image_call(text_input, image, task, prompt) - if images is not None: - results = [] - for image in images: - processed_image = self._process_image(image) - result = self._single_image_call( - text_input, processed_image, task, prompt - ) - results.append(result) - return results - if video is not None: - images = self._process_video(video) - return [ - self._single_image_call(text_input, image, task, prompt) - for image in images - ] - - def _single_image_call( + results = self._batch_image_call([text_input], [image], task) + return results[0] + elif images is not None: + # Batch processing + images_list = [self._process_image(img) for img in images] + num_images = len(images_list) + + # Create text_inputs by repeating the task and prompt for each image + text_input = str(task.value) + prompt + text_inputs = [text_input] * num_images + + return self._batch_image_call(text_inputs, images_list, task) + elif video is not None: + # Process video frames + images_list = self._process_video(video) + num_images = len(images_list) + + # Create text_inputs by repeating the task and prompt for each frame + text_input = str(task.value) + prompt + text_inputs = [text_input] * num_images + + return self._batch_image_call(text_inputs, images_list, task) + + def _batch_image_call( self, - text_input: str, - image: Image.Image, + text_inputs: List[str], + images: List[Image.Image], task: PromptTask, - prompt: str, ): - inputs = self._processor(text=text_input, images=image, return_tensors="pt").to( - self.device - ) + inputs = self._processor( + text=text_inputs, + images=images, + return_tensors="pt", + ).to(self.device) - with torch.autocast(self.device): + with torch.autocast(device_type=self.device): generated_ids = self._model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], @@ -165,26 +181,33 @@ def _single_image_call( early_stopping=False, do_sample=False, ) - generated_text = self._processor.batch_decode( - generated_ids, skip_special_tokens=False - )[0] - parsed_answer = self._processor.post_process_generation( - generated_text, task=task, image_size=(image.width, image.height) + # Set skip_special_tokens based on the task + if task == PromptTask.OCR: + skip_special_tokens = True + else: + skip_special_tokens = False + + generated_texts = self._processor.batch_decode( + generated_ids, skip_special_tokens=skip_special_tokens ) - return parsed_answer + results = [] + for text, img in zip(generated_texts, images): + parsed_answer = self._processor.post_process_generation( + text, task=task, image_size=(img.width, img.height) + ) + results.append(parsed_answer) + return results def to(self, device: Device): self._model.to(device=device.value) def predict( - self, image: Image.Image, prompts: Optional[List[str]] = None, **kwargs + self, images: list[Image.Image], prompts: List[str] | None = None, **kwargs ) -> Any: task = kwargs.get("task", "") - results = [] - for prompt in prompts: - results.append(self.__call__(images=image, task=task, prompt=prompt)) + results = self.__call__(task=task, images=images, prompt=prompts) return results