diff --git a/demo/tutorials/llm_notebooks/Visual_QA.ipynb b/demo/tutorials/llm_notebooks/Visual_QA.ipynb
new file mode 100644
index 000000000..7045e71c3
--- /dev/null
+++ b/demo/tutorials/llm_notebooks/Visual_QA.ipynb
@@ -0,0 +1 @@
+{"cells":[{"cell_type":"markdown","metadata":{"id":"D285OP467TeS"},"source":["![image.png]()"]},{"cell_type":"markdown","metadata":{"id":"_8dMBi8UNtg1"},"source":["[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/langtest/blob/main/demo/tutorials/llm_notebooks/Visual_QA.ipynb)"]},{"cell_type":"markdown","metadata":{"id":"_EzC6SKhjdk7"},"source":["**LangTest** is an open-source python library designed to help developers deliver safe and effective Natural Language Processing (NLP) models. Whether you are using **John Snow Labs, Hugging Face, Spacy** models or **OpenAI, Cohere, AI21, Hugging Face Inference API and Azure-OpenAI** based LLMs, it has got you covered. You can test any Named Entity Recognition (NER), Text Classification, fill-mask, Translation model using the library. We also support testing LLMS for Question-Answering, Visual question-answering, Summarization and text-generation tasks on benchmark datasets. The library supports 60+ out of the box tests. For a complete list of supported test categories, please refer to the [documentation](http://langtest.org/docs/pages/docs/test_categories).\n","\n","Metrics are calculated by comparing the model's extractions in the original list of sentences against the extractions carried out in the noisy list of sentences. The original annotated labels are not used at any point, we are simply comparing the model against itself in a 2 settings."]},{"cell_type":"markdown","metadata":{"id":"v9Yd7KhpZOTF"},"source":["# Getting started with LangTest"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kJ-dxTWu7bcA"},"outputs":[],"source":["!pip install langtest==2.4.0"]},{"cell_type":"markdown","metadata":{"id":"cXOI5kBFlO6w"},"source":["# Harness and its Parameters\n","\n","The Harness class is a testing class for Natural Language Processing (NLP) models. It evaluates the performance of a NLP model on a given task using test data and generates a report with test results.Harness can be imported from the LangTest library in the following way."]},{"cell_type":"code","execution_count":1,"metadata":{"executionInfo":{"elapsed":4291,"status":"ok","timestamp":1692340616139,"user":{"displayName":"Prikshit sharma","userId":"07819241395213139913"},"user_tz":-330},"id":"w1g27-uxl1AA"},"outputs":[],"source":["#Import Harness from the LangTest library\n","from langtest import Harness"]},{"cell_type":"markdown","metadata":{"id":"PXBMpFHIl7n9"},"source":["It imports the Harness class from within the module, that is designed to provide a blueprint or framework for conducting NLP testing, and that instances of the Harness class can be customized or configured for different testing scenarios or environments.\n","\n","Here is a list of the different parameters that can be passed to the Harness function:\n","\n","
\n","\n","\n","\n","| Parameter | Description |\n","| - | - |\n","| **task** | Task for which the model is to be evaluated (Visual Question Answering) |\n","| **model** | Specifies the model(s) to be evaluated. This parameter can be provided as either a dictionary or a list of dictionaries. Each dictionary should contain the following keys:
- model (mandatory): \tPipelineModel or path to a saved model or pretrained LLM pipeline/model from hub.
- hub (mandatory): Hub (library) to use in back-end for loading model from public models hub or from path
|\n","| **data** | The data to be used for evaluation. A dictionary providing flexibility and options for data sources. It should include the following keys: - data_source (mandatory): The source of the data.
- subset (optional): The subset of the data.
- feature_column (optional): The column containing the features.
- target_column (optional): The column containing the target labels.
- split (optional): The data split to be used.
- source (optional): Set to 'huggingface' when loading Hugging Face dataset.
|\n","| **config** | Configuration for the tests to be performed, specified in the form of a YAML file. |\n","\n","\n","
\n","
"]},{"cell_type":"markdown","metadata":{"id":"KLC_lBv09ZuN"},"source":["# Robustness Testing\n","\n","Model robustness can be described as the ability of a model to maintain similar levels of accuracy, precision, and recall when perturbations are made to the data it is predicting on. For example, In the case of images, the goal is to understand how modifications such as resizing, rotation, noise addition, or color adjustments affect the model's performance compared to the original images it was trained on.\n","\n","\n","**`Supported Robustness tests :`**
\n","\n","### Text\n","\n","| **Test Name** | **Short Description** |\n","|-------------------------------|----------------------------------------------------------------------------------------|\n","| **`uppercase`** | Capitalization of the text set is turned into uppercase |\n","| **`lowercase`** | Capitalization of the text set is turned into lowercase |\n","| **`titlecase`** | Capitalization of the text set is turned into title case |\n","| **`add_punctuation`** | Adds punctuation to the text set |\n","| **`strip_punctuation`** | Removes punctuation from the text set |\n","| **`add_typo`** | Introduces typographical errors into the text |\n","| **`swap_entities`** | Swaps named entities in the text |\n","| **`american_to_british`** | Converts American English spellings to British English |\n","| **`british_to_american`** | Converts British English spellings to American English |\n","| **`add_context`** | Adds additional context to the text set |\n","| **`add_contraction`** | Introduces contractions (e.g., do not → don't) |\n","| **`dyslexia_word_swap`** | Swaps words in a way that mimics dyslexic reading errors |\n","| **`number_to_word`** | Converts numbers to words in the text set (e.g., 1 → one) |\n","| **`add_ocr_typo`** | Adds optical character recognition (OCR) specific typos to the text |\n","| **`add_abbreviation`** | Replaces certain words with their abbreviations |\n","| **`add_speech_to_text_typo`** | Adds speech-to-text transcription errors |\n","| **`add_slangs`** | Introduces slang terms into the text |\n","| **`multiple_perturbations`** | Applies multiple perturbations to the text at once |\n","| **`adjective_synonym_swap`** | Swaps adjectives in the text with their synonyms |\n","| **`adjective_antonym_swap`** | Swaps adjectives in the text with their antonyms |\n","| **`strip_all_punctuation`** | Removes all punctuation from the text |\n","| **`randomize_age`** | Randomizes the age mentioned in the text |\n","| **`add_new_lines`** | Inserts new lines into the text set |\n","| **`add_tabs`** | Inserts tab characters into the text set |\n","\n","### Images\n","\n","| **Test Name** | **Short Description** |\n","|----------------------|--------------------------------------------------------|\n","| **`image_resize`** | Resizes the image to a different dimension |\n","| **`image_rotate`** | Rotates the image by a specified angle |\n","| **`image_blur`** | Applies a blur filter to the image |\n","| **`image_noise`** | Adds random noise to the image |\n","| **`image_contrast`** | Adjusts the contrast of the image |\n","| **`image_brightness`**| Adjusts the brightness of the image |\n","| **`image_sharpness`** | Adjusts the sharpness of the image |\n","| **`image_color`** | Adjusts the color balance of the image |\n","| **`image_flip`** | Flips the image either horizontally or vertically |\n","| **`image_crop`** | Crops a portion of the image |\n","\n","
"]},{"cell_type":"markdown","metadata":{"id":"cVIzXdGMjX47"},"source":["## Testing robustness of a pretrained LLM models\n","\n","Testing a LLM model's robustness gives us an idea on how our data may need to be modified to make the model more robust. We can use a pretrained model/pipeline or define our own custom pipeline or load a saved pre trained model to test.\n","\n","Here we are directly passing a pretrained model/pipeline from hub as the model parameter in harness and running the tests."]},{"cell_type":"markdown","metadata":{"id":"78THAZm3cRu7"},"source":["### Test Configuration\n","\n","Test configuration can be passed in the form of a YAML file as shown below or using .configure() method\n","\n","\n","**Config YAML format** :\n","```\n","tests: \n"," {\n"," \"defaults\": {\n"," \"min_pass_rate\": 0.5,\n"," },\n"," \"robustness\": {\n"," \"image_noise\": {\n"," \"min_pass_rate\": 0.5,\n"," \"parameters\": {\n"," \"noise_level\": 0.5\n"," }\n","\n"," },\n"," \"image_rotate\": {\n"," \"min_pass_rate\": 0.5,\n"," \"parameters\": {\n"," \"angle\": 45\n"," }\n"," },\n"," \"image_blur\": {\n"," \"min_pass_rate\": 0.5,\n"," \"parameters\": {\n"," \"radius\": 5\n"," }\n"," },\n"," \"image_resize\": {\n"," \"min_pass_rate\": 0.5,\n"," \"parameters\": {\n"," \"resize\": 0.5 # 0.01 to 1.0 means 1% to 100% of the original size\n"," }\n"," },\n"," }\n"," }\n"," \n","```\n","\n","If config file is not present, we can also use the **.configure()** method to manually configure the harness to perform the needed tests.\n"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["import os \n","os.environ['OPENAI_API_KEY'] = \"sk-XXXXXXXX\""]},{"cell_type":"markdown","metadata":{},"source":["## Visual Question Answering (VQA)\n","\n","This notebook demonstrates how to perform a Visual Question Answering (VQA) using the `PIL` library to load images and a harness for running the task. The model being used is `gpt-4o-mini` from the OpenAI hub, and the data comes from the MMMU dataset, specifically the `Clinical_Medicine` subset."]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c274bf01644a432fb0e254fd1e8ebb75","version_major":2,"version_minor":0},"text/plain":["Resolving data files: 0%| | 0/60 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"3cc1882c9281421f8b7f42f54a3999ce","version_major":2,"version_minor":0},"text/plain":["Resolving data files: 0%| | 0/32 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Test Configuration : \n"," {}\n"]}],"source":["harness = Harness(\n"," task=\"visualqa\",\n"," model={\n"," \"model\": \"gpt-4o-mini\",\n"," \"hub\": \"openai\"\n"," },\n"," data={\"data_source\": 'MMMU/MMMU',\n"," \"subset\": \"Clinical_Medicine\",\n"," # \"feature_column\": \"question\",\n"," # \"target_column\": 'answer',\n"," \"split\": \"dev\",\n"," \"source\": \"huggingface\"\n"," },\n"," config={}\n",")"]},{"cell_type":"markdown","metadata":{"id":"jGEN7Q0Ric8H"},"source":["We can use the .configure() method to manually define our test configuration for the robustness tests."]},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":91,"status":"ok","timestamp":1692340473373,"user":{"displayName":"Prikshit sharma","userId":"07819241395213139913"},"user_tz":-330},"id":"C08dW5tue_6d","outputId":"c12433af-296e-4e9b-d2e2-cdd68f5426ea"},"outputs":[{"data":{"text/plain":["{'tests': {'defaults': {'min_pass_rate': 0.5},\n"," 'robustness': {'image_noise': {'min_pass_rate': 0.5,\n"," 'parameters': {'noise_level': 0.5}},\n"," 'image_rotate': {'min_pass_rate': 0.5, 'parameters': {'angle': 55}},\n"," 'image_blur': {'min_pass_rate': 0.5, 'parameters': {'radius': 5}},\n"," 'image_resize': {'min_pass_rate': 0.5, 'parameters': {'resize': 0.5}}}}}"]},"execution_count":4,"metadata":{},"output_type":"execute_result"}],"source":["harness.configure({\n"," \"tests\": {\n"," \"defaults\": {\n"," \"min_pass_rate\": 0.5,\n"," },\n"," \"robustness\": {\n"," \"image_noise\": {\n"," \"min_pass_rate\": 0.5,\n"," \"parameters\": {\n"," \"noise_level\": 0.5\n"," }\n","\n"," },\n"," \"image_rotate\": {\n"," \"min_pass_rate\": 0.5,\n"," \"parameters\": {\n"," \"angle\": 55\n"," }\n"," },\n"," \"image_blur\": {\n"," \"min_pass_rate\": 0.5,\n"," \"parameters\": {\n"," \"radius\": 5\n"," }\n"," },\n"," \"image_resize\": {\n"," \"min_pass_rate\": 0.5,\n"," \"parameters\": {\n"," \"resize\": 0.5 # 0.01 to 1.0 means 1% to 100% of the original size\n"," }\n"," },\n"," }\n"," }\n","})"]},{"cell_type":"markdown","metadata":{"id":"FLLzeE_Pix2W"},"source":["Here we have configured the harness to perform image robustness tests (image_blur, image_resize, image_rotate, and image_noise) and defined the minimum pass rate for each test."]},{"cell_type":"markdown","metadata":{},"source":["To ensure we work with a smaller subset of data, we'll limit the dataset to the first 50 entries. This is useful for faster prototyping and testing without needing to process the entire dataset.\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["harness.data = harness.data[:50]"]},{"cell_type":"markdown","metadata":{},"source":["In this section, we will reset the test cases in the `Harness` object by setting `harness._testcases` to `None`. This can be useful if you want to clear any previously loaded test cases or start fresh without any predefined cases.\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["harness._testcases = None"]},{"cell_type":"markdown","metadata":{"id":"MomLlmTwjpzU"},"source":["\n","### Generating the test cases.\n","\n","\n"]},{"cell_type":"code","execution_count":5,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":23034,"status":"ok","timestamp":1692340496325,"user":{"displayName":"Prikshit sharma","userId":"07819241395213139913"},"user_tz":-330},"id":"njyA7h_tfMVo","outputId":"481382ae-630d-4c62-d6d8-c8108982df89"},"outputs":[{"name":"stderr","output_type":"stream","text":["Generating testcases...: 100%|██████████| 1/1 [00:00, ?it/s]\n"]},{"data":{"text/plain":[]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["harness.generate()"]},{"cell_type":"markdown","metadata":{"id":"C_qyYdl8FYoD"},"source":["harness.generate() method automatically generates the test cases (based on the provided configuration)"]},{"cell_type":"markdown","metadata":{},"source":["This code snippet will display an HTML table based on the DataFrame returned by `harness.testcases()`. The `escape=False` parameter allows HTML content within the DataFrame to be rendered without escaping special characters."]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"data":{"text/html":["\n"," \n"," \n"," | \n"," category | \n"," test_type | \n"," original_image | \n"," perturbed_image | \n"," question | \n"," options | \n","
\n"," \n"," \n"," \n"," 3 | \n"," robustness | \n"," image_noise | \n"," | \n"," | \n"," What person's name is associated with the fracture shown below? | \n"," A. Monteggia\\nB. Bennett\\nC. Jones\\nD. Smith | \n","
\n"," \n"," 15 | \n"," robustness | \n"," image_resize | \n"," | \n"," | \n"," Identify the following rhythm: | \n"," A. Sinus Rhythm with PAC's\\nB. Junctional Rhythm\\nC. 2nd Degree AV Block, Type I\\nD. 3rd Degree AV Block\\nE. Normal Sinus Rhythm with PVC's\\nF. Idioventricular Rhythm | \n","
\n"," \n"," 6 | \n"," robustness | \n"," image_rotate | \n"," | \n"," | \n"," A 56-year-old woman is undergoing chemotherapy for treatment of breast carcinoma. The gross appearance of her skin shown here is most typical for which of the following conditions? | \n"," A. Thrombocytopenia\\nB. Gangrene\\nC. Congestive heart failure\\nD. Metastatic breast carcinoma | \n","
\n"," \n"," 18 | \n"," robustness | \n"," image_resize | \n"," | \n"," | \n"," What person's name is associated with the fracture shown below? | \n"," A. Monteggia\\nB. Bennett\\nC. Jones\\nD. Smith | \n","
\n"," \n"," 17 | \n"," robustness | \n"," image_resize | \n"," | \n"," | \n"," Based on , what's the most likely diagnosis? | \n"," A. first degree atrioventricular block\\nB. third degree atrioventricular block\\nC. Second degree type II atrioventricular block\\nD. atrial flutter | \n","
\n"," \n","
"],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from IPython.display import display, HTML\n","\n","\n","df = harness.testcases()\n","html=df.sample(5).to_html(escape=False)\n","\n","display(HTML(html))"]},{"cell_type":"markdown","metadata":{"id":"fRyNPRBokXNZ"},"source":["### Running the tests."]},{"cell_type":"code","execution_count":12,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":68268,"status":"ok","timestamp":1692340564519,"user":{"displayName":"Prikshit sharma","userId":"07819241395213139913"},"user_tz":-330},"id":"3kUPTsNvjkgr","outputId":"4c4815e4-4cab-4dbf-99ba-1a231656f1e3"},"outputs":[{"name":"stderr","output_type":"stream","text":["Running testcases... : 100%|██████████| 20/20 [00:44<00:00, 2.21s/it]\n"]},{"data":{"text/plain":[]},"execution_count":12,"metadata":{},"output_type":"execute_result"}],"source":["harness.run()"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"data":{"text/html":["\n"," \n"," \n"," | \n"," category | \n"," test_type | \n"," original_image | \n"," perturbed_image | \n"," question | \n"," options | \n"," expected_result | \n"," actual_result | \n"," pass | \n","
\n"," \n"," \n"," \n"," 5 | \n"," robustness | \n"," image_rotate | \n"," | \n"," | \n"," Identify the following rhythm: | \n"," A. Sinus Rhythm with PAC's\\nB. Junctional Rhythm\\nC. 2nd Degree AV Block, Type I\\nD. 3rd Degree AV Block\\nE. Normal Sinus Rhythm with PVC's\\nF. Idioventricular Rhythm | \n"," Answer: UnRecognizable. | \n"," Answer: UnRecognizable. | \n"," True | \n","
\n"," \n"," 4 | \n"," robustness | \n"," image_noise | \n"," | \n"," | \n"," The best diagnosis for the appendix is: | \n"," A. simple appendicitis\\nB. appendix abscess\\nC. normal appendix\\nD. cellulite appendicitis | \n"," Answer: UnRecognizable. | \n"," I'm unable to recognize the content of the image. Thus, I cannot determine the correct diagnosis for the appendix. \\n\\nAnswer: UnRecognizable. | \n"," False | \n","
\n"," \n"," 7 | \n"," robustness | \n"," image_rotate | \n"," | \n"," | \n"," Based on , what's the most likely diagnosis? | \n"," A. first degree atrioventricular block\\nB. third degree atrioventricular block\\nC. Second degree type II atrioventricular block\\nD. atrial flutter | \n"," Answer: UnRecognizable. | \n"," Answer: UnRecognizable. | \n"," True | \n","
\n"," \n"," 9 | \n"," robustness | \n"," image_rotate | \n"," | \n"," | \n"," The best diagnosis for the appendix is: | \n"," A. simple appendicitis\\nB. appendix abscess\\nC. normal appendix\\nD. cellulite appendicitis | \n"," Answer: UnRecognizable. | \n"," Answer: A. simple appendicitis. | \n"," False | \n","
\n"," \n"," 0 | \n"," robustness | \n"," image_noise | \n"," | \n"," | \n"," Identify the following rhythm: | \n"," A. Sinus Rhythm with PAC's\\nB. Junctional Rhythm\\nC. 2nd Degree AV Block, Type I\\nD. 3rd Degree AV Block\\nE. Normal Sinus Rhythm with PVC's\\nF. Idioventricular Rhythm | \n"," Answer: UnRecognizable. | \n"," Answer: UnRecognizable. | \n"," True | \n","
\n"," \n","
"],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from IPython.display import display, HTML\n","\n","\n","df = harness.generated_results()\n","html=df.sample(5).to_html(escape=False)\n","\n","display(HTML(html))"]},{"cell_type":"markdown","metadata":{},"source":["Called after harness.generate() and is to used to run all the tests. Returns a pass/fail flag for each test."]},{"cell_type":"markdown","metadata":{"id":"106TE41ffw43"},"source":["This method returns the generated results in the form of a pandas dataframe, which provides a convenient and easy-to-use format for working with the test results. You can use this method to quickly identify the test cases that failed and to determine where fixes are needed."]},{"cell_type":"markdown","metadata":{"id":"_0gnozMlkoF0"},"source":["### Report of the tests"]},{"cell_type":"code","execution_count":15,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":112},"executionInfo":{"elapsed":22,"status":"ok","timestamp":1692340564522,"user":{"displayName":"Prikshit sharma","userId":"07819241395213139913"},"user_tz":-330},"id":"YKFvMs0RGHO7","outputId":"3a0ed33b-aa59-4e98-86d0-8d407391b0e4"},"outputs":[{"data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," category | \n"," test_type | \n"," fail_count | \n"," pass_count | \n"," pass_rate | \n"," minimum_pass_rate | \n"," pass | \n","
\n"," \n"," \n"," \n"," 0 | \n"," robustness | \n"," image_noise | \n"," 3 | \n"," 2 | \n"," 40% | \n"," 50% | \n"," False | \n","
\n"," \n"," 1 | \n"," robustness | \n"," image_rotate | \n"," 2 | \n"," 3 | \n"," 60% | \n"," 50% | \n"," True | \n","
\n"," \n"," 2 | \n"," robustness | \n"," image_blur | \n"," 2 | \n"," 3 | \n"," 60% | \n"," 50% | \n"," True | \n","
\n"," \n"," 3 | \n"," robustness | \n"," image_resize | \n"," 2 | \n"," 3 | \n"," 60% | \n"," 50% | \n"," True | \n","
\n"," \n","
\n","
"],"text/plain":[" category test_type fail_count pass_count pass_rate \\\n","0 robustness image_noise 3 2 40% \n","1 robustness image_rotate 2 3 60% \n","2 robustness image_blur 2 3 60% \n","3 robustness image_resize 2 3 60% \n","\n"," minimum_pass_rate pass \n","0 50% False \n","1 50% True \n","2 50% True \n","3 50% True "]},"execution_count":15,"metadata":{},"output_type":"execute_result"}],"source":["harness.report()"]},{"cell_type":"markdown","metadata":{"id":"bSP2QL6agTH_"},"source":["Called after harness.run() and it summarizes the results giving information about pass and fail counts and overall test pass/fail flag."]}],"metadata":{"accelerator":"GPU","colab":{"machine_shape":"hm","provenance":[],"toc_visible":true},"gpuClass":"standard","kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.10"}},"nbformat":4,"nbformat_minor":0}
diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py
index 4de9999f4..c12a11662 100644
--- a/langtest/datahandler/datasource.py
+++ b/langtest/datahandler/datasource.py
@@ -95,6 +95,12 @@
"anti-stereotype": ["anti-stereotype"],
"unrelated": ["unrelated"],
},
+ "visualqa": {
+ "image": ["image", "image_1"],
+ "question": ["question"],
+ "options": ["options"],
+ "answer": ["answer"],
+ },
}
@@ -183,7 +189,7 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) ->
raise ValueError(Errors.E024)
if "data_source" not in file_path:
- raise ValueError(Errors.E025)
+ raise ValueError(Errors.E025())
self._custom_label = file_path.copy()
self._file_path = file_path.get("data_source")
self._size = None
@@ -1246,6 +1252,7 @@ class HuggingFaceDataset(BaseDataset):
"summarization",
"ner",
"question-answering",
+ "visualqa",
]
LIB_NAME = "datasets"
@@ -1709,6 +1716,7 @@ class PandasDataset(BaseDataset):
"legal",
"factuality",
"stereoset",
+ "visualqa",
]
COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks}
diff --git a/langtest/langtest.py b/langtest/langtest.py
index d7a1f15cd..09df1b57d 100644
--- a/langtest/langtest.py
+++ b/langtest/langtest.py
@@ -605,6 +605,7 @@ def generated_results(self) -> Optional[pd.DataFrame]:
"model_name",
"category",
"test_type",
+ "original_image",
"original",
"context",
"prompt",
@@ -613,8 +614,10 @@ def generated_results(self) -> Optional[pd.DataFrame]:
"completion",
"test_case",
"perturbed_context",
+ "perturbed_image",
"perturbed_question",
"sentence",
+ "question",
"patient_info_A",
"patient_info_B",
"case",
@@ -838,6 +841,7 @@ def testcases(self, additional_cols=False) -> pd.DataFrame:
"model_name",
"category",
"test_type",
+ "original_image",
"original",
"context",
"original_context",
@@ -863,7 +867,9 @@ def testcases(self, additional_cols=False) -> pd.DataFrame:
"correct_sentence",
"incorrect_sentence",
"perturbed_context",
+ "perturbed_image",
"perturbed_question",
+ "question",
"ground_truth",
"options",
"expected_result",
diff --git a/langtest/modelhandler/llm_modelhandler.py b/langtest/modelhandler/llm_modelhandler.py
index c65387402..968928e12 100644
--- a/langtest/modelhandler/llm_modelhandler.py
+++ b/langtest/modelhandler/llm_modelhandler.py
@@ -13,6 +13,7 @@
import logging
from functools import lru_cache
from langtest.utils.custom_types.helpers import HashableDict
+from langchain.chat_models.base import BaseChatModel
class PretrainedModelForQA(ModelAPI):
@@ -452,3 +453,57 @@ class PretrainedModelForSycophancy(PretrainedModelForQA, ModelAPI):
"""
pass
+
+
+class PretrainedModelForVisualQA(PretrainedModelForQA, ModelAPI):
+ """A class representing a pretrained model for visual question answering.
+
+ Inherits:
+ PretrainedModelForQA: The base class for pretrained models.
+ """
+
+ @lru_cache(maxsize=102400)
+ def predict(
+ self, text: Union[str, dict], prompt: dict, images: List[Any], *args, **kwargs
+ ):
+ """Perform prediction using the pretrained model.
+
+ Args:
+ text (Union[str, dict]): The input text or dictionary.
+ prompt (dict): The prompt configuration.
+ images (List[Any]): The list of images.
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ dict: A dictionary containing the prediction result.
+ - 'result': The prediction result.
+ """
+ try:
+ if not isinstance(self.model, BaseChatModel):
+ ValueError("visualQA task is only supported for chat models")
+
+ # prepare prompt
+ prompt_template = PromptTemplate(**prompt)
+ from langchain_core.messages import HumanMessage
+
+ images = [
+ {
+ "type": "image_url",
+ "image_url": {"url": image},
+ }
+ for image in images
+ ]
+
+ messages = HumanMessage(
+ content=[
+ {"type": "text", "text": prompt_template.format(**text)},
+ *images,
+ ]
+ )
+
+ response = self.model.invoke([messages])
+ return response.content
+
+ except Exception as e:
+ raise ValueError(Errors.E089(error_message=e))
diff --git a/langtest/tasks/task.py b/langtest/tasks/task.py
index 93af99114..0e5134eae 100644
--- a/langtest/tasks/task.py
+++ b/langtest/tasks/task.py
@@ -851,3 +851,44 @@ def create_sample(
class FillMask(BaseTask):
pass
+
+
+class VisualQA(BaseTask):
+ _name = "visualqa"
+ _default_col = {
+ "image": ["image"],
+ "question": ["question"],
+ "answer": ["answer"],
+ }
+ sample_class = samples.VisualQASample
+
+ def create_sample(
+ cls,
+ row_data: dict,
+ image: str = "image_1",
+ question: str = "question",
+ options: str = "options",
+ answer: str = "answer",
+ dataset_name: str = "",
+ ) -> samples.VisualQASample:
+ """Create a sample."""
+ keys = list(row_data.keys())
+
+ # auto-detect the default column names from the row_data
+ column_mapper = cls.column_mapping(keys, [image, question, options, answer])
+
+ options = row_data.get(column_mapper.get(options, "-"), "-")
+
+ if len(options) > 3 and options[0] == "[" and options[-1] == "]":
+ options = ast.literal_eval(row_data[column_mapper["options"]])
+ options = "\n".join(
+ [f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]
+ )
+
+ return samples.VisualQASample(
+ original_image=row_data[column_mapper[image]],
+ question=row_data[column_mapper[question]],
+ options=options,
+ expected_result=row_data[column_mapper[answer]],
+ dataset_name=dataset_name,
+ )
diff --git a/langtest/transform/__init__.py b/langtest/transform/__init__.py
index 3cb59ebd6..0c4f41c9b 100644
--- a/langtest/transform/__init__.py
+++ b/langtest/transform/__init__.py
@@ -22,6 +22,8 @@
from langtest.transform.grammar import GrammarTestFactory
from langtest.transform.safety import SafetyTestFactory
+from langtest.transform import image
+
# Fixing the asyncio event loop
nest_asyncio.apply()
@@ -47,4 +49,5 @@
SycophancyTestFactory,
GrammarTestFactory,
SafetyTestFactory,
+ image,
]
diff --git a/langtest/transform/image/__init__.py b/langtest/transform/image/__init__.py
new file mode 100644
index 000000000..f02586ce0
--- /dev/null
+++ b/langtest/transform/image/__init__.py
@@ -0,0 +1,3 @@
+from .robustness import ImageResizing, ImageRotation, ImageBlur, ImageNoise
+
+__all__ = [ImageResizing, ImageRotation, ImageBlur, ImageNoise]
diff --git a/langtest/transform/image/robustness.py b/langtest/transform/image/robustness.py
new file mode 100644
index 000000000..3444abfe9
--- /dev/null
+++ b/langtest/transform/image/robustness.py
@@ -0,0 +1,286 @@
+import random
+from typing import List, Tuple, Union
+from langtest.logger import logger
+from langtest.transform.robustness import BaseRobustness
+from langtest.utils.custom_types.sample import Sample
+from PIL import Image, ImageFilter
+
+
+class ImageResizing(BaseRobustness):
+ alias_name = "image_resize"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample],
+ resize: Union[float, Tuple[int, int]] = 0.5,
+ *args,
+ **kwargs,
+ ) -> List[Sample]:
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_resize"
+ if isinstance(resize, float):
+ sample.perturbed_image = sample.original_image.resize(
+ (
+ int(sample.original_image.width * resize),
+ int(sample.original_image.height * resize),
+ )
+ )
+ else:
+ sample.perturbed_image = sample.original_image.resize(resize)
+
+ return sample_list
+
+
+class ImageRotation(BaseRobustness):
+ alias_name = "image_rotate"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample], angle: int = 90, exapand=True, *args, **kwargs
+ ) -> List[Sample]:
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_rotate"
+ sample.perturbed_image = sample.original_image.rotate(angle, expand=True)
+
+ return sample_list
+
+
+class ImageBlur(BaseRobustness):
+ alias_name = "image_blur"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample], radius: int = 2, *args, **kwargs
+ ) -> List[Sample]:
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_blur"
+ sample.perturbed_image = sample.original_image.filter(
+ ImageFilter.GaussianBlur(radius)
+ )
+
+ return sample_list
+
+
+class ImageNoise(BaseRobustness):
+ alias_name = "image_noise"
+ supported_tasks = ["visualqa"]
+
+ @classmethod
+ def transform(
+ cls, sample_list: List[Sample], noise: float = 0.1, *args, **kwargs # Noise level
+ ) -> List[Sample]:
+ try:
+ if noise < 0 or noise > 1:
+ raise ValueError("Noise level must be in the range [0, 1].")
+
+ # Get image size
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_noise"
+ sample.perturbed_image = cls.add_noise(
+ image=sample.original_image, noise_level=noise
+ )
+ return sample_list
+
+ except Exception as e:
+ logger.error(f"Error in adding noise to the image: {e}")
+ raise e
+
+ @staticmethod
+ def add_noise(image: Image.Image, noise_level: float) -> Image:
+ width, height = image.size
+
+ # Create a new image to hold the noisy version
+ noisy_image = image.copy()
+ pixels = noisy_image.load() # Access pixel data
+
+ # Check if the image is grayscale or RGB
+ if image.mode == "L": # Grayscale image
+ for x in range(width):
+ for y in range(height):
+ # Get the pixel value
+ gray = image.getpixel((x, y))
+
+ # Generate random noise
+ noise_gray = int(random.gauss(0, 255 * noise_level))
+
+ # Add noise and clip the value to stay in [0, 255]
+ new_gray = max(0, min(255, gray + noise_gray))
+
+ # Set the new pixel value
+ pixels[x, y] = new_gray
+
+ elif image.mode == "RGB": # Color image
+ for x in range(width):
+ for y in range(height):
+ r, g, b = image.getpixel((x, y)) # Get the RGB values of the pixel
+
+ # Generate random noise for each channel
+ noise_r = int(random.gauss(0, 255 * noise_level))
+ noise_g = int(random.gauss(0, 255 * noise_level))
+ noise_b = int(random.gauss(0, 255 * noise_level))
+
+ # Add noise to each channel and clip values to stay in range [0, 255]
+ new_r = max(0, min(255, r + noise_r))
+ new_g = max(0, min(255, g + noise_g))
+ new_b = max(0, min(255, b + noise_b))
+
+ # Set the new pixel value
+ pixels[x, y] = (new_r, new_g, new_b)
+
+ else:
+ raise ValueError("The input image must be in 'L' (grayscale) or 'RGB' mode.")
+
+ return noisy_image
+
+
+class ImageConstrast(BaseRobustness):
+ alias_name = "image_contrast"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample], contrast_factor: float = 0.5, *args, **kwargs
+ ) -> List[Sample]:
+ from PIL import ImageEnhance
+
+ if contrast_factor < 0:
+ raise ValueError("Contrast factor must be above 0.")
+
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_contrast"
+ img = ImageEnhance.Contrast(sample.original_image)
+ sample.perturbed_image = img.enhance(contrast_factor)
+
+ return sample_list
+
+
+class ImageBrightness(BaseRobustness):
+ alias_name = "image_brightness"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample], brightness_factor: float = 0.3, *args, **kwargs
+ ) -> List[Sample]:
+ from PIL import ImageEnhance
+
+ if brightness_factor < 0:
+ raise ValueError("Brightness factor must be above 0.")
+
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_brightness"
+ enchancer = ImageEnhance.Brightness(sample.original_image)
+ sample.perturbed_image = enchancer.enhance(brightness_factor)
+
+ return sample_list
+
+
+class ImageSharpness(BaseRobustness):
+ alias_name = "image_sharpness"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample], sharpness_factor: float = 1.5, *args, **kwargs
+ ) -> List[Sample]:
+ from PIL import ImageEnhance
+
+ if sharpness_factor < 0:
+ raise ValueError("Sharpness factor must be above 0.")
+
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_sharpness"
+ enchancer = ImageEnhance.Sharpness(sample.original_image)
+ sample.perturbed_image = enchancer.enhance(sharpness_factor)
+
+ return sample_list
+
+
+class ImageColor(BaseRobustness):
+ 3
+ alias_name = "image_color"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample], color_factor: float = 0, *args, **kwargs
+ ) -> List[Sample]:
+ from PIL import ImageEnhance
+
+ if color_factor < 0:
+ raise ValueError("Color factor must be in the range [0, inf].")
+
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_color"
+ enchancer = ImageEnhance.Color(sample.original_image)
+ sample.perturbed_image = enchancer.enhance(color_factor)
+
+ return sample_list
+
+
+class ImageFlip(BaseRobustness):
+ alias_name = "image_flip"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample], flip: str = "horizontal", *args, **kwargs
+ ) -> List[Sample]:
+ if flip not in ["horizontal", "vertical"]:
+ raise ValueError("Flip must be either 'horizontal' or 'vertical'.")
+
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_flip"
+ if flip == "horizontal":
+ sample.perturbed_image = sample.original_image.transpose(
+ Image.FLIP_LEFT_RIGHT
+ )
+ else:
+ sample.perturbed_image = sample.original_image.transpose(
+ Image.FLIP_TOP_BOTTOM
+ )
+
+ return sample_list
+
+
+class ImageCrop(BaseRobustness):
+ alias_name = "image_crop"
+ supported_tasks = ["visualqa"]
+
+ @staticmethod
+ def transform(
+ sample_list: List[Sample],
+ crop_size: Union[float, Tuple[int, int]] = (100, 100),
+ *args,
+ **kwargs,
+ ) -> List[Sample]:
+ for sample in sample_list:
+ sample.category = "robustness"
+ sample.test_type = "image_crop"
+ if isinstance(crop_size, float):
+ sample.perturbed_image = sample.original_image.crop(
+ (
+ 0,
+ 0,
+ int(sample.original_image.width * crop_size),
+ int(sample.original_image.height * crop_size),
+ )
+ )
+ else:
+ sample.perturbed_image = sample.original_image.crop(
+ (0, 0, crop_size[0], crop_size[1])
+ )
+
+ return sample_list
diff --git a/langtest/transform/utils.py b/langtest/transform/utils.py
index 4540155bf..0fc2dcd23 100644
--- a/langtest/transform/utils.py
+++ b/langtest/transform/utils.py
@@ -397,6 +397,8 @@ def filter_unique_samples(task: str, transformed_samples: list, test_name: str):
no_transformation_applied_tests[test_name] += 1
else:
no_transformation_applied_tests[test_name] = 1
+ elif task == "visualqa":
+ return transformed_samples, no_transformation_applied_tests
else:
for sample in transformed_samples:
if sample.original.replace(" ", "") != sample.test_case.replace(" ", ""):
diff --git a/langtest/utils/custom_types/__init__.py b/langtest/utils/custom_types/__init__.py
index 41d60e870..82e3e62f0 100644
--- a/langtest/utils/custom_types/__init__.py
+++ b/langtest/utils/custom_types/__init__.py
@@ -22,6 +22,7 @@
CrowsPairsSample,
StereoSetSample,
TextGenerationSample,
+ VisualQASample,
)
from .helpers import Span, Transformation
from .output import (
diff --git a/langtest/utils/custom_types/sample.py b/langtest/utils/custom_types/sample.py
index 8477fb9bb..f6e088b39 100644
--- a/langtest/utils/custom_types/sample.py
+++ b/langtest/utils/custom_types/sample.py
@@ -3,6 +3,8 @@
import importlib
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, Callable
from copy import deepcopy
+
+from langtest.modelhandler.modelhandler import ModelAPI
from ...errors import Errors
from pydantic import BaseModel, PrivateAttr, validator, Field
from .helpers import Transformation, Span
@@ -2751,6 +2753,320 @@ class FillMaskSample(TextGenerationSample):
pass
+class VisualQASample(BaseModel):
+ """
+ A class representing a sample for the Visual Question Answering task.
+
+ Attributes:
+ original_image (str): The original image used for the test.
+ perturbed_image (str): The perturbed image used for the test.
+ question (str): The question asked about the image.
+ ground_truth (str): The ground truth answer to the question.
+ expected_result (str): The expected result of the test.
+ actual_result (str): The actual result of the test.
+ """
+
+ from PIL.Image import Image
+
+ original_image: Union[Image, str, Any] = None
+ perturbed_image: Union[Image, str, Any] = None
+ question: str = None
+ options: str = None
+ ground_truth: str = None
+ expected_results: str = None
+ actual_results: str = None
+ dataset_name: str = None
+ category: str = None
+ test_type: str = None
+ state: str = None
+ task: str = None
+ ran_pass: bool = None
+ metric_name: str = None
+ config: Union[str, dict] = None
+ state: str = None
+ task: str = Field(default="visualqa", const=True)
+ distance_result: float = None
+ eval_model: str = None
+ feedback: str = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def __init__(self, **data):
+ super().__init__(**data)
+ self.original_image = self.__load_image(self.original_image)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Converts the VisualQASample object to a dictionary.
+
+ Returns:
+ Dict[str, Any]: A dictionary representation of the VisualQASample object.
+ """
+ self.__update_params()
+
+ result = {
+ "category": self.category,
+ "test_type": self.test_type,
+ "original_image": self.convert_image_to_html(self.original_image),
+ "perturbed_image": self.convert_image_to_html(self.perturbed_image),
+ "question": self.question,
+ }
+
+ if self.options is not None:
+ result["options"] = self.options
+
+ if self.state == "done":
+ if self.expected_results is not None and self.actual_results is not None:
+ result.update(
+ {
+ "expected_result": self.expected_results,
+ "actual_result": self.actual_results,
+ "pass": self.is_pass(),
+ }
+ )
+ if "evaluation" in self.config and "metric" in self.config["evaluation"]:
+ if self.config["evaluation"]["metric"].lower() == "prometheus_eval":
+ result.update({"feedback": self.feedback})
+ elif self.config["evaluation"]["metric"].lower() != "llm_eval":
+ result.update({"eval_score": self.distance_result})
+
+ return result
+
+ def run(self, model: ModelAPI, **kwargs):
+ """
+ Run the VisualQASample test using the provided model.
+
+ Args:
+ model: The model used for VisualQASample testing.
+ **kwargs: Additional keyword arguments for the model.
+
+ Returns:
+ bool: True
+ """
+
+ dataset_name = self.dataset_name.split("-")[0].lower()
+ prompt_template = kwargs.get(
+ "user_prompt",
+ default_user_prompt.get(
+ dataset_name,
+ (
+ """You are an AI Vision bot specializing in providing accurate and concise answers to multiple-choice questions. You will be presented with a question and options. Choose the correct answer.
+
+Example:
+
+Question: What is the capital of France ?
+
+Options:
+A. Berlin
+B. Madrid
+C. Paris
+D. Rome
+
+Answer: C. Paris.
+
+Example 2:
+
+Question: What is in the image ?
+
+Options:
+A. Dog
+B. Cat
+C. Elephant
+D. Ear
+
+Answer: UnRecognizable.
+"""
+ " Similary \n Question: {question}\nOptions: {options}\n Answer:"
+ ),
+ ),
+ )
+
+ server_prompt = kwargs.get("server_prompt", " ")
+
+ text_dict = {
+ "question": self.question,
+ }
+ input_variables = ["question"]
+
+ if self.options is not None:
+ text_dict["options"] = self.options
+ input_variables.append("options")
+
+ payload = {
+ "text": text_dict,
+ "prompt": {
+ "template": prompt_template,
+ "input_variables": input_variables,
+ },
+ }
+
+ # convert the image to base64 url
+ orig_image = self.convert_image_to_bas64_url(self.original_image)
+ pred_image = self.convert_image_to_bas64_url(self.perturbed_image)
+
+ self.expected_results = model(
+ **payload,
+ images=(orig_image,),
+ server_prompt=server_prompt,
+ )
+ self.actual_results = model(
+ **payload,
+ images=(pred_image,),
+ server_prompt=server_prompt,
+ )
+ return True
+
+ def transform(self, func: Callable, params: Dict, **kwargs):
+ """
+ Transform the original image using a specified function.
+
+ Args:
+ func (Callable): The transformation function.
+ params (Dict): Parameters for the transformation function.
+ **kwargs: Additional keyword arguments for the transformation.
+
+ """
+ sens = [self.original_image]
+ self.perturbed_image = func(sens, **params, **kwargs)
+ self.category = func.__module__.split(".")[-1]
+
+ return self
+
+ def __load_image(self, image_path):
+ # check the image path as url using regex
+ import requests
+ from PIL.Image import Image
+ import io
+ import base64
+
+ if isinstance(image_path, dict) and "bytes" in image_path:
+ image = Image.open(io.BytesIO(image_path["bytes"]))
+ elif isinstance(image_path, str) and re.match(r"^https?://", image_path):
+ response = requests.get(image_path)
+ image = Image.open(io.BytesIO(response.content))
+ elif isinstance(image_path, str) and re.match(r"^data:image", image_path):
+ image = Image.open(io.BytesIO(base64.b64decode(image_path.split(",")[1])))
+ elif isinstance(image_path, Image):
+ image = image_path
+ else:
+ image = Image.open(image_path)
+ return image.convert("RGB")
+
+ def convert_image_to_html(self, image: Image):
+ import io
+ import base64
+
+ if image is not None:
+ image = image.copy()
+ buffered = io.BytesIO()
+ image.thumbnail((200, 200))
+ image.save(buffered, format="PNG")
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return f''
+
+ def convert_image_to_bas64_url(self, image: Image):
+ import io
+ import base64
+
+ if image is not None:
+ image = image.copy()
+ buffered = io.BytesIO()
+ image.thumbnail((400, 400))
+ image.save(buffered, format="PNG")
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return f"data:image/png;base64,{img_str}"
+
+ def __update_params(self):
+ from ...langtest import HARNESS_CONFIG as harness_config
+
+ self.config = harness_config
+ self.metric_name = (
+ self.config.get("evaluation", {}).get("metric", "llm_eval").lower()
+ )
+
+ if self.state == "done":
+ from ...langtest import EVAL_MODEL
+
+ if (
+ "evaluation" in harness_config
+ and "metric" in harness_config["evaluation"]
+ ):
+ if harness_config["evaluation"]["metric"].lower() == "llm_eval":
+ model = harness_config["evaluation"].get("model", None)
+ hub = harness_config["evaluation"].get("hub", None)
+ if model and hub:
+ from ...tasks import TaskManager
+
+ load_eval_model = TaskManager(self.task)
+ self.eval_model = load_eval_model.model(
+ model, hub, **harness_config.get("model_parameters", {})
+ )
+
+ else:
+ self.eval_model = EVAL_MODEL
+
+ def is_pass(self) -> bool:
+ """Checks if the sample has passed the evaluation.
+
+ Returns:
+ bool: True if the sample passed the evaluation, False otherwise.
+ """
+
+ if self.ran_pass is not None:
+ return self.ran_pass
+ elif self.expected_results.strip().lower() == self.actual_results.strip().lower():
+ self.ran_pass = True
+ return True
+ else:
+ self.__update_params()
+ try:
+ metric_module = importlib.import_module(
+ "langtest.utils.custom_types.helpers"
+ )
+ metric_function = getattr(metric_module, f"is_pass_{self.metric_name}")
+ except (ImportError, AttributeError):
+ raise ValueError(f"Metric '{self.metric_name}' not found.")
+
+ if self.metric_name == "string_distance":
+ selected_distance = self.config["evaluation"].get("distance", "jaro")
+ threshold = self.config["evaluation"].get("threshold")
+
+ elif self.metric_name == "embedding_distance":
+ selected_distance = self.config["evaluation"].get("distance", "cosine")
+ threshold = self.config["evaluation"].get("threshold")
+
+ if self.metric_name in (
+ "string_distance",
+ "embedding_distance",
+ ):
+ self.distance_result, result = metric_function(
+ answer=self.expected_results,
+ prediction=self.actual_results,
+ selected_distance=selected_distance,
+ threshold=threshold,
+ )
+ self.ran_pass = result
+ return result
+ elif self.metric_name == "llm_eval":
+ if isinstance(self.eval_model, dict):
+ self.eval_model = list(self.eval_model.values())[-1]
+ result = metric_function(
+ eval_model=self.eval_model,
+ dataset_name=self.dataset_name,
+ original_question=" " + self.question,
+ answer=self.expected_results,
+ perturbed_question=" " + self.question,
+ prediction=self.actual_results,
+ )
+
+ self.ran_pass = result
+ return result
+
+ else:
+ raise ValueError(f"Metric '{self.metric_name}' not found.")
+
+
Sample = TypeVar(
"Sample",
MaxScoreSample,
@@ -2772,4 +3088,5 @@ class FillMaskSample(TextGenerationSample):
LegalSample,
CrowsPairsSample,
StereoSetSample,
+ VisualQASample,
)
diff --git a/poetry.lock b/poetry.lock
index 3526d8014..b3655893c 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -3344,7 +3344,7 @@ files = [
name = "pillow"
version = "10.0.0"
description = "Python Imaging Library (Fork)"
-optional = true
+optional = false
python-versions = ">=3.8"
files = [
{file = "Pillow-10.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1f62406a884ae75fb2f818694469519fb685cc7eaff05d3451a9ebe55c646891"},
@@ -5753,4 +5753,4 @@ transformers = ["accelerate", "datasets", "torch", "transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
-content-hash = "f43231a0fd18c0d2b740ccad37045fd68294240109c0744b13973dc3ec2f445d"
+content-hash = "7c8dc3eabf8a4d28f97b9be0f2a9fb70261baef10e3d2ef996fe56a906c36a45"
diff --git a/pyproject.toml b/pyproject.toml
index 724645aab..074ef4b92 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -82,6 +82,7 @@ importlib-resources = "^6.4.0"
click = "^8.1.7"
openpyxl = "^3.1.5"
tables = "3.8.0"
+pillow = "10.0.0"
[tool.poetry.extras]
transformers = ["transformers", "torch", "accelerate", "datasets"]
diff --git a/tests/test_robustness.py b/tests/test_robustness.py
index 70e6bd78f..8b332db87 100644
--- a/tests/test_robustness.py
+++ b/tests/test_robustness.py
@@ -469,7 +469,10 @@ def setUp(self) -> None:
test: list(scenarios.keys()) for test, scenarios in test_scenarios.items()
}
- self.perturbations_list = self.available_tests["robustness"]
+ self.perturbations_list = [
+ i for i in self.available_tests["robustness"] if not i.startswith("image_")
+ ]
+
self.supported_tests = self.available_test()
self.samples = {
"question-answering": [