Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/implement the support for multimodal with new vqa task #1111

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c8a9511
implemented: basic structured to handle visualQA
chakravarthik27 Sep 14, 2024
f7b53e6
Refactor VisualQASample class to include additional attributes and do…
chakravarthik27 Sep 14, 2024
6eec7ca
Refactor llm_modelhandler.py to include PretrainedModelForVisualQA class
chakravarthik27 Sep 14, 2024
b95ecf3
Refactor VisualQA class to fix typo in base class name
chakravarthik27 Sep 14, 2024
adf18db
Merge remote-tracking branch 'origin/release/2.4.0' into feature/impl…
chakravarthik27 Sep 15, 2024
d3e6fa5
updated: image handling while loading dataset.
chakravarthik27 Sep 15, 2024
3ee5f8f
implemented the different tests under robusntess category and support…
chakravarthik27 Sep 15, 2024
3dd6770
Refactor image handling in robustness tests
chakravarthik27 Sep 15, 2024
d95e558
Refactor image handling in robustness tests and add support for multi…
chakravarthik27 Sep 15, 2024
ebd7bfd
Refactor image handling in robustness tests and update VisualQASample…
chakravarthik27 Sep 15, 2024
4538490
Refactor image handling in robustness tests and exclude image-related…
chakravarthik27 Sep 15, 2024
41f0db2
fixed: format issues.
chakravarthik27 Sep 15, 2024
3521927
Refactor image handling in robustness tests and remove commented code
chakravarthik27 Sep 16, 2024
a87e96c
Refactor image handling in robustness tests and update VisualQASample…
chakravarthik27 Sep 16, 2024
04e18e3
- added new tests in image robustness.
chakravarthik27 Sep 16, 2024
8039ef8
Add pillow library to pyproject.toml
chakravarthik27 Sep 16, 2024
a5ae26a
Merge remote-tracking branch 'origin/release/2.4.0' into feature/impl…
chakravarthik27 Sep 17, 2024
b29f9dd
resolve OutofMemory issues
chakravarthik27 Sep 17, 2024
16a3aa5
updated the notebook
chakravarthik27 Sep 17, 2024
b337d2b
Update pillow version to 10.0.0 and make it a required dependency
chakravarthik27 Sep 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions demo/tutorials/llm_notebooks/Visual_QA.ipynb

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@
"anti-stereotype": ["anti-stereotype"],
"unrelated": ["unrelated"],
},
"visualqa": {
"image": ["image", "image_1"],
"question": ["question"],
"options": ["options"],
"answer": ["answer"],
},
}


Expand Down Expand Up @@ -183,7 +189,7 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) ->
raise ValueError(Errors.E024)

if "data_source" not in file_path:
raise ValueError(Errors.E025)
raise ValueError(Errors.E025())
self._custom_label = file_path.copy()
self._file_path = file_path.get("data_source")
self._size = None
Expand Down Expand Up @@ -1246,6 +1252,7 @@ class HuggingFaceDataset(BaseDataset):
"summarization",
"ner",
"question-answering",
"visualqa",
]

LIB_NAME = "datasets"
Expand Down Expand Up @@ -1709,6 +1716,7 @@ class PandasDataset(BaseDataset):
"legal",
"factuality",
"stereoset",
"visualqa",
]
COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks}

Expand Down
6 changes: 6 additions & 0 deletions langtest/langtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def generated_results(self) -> Optional[pd.DataFrame]:
"model_name",
"category",
"test_type",
"original_image",
"original",
"context",
"prompt",
Expand All @@ -613,8 +614,10 @@ def generated_results(self) -> Optional[pd.DataFrame]:
"completion",
"test_case",
"perturbed_context",
"perturbed_image",
"perturbed_question",
"sentence",
"question",
"patient_info_A",
"patient_info_B",
"case",
Expand Down Expand Up @@ -838,6 +841,7 @@ def testcases(self, additional_cols=False) -> pd.DataFrame:
"model_name",
"category",
"test_type",
"original_image",
"original",
"context",
"original_context",
Expand All @@ -863,7 +867,9 @@ def testcases(self, additional_cols=False) -> pd.DataFrame:
"correct_sentence",
"incorrect_sentence",
"perturbed_context",
"perturbed_image",
"perturbed_question",
"question",
"ground_truth",
"options",
"expected_result",
Expand Down
55 changes: 55 additions & 0 deletions langtest/modelhandler/llm_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
from functools import lru_cache
from langtest.utils.custom_types.helpers import HashableDict
from langchain.chat_models.base import BaseChatModel


class PretrainedModelForQA(ModelAPI):
Expand Down Expand Up @@ -452,3 +453,57 @@ class PretrainedModelForSycophancy(PretrainedModelForQA, ModelAPI):
"""

pass


class PretrainedModelForVisualQA(PretrainedModelForQA, ModelAPI):
"""A class representing a pretrained model for visual question answering.

Inherits:
PretrainedModelForQA: The base class for pretrained models.
"""

@lru_cache(maxsize=102400)
def predict(
self, text: Union[str, dict], prompt: dict, images: List[Any], *args, **kwargs
):
"""Perform prediction using the pretrained model.

Args:
text (Union[str, dict]): The input text or dictionary.
prompt (dict): The prompt configuration.
images (List[Any]): The list of images.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.

Returns:
dict: A dictionary containing the prediction result.
- 'result': The prediction result.
"""
try:
if not isinstance(self.model, BaseChatModel):
ValueError("visualQA task is only supported for chat models")

# prepare prompt
prompt_template = PromptTemplate(**prompt)
from langchain_core.messages import HumanMessage

images = [
{
"type": "image_url",
"image_url": {"url": image},
}
for image in images
]

messages = HumanMessage(
content=[
{"type": "text", "text": prompt_template.format(**text)},
*images,
]
)

response = self.model.invoke([messages])
return response.content

except Exception as e:
raise ValueError(Errors.E089(error_message=e))
41 changes: 41 additions & 0 deletions langtest/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,44 @@ def create_sample(

class FillMask(BaseTask):
pass


class VisualQA(BaseTask):
_name = "visualqa"
_default_col = {
"image": ["image"],
"question": ["question"],
"answer": ["answer"],
}
sample_class = samples.VisualQASample

def create_sample(
cls,
row_data: dict,
image: str = "image_1",
question: str = "question",
options: str = "options",
answer: str = "answer",
dataset_name: str = "",
) -> samples.VisualQASample:
"""Create a sample."""
keys = list(row_data.keys())

# auto-detect the default column names from the row_data
column_mapper = cls.column_mapping(keys, [image, question, options, answer])

options = row_data.get(column_mapper.get(options, "-"), "-")

if len(options) > 3 and options[0] == "[" and options[-1] == "]":
options = ast.literal_eval(row_data[column_mapper["options"]])
options = "\n".join(
[f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]
)

return samples.VisualQASample(
original_image=row_data[column_mapper[image]],
question=row_data[column_mapper[question]],
options=options,
expected_result=row_data[column_mapper[answer]],
dataset_name=dataset_name,
)
3 changes: 3 additions & 0 deletions langtest/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from langtest.transform.grammar import GrammarTestFactory
from langtest.transform.safety import SafetyTestFactory

from langtest.transform import image

# Fixing the asyncio event loop
nest_asyncio.apply()

Expand All @@ -47,4 +49,5 @@
SycophancyTestFactory,
GrammarTestFactory,
SafetyTestFactory,
image,
]
3 changes: 3 additions & 0 deletions langtest/transform/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .robustness import ImageResizing, ImageRotation, ImageBlur, ImageNoise

__all__ = [ImageResizing, ImageRotation, ImageBlur, ImageNoise]
Loading
Loading