Skip to content

Commit

Permalink
Merge pull request #1111 from JohnSnowLabs/feature/implement-the-supp…
Browse files Browse the repository at this point in the history
…ort-for-multimodal-with-new-vqa-task

Feature/implement the support for multimodal with new vqa task
  • Loading branch information
chakravarthik27 authored Sep 17, 2024
2 parents d3a4663 + b337d2b commit 67c641d
Show file tree
Hide file tree
Showing 14 changed files with 731 additions and 4 deletions.
1 change: 1 addition & 0 deletions demo/tutorials/llm_notebooks/Visual_QA.ipynb

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@
"anti-stereotype": ["anti-stereotype"],
"unrelated": ["unrelated"],
},
"visualqa": {
"image": ["image", "image_1"],
"question": ["question"],
"options": ["options"],
"answer": ["answer"],
},
}


Expand Down Expand Up @@ -183,7 +189,7 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) ->
raise ValueError(Errors.E024)

if "data_source" not in file_path:
raise ValueError(Errors.E025)
raise ValueError(Errors.E025())
self._custom_label = file_path.copy()
self._file_path = file_path.get("data_source")
self._size = None
Expand Down Expand Up @@ -1246,6 +1252,7 @@ class HuggingFaceDataset(BaseDataset):
"summarization",
"ner",
"question-answering",
"visualqa",
]

LIB_NAME = "datasets"
Expand Down Expand Up @@ -1709,6 +1716,7 @@ class PandasDataset(BaseDataset):
"legal",
"factuality",
"stereoset",
"visualqa",
]
COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks}

Expand Down
6 changes: 6 additions & 0 deletions langtest/langtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def generated_results(self) -> Optional[pd.DataFrame]:
"model_name",
"category",
"test_type",
"original_image",
"original",
"context",
"prompt",
Expand All @@ -613,8 +614,10 @@ def generated_results(self) -> Optional[pd.DataFrame]:
"completion",
"test_case",
"perturbed_context",
"perturbed_image",
"perturbed_question",
"sentence",
"question",
"patient_info_A",
"patient_info_B",
"case",
Expand Down Expand Up @@ -838,6 +841,7 @@ def testcases(self, additional_cols=False) -> pd.DataFrame:
"model_name",
"category",
"test_type",
"original_image",
"original",
"context",
"original_context",
Expand All @@ -863,7 +867,9 @@ def testcases(self, additional_cols=False) -> pd.DataFrame:
"correct_sentence",
"incorrect_sentence",
"perturbed_context",
"perturbed_image",
"perturbed_question",
"question",
"ground_truth",
"options",
"expected_result",
Expand Down
55 changes: 55 additions & 0 deletions langtest/modelhandler/llm_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
from functools import lru_cache
from langtest.utils.custom_types.helpers import HashableDict
from langchain.chat_models.base import BaseChatModel


class PretrainedModelForQA(ModelAPI):
Expand Down Expand Up @@ -452,3 +453,57 @@ class PretrainedModelForSycophancy(PretrainedModelForQA, ModelAPI):
"""

pass


class PretrainedModelForVisualQA(PretrainedModelForQA, ModelAPI):
"""A class representing a pretrained model for visual question answering.
Inherits:
PretrainedModelForQA: The base class for pretrained models.
"""

@lru_cache(maxsize=102400)
def predict(
self, text: Union[str, dict], prompt: dict, images: List[Any], *args, **kwargs
):
"""Perform prediction using the pretrained model.
Args:
text (Union[str, dict]): The input text or dictionary.
prompt (dict): The prompt configuration.
images (List[Any]): The list of images.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
dict: A dictionary containing the prediction result.
- 'result': The prediction result.
"""
try:
if not isinstance(self.model, BaseChatModel):
ValueError("visualQA task is only supported for chat models")

# prepare prompt
prompt_template = PromptTemplate(**prompt)
from langchain_core.messages import HumanMessage

images = [
{
"type": "image_url",
"image_url": {"url": image},
}
for image in images
]

messages = HumanMessage(
content=[
{"type": "text", "text": prompt_template.format(**text)},
*images,
]
)

response = self.model.invoke([messages])
return response.content

except Exception as e:
raise ValueError(Errors.E089(error_message=e))
41 changes: 41 additions & 0 deletions langtest/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,44 @@ def create_sample(

class FillMask(BaseTask):
pass


class VisualQA(BaseTask):
_name = "visualqa"
_default_col = {
"image": ["image"],
"question": ["question"],
"answer": ["answer"],
}
sample_class = samples.VisualQASample

def create_sample(
cls,
row_data: dict,
image: str = "image_1",
question: str = "question",
options: str = "options",
answer: str = "answer",
dataset_name: str = "",
) -> samples.VisualQASample:
"""Create a sample."""
keys = list(row_data.keys())

# auto-detect the default column names from the row_data
column_mapper = cls.column_mapping(keys, [image, question, options, answer])

options = row_data.get(column_mapper.get(options, "-"), "-")

if len(options) > 3 and options[0] == "[" and options[-1] == "]":
options = ast.literal_eval(row_data[column_mapper["options"]])
options = "\n".join(
[f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]
)

return samples.VisualQASample(
original_image=row_data[column_mapper[image]],
question=row_data[column_mapper[question]],
options=options,
expected_result=row_data[column_mapper[answer]],
dataset_name=dataset_name,
)
3 changes: 3 additions & 0 deletions langtest/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from langtest.transform.grammar import GrammarTestFactory
from langtest.transform.safety import SafetyTestFactory

from langtest.transform import image

# Fixing the asyncio event loop
nest_asyncio.apply()

Expand All @@ -47,4 +49,5 @@
SycophancyTestFactory,
GrammarTestFactory,
SafetyTestFactory,
image,
]
3 changes: 3 additions & 0 deletions langtest/transform/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .robustness import ImageResizing, ImageRotation, ImageBlur, ImageNoise

__all__ = [ImageResizing, ImageRotation, ImageBlur, ImageNoise]
Loading

0 comments on commit 67c641d

Please sign in to comment.