From c2f4c2649114380345115e338a63b26880dd4963 Mon Sep 17 00:00:00 2001 From: Kenneth Enevoldsen Date: Wed, 4 Dec 2024 11:39:28 +0100 Subject: [PATCH] Add cohere models (#1538) * fix: bug cohere names * format --- mteb/leaderboard/app.py | 1 - mteb/load_results/benchmark_results.py | 4 ++-- mteb/models/cohere_models.py | 3 +-- pyproject.toml | 1 + 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mteb/leaderboard/app.py b/mteb/leaderboard/app.py index c4e5e80ef..c51dc7a50 100644 --- a/mteb/leaderboard/app.py +++ b/mteb/leaderboard/app.py @@ -5,7 +5,6 @@ from pathlib import Path import gradio as gr -import pandas as pd from gradio_rangeslider import RangeSlider import mteb diff --git a/mteb/load_results/benchmark_results.py b/mteb/load_results/benchmark_results.py index 756024a4e..25a332e2c 100644 --- a/mteb/load_results/benchmark_results.py +++ b/mteb/load_results/benchmark_results.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Iterable from pathlib import Path -from typing import Any, Callable, Literal, Optional +from typing import Any, Callable, Literal import numpy as np import pandas as pd @@ -229,7 +229,7 @@ def filter_models( return type(self).model_construct(model_results=new_model_results) def join_revisions(self): - def parse_version(version_str: str) -> Optional[Version]: + def parse_version(version_str: str) -> Version | None: try: return Version(version_str) except (InvalidVersion, TypeError): diff --git a/mteb/models/cohere_models.py b/mteb/models/cohere_models.py index ec86d2d1b..2ed0b76a9 100644 --- a/mteb/models/cohere_models.py +++ b/mteb/models/cohere_models.py @@ -210,7 +210,7 @@ def encode( cohere_eng_3 = ModelMeta( loader=partial( CohereTextEmbeddingModel, - model_name="embed-multilingual-v3.0", + model_name="embed-english-v3.0", model_prompts=model_prompts, ), name="Cohere/Cohere-embed-english-v3.0", @@ -229,7 +229,6 @@ def encode( use_instructions=False, ) - cohere_mult_light_3 = ModelMeta( loader=partial( CohereTextEmbeddingModel, diff --git a/pyproject.toml b/pyproject.toml index b783ec80a..10154edc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,7 @@ ignore = ["E501", # line too long "D107", # Missing docstring in __init__ "D205", # 1 blank line required between summary line and description "D415", # First line should end with a period + "C408", # don't use unecc. collection call, e.g. dict over {} ] [tool.ruff.lint.flake8-implicit-str-concat]