Skip to content

Commit

Permalink
Correction of discrepancies for gte-Qweb model
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyVatolin committed Dec 27, 2024
1 parent 1b06601 commit 68369b4
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 14 deletions.
28 changes: 20 additions & 8 deletions mteb/models/gte_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,36 @@

from functools import partial

import torch

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.instruct_wrapper import instruct_wrapper


def instruction_template(instruction: str) -> str:
return f"Instruct: {instruction}\nQuery: " if instruction else ""
def instruction_template(
instruction: str, prompt_type: PromptType | None = None
) -> str:
return (
f"Instruct: {instruction}\nQuery: "
if (prompt_type is None or prompt_type == PromptType.query) and instruction
else ""
)


gte_Qwen2_7B_instruct = ModelMeta(
loader=partial( # type: ignore
instruct_wrapper,
model_name_or_path="Alibaba-NLP/gte-Qwen2-7B-instruct",
instruction_template=instruction_template,
attn="cccc",
attn="bbcc",
pooling_method="lasttoken",
mode="embedding",
torch_dtype="auto",
torch_dtype=torch.float16,
# The ST script does not normalize while the HF one does so unclear what to do
# https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct#sentence-transformers
normalized=True,
embed_eos="<|endoftext|>",
),
name="Alibaba-NLP/gte-Qwen2-7B-instruct",
languages=None,
Expand All @@ -44,11 +54,12 @@ def instruction_template(instruction: str) -> str:
instruct_wrapper,
model_name_or_path="Alibaba-NLP/gte-Qwen1.5-7B-instruct",
instruction_template=instruction_template,
attn="cccc",
attn="bbcc",
pooling_method="lasttoken",
mode="embedding",
torch_dtype="auto",
torch_dtype=torch.float16,
normalized=True,
embed_eos="<|endoftext|>",
),
name="Alibaba-NLP/gte-Qwen1.5-7B-instruct",
languages=["eng_Latn"],
Expand All @@ -72,11 +83,12 @@ def instruction_template(instruction: str) -> str:
instruct_wrapper,
model_name_or_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
instruction_template=instruction_template,
attn="cccc",
attn="bbcc",
pooling_method="lasttoken",
mode="embedding",
torch_dtype="auto",
torch_dtype=torch.float16,
normalized=True,
embed_eos="<|endoftext|>",
),
name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
languages=["eng_Latn"],
Expand Down
2 changes: 1 addition & 1 deletion mteb/models/instruct_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def encode(
instruction = self.get_instruction(task_name, prompt_type)

if self.instruction_template:
instruction = self.format_instruction(instruction)
instruction = self.format_instruction(instruction, prompt_type)

logger.info(f"Using instruction: '{instruction}' for task: '{task_name}'")
embeddings = super().encode(
Expand Down
5 changes: 4 additions & 1 deletion mteb/models/linq_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

import torch

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.instruct_wrapper import instruct_wrapper


def instruction_template(instruction: str) -> str:
def instruction_template(
instruction: str, prompt_type: PromptType | None = None
) -> str:
return f"Instruct: {instruction}\nQuery: " if instruction else ""


Expand Down
4 changes: 3 additions & 1 deletion mteb/models/nvidia_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
logger = logging.getLogger(__name__)


def instruction_template(instruction: str) -> str:
def instruction_template(
instruction: str, prompt_type: PromptType | None = None
) -> str:
return f"Instruct: {instruction}\nQuery: " if instruction else ""


Expand Down
5 changes: 4 additions & 1 deletion mteb/models/salesforce_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from functools import partial

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.instruct_wrapper import instruct_wrapper


def instruction_template(instruction: str) -> str:
def instruction_template(
instruction: str, prompt_type: PromptType | None = None
) -> str:
return f"Instruct: {instruction}\nQuery: " if instruction else ""


Expand Down
6 changes: 4 additions & 2 deletions mteb/models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,16 @@ def get_instruction(task_name: str, prompt_type: PromptType | None) -> str:
return task_metadata.prompt
return task.abstask_prompt

def format_instruction(self, instruction: str) -> str:
def format_instruction(
self, instruction: str, prompt_type: PromptType | None = None
) -> str:
if isinstance(self.instruction_template, str):
if "{instruction}" not in self.instruction_template:
raise ValueError(
"Instruction template must contain the string '{instruction}'."
)
return self.instruction_template.format(instruction=instruction)
return self.instruction_template(instruction)
return self.instruction_template(instruction, prompt_type)

def get_task_instruction(
self, task_name: str, prompt_type: PromptType | None
Expand Down

0 comments on commit 68369b4

Please sign in to comment.