Skip to content

Commit

Permalink
fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xpbowler committed Sep 3, 2024
1 parent 77576e3 commit e768a59
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 38 deletions.
3 changes: 2 additions & 1 deletion src/rank_llm/rerank/rankllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import Any, Dict, List, Tuple, Union

from rank_llm.data import Request, Result
from rank_llm.rerank import Prompt

logger = logging.getLogger(__name__)


class RankLLM(ABC):
def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None:
def __init__(self, model: str, context_size: int, prompt_mode: Prompt) -> None:
self._model = model
self._context_size = context_size
self._prompt_mode = prompt_mode
Expand Down
48 changes: 24 additions & 24 deletions test/rerank/listwise/test_RankListwiseOSLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from dacite import from_dict

from rank_llm.data import Result
from rank_llm.rerank import PromptMode
from rank_llm.rerank import Prompt
from rank_llm.rerank.listwise import RankListwiseOSLLM

# model, context_size, prompt_mode, num_few_shot_examples, variable_passages, window_size, system_message
valid_inputs = [
(
"castorini/rank_zephyr_7b_v1_full",
4096,
PromptMode.RANK_GPT,
Prompt.RANK_GPT,
0,
True,
10,
Expand All @@ -21,7 +21,7 @@
(
"castorini/rank_zephyr_7b_v1_full",
4096,
PromptMode.RANK_GPT,
Prompt.RANK_GPT,
0,
False,
10,
Expand All @@ -30,20 +30,20 @@
(
"castorini/rank_zephyr_7b_v1_full",
4096,
PromptMode.RANK_GPT,
Prompt.RANK_GPT,
0,
True,
30,
"Default Message",
),
("castorini/rank_zephyr_7b_v1_full", 4096, PromptMode.RANK_GPT, 0, True, 10, ""),
("castorini/rank_vicuna_7b_v1", 4096, PromptMode.RANK_GPT, 0, True, 10, ""),
("castorini/rank_vicuna_7b_v1_noda", 4096, PromptMode.RANK_GPT, 0, True, 10, ""),
("castorini/rank_vicuna_7b_v1_fp16", 4096, PromptMode.RANK_GPT, 0, True, 10, ""),
("castorini/rank_zephyr_7b_v1_full", 4096, Prompt.RANK_GPT, 0, True, 10, ""),
("castorini/rank_vicuna_7b_v1", 4096, Prompt.RANK_GPT, 0, True, 10, ""),
("castorini/rank_vicuna_7b_v1_noda", 4096, Prompt.RANK_GPT, 0, True, 10, ""),
("castorini/rank_vicuna_7b_v1_fp16", 4096, Prompt.RANK_GPT, 0, True, 10, ""),
(
"castorini/rank_vicuna_7b_v1_noda_fp16",
4096,
PromptMode.RANK_GPT,
Prompt.RANK_GPT,
0,
True,
10,
Expand All @@ -55,7 +55,7 @@
(
"castorini/rank_zephyr_7b_v1_full",
4096,
PromptMode.UNSPECIFIED,
Prompt.UNSPECIFIED,
0,
True,
30,
Expand All @@ -64,7 +64,7 @@
(
"castorini/rank_zephyr_7b_v1_full",
4096,
PromptMode.LRL,
Prompt.LRL,
0,
True,
30,
Expand All @@ -73,7 +73,7 @@
(
"castorini/rank_vicuna_7b_v1",
4096,
PromptMode.UNSPECIFIED,
Prompt.UNSPECIFIED,
0,
True,
30,
Expand All @@ -82,7 +82,7 @@
(
"castorini/rank_vicuna_7b_v1",
4096,
PromptMode.LRL,
Prompt.LRL,
0,
True,
30,
Expand All @@ -91,7 +91,7 @@
(
"castorini/rank_vicuna_7b_v1_noda",
4096,
PromptMode.UNSPECIFIED,
Prompt.UNSPECIFIED,
0,
True,
30,
Expand All @@ -100,7 +100,7 @@
(
"castorini/rank_vicuna_7b_v1_noda",
4096,
PromptMode.LRL,
Prompt.LRL,
0,
True,
30,
Expand All @@ -109,7 +109,7 @@
(
"castorini/rank_vicuna_7b_v1_fp16",
4096,
PromptMode.UNSPECIFIED,
Prompt.UNSPECIFIED,
0,
True,
30,
Expand All @@ -118,7 +118,7 @@
(
"castorini/rank_vicuna_7b_v1_fp16",
4096,
PromptMode.LRL,
Prompt.LRL,
0,
True,
30,
Expand All @@ -127,7 +127,7 @@
(
"castorini/rank_vicuna_7b_v1_noda_fp16",
4096,
PromptMode.UNSPECIFIED,
Prompt.UNSPECIFIED,
0,
True,
30,
Expand All @@ -136,7 +136,7 @@
(
"castorini/rank_vicuna_7b_v1_noda_fp16",
4096,
PromptMode.LRL,
Prompt.LRL,
0,
True,
30,
Expand Down Expand Up @@ -256,7 +256,7 @@ def test_num_output_tokens(self, mock_num_output_tokens):
model="castorini/rank_zephyr_7b_v1_full",
name="rank_zephyr",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
prompt_mode=Prompt.RANK_GPT,
num_few_shot_examples=0,
variable_passages=True,
window_size=10,
Expand All @@ -272,7 +272,7 @@ def test_num_output_tokens(self, mock_num_output_tokens):
model="castorini/rank_zephyr_7b_v1_full",
name="rank_zephyr",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
prompt_mode=Prompt.RANK_GPT,
num_few_shot_examples=0,
variable_passages=True,
window_size=5,
Expand All @@ -289,7 +289,7 @@ def test_run_llm(self, mock_run_llm):
model="castorini/rank_zephyr_7b_v1_full",
name="rank_zephyr",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
prompt_mode=Prompt.RANK_GPT,
num_few_shot_examples=0,
variable_passages=True,
window_size=5,
Expand All @@ -311,7 +311,7 @@ def test_create_prompt(
model="castorini/rank_zephyr_7b_v1_full",
name="rank_zephyr",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
prompt_mode=Prompt.RANK_GPT,
num_few_shot_examples=0,
variable_passages=True,
window_size=5,
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_get_num_tokens(self, mock_get_num_tokens):
model="castorini/rank_zephyr_7b_v1_full",
name="rank_zephyr",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
prompt_mode=Prompt.RANK_GPT,
num_few_shot_examples=0,
variable_passages=True,
window_size=5,
Expand Down
26 changes: 13 additions & 13 deletions test/rerank/listwise/test_SafeOpenai.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
import unittest
from unittest.mock import patch

from rank_llm.rerank import PromptMode
from rank_llm.rerank import Prompt
from rank_llm.rerank.listwise import SafeOpenai

# model, context_size, prompt_mode, num_few_shot_examples, keys, key_start_id
valid_inputs = [
("gpt-3.5-turbo", 4096, PromptMode.RANK_GPT, 0, "OPEN_AI_API_KEY", None),
("gpt-3.5-turbo", 4096, PromptMode.LRL, 0, "OPEN_AI_API_KEY", 3),
("gpt-4", 4096, PromptMode.RANK_GPT, 0, "OPEN_AI_API_KEY", None),
("gpt-4", 4096, PromptMode.LRL, 0, "OPEN_AI_API_KEY", 3),
("gpt-3.5-turbo", 4096, Prompt.RANK_GPT, 0, "OPEN_AI_API_KEY", None),
("gpt-3.5-turbo", 4096, Prompt.LRL, 0, "OPEN_AI_API_KEY", 3),
("gpt-4", 4096, Prompt.RANK_GPT, 0, "OPEN_AI_API_KEY", None),
("gpt-4", 4096, Prompt.LRL, 0, "OPEN_AI_API_KEY", 3),
]

failure_inputs = [
("gpt-3.5-turbo", 4096, PromptMode.RANK_GPT, 0, None), # missing key
("gpt-3.5-turbo", 4096, PromptMode.LRL, 0, None), # missing key
("gpt-3.5-turbo", 4096, Prompt.RANK_GPT, 0, None), # missing key
("gpt-3.5-turbo", 4096, Prompt.LRL, 0, None), # missing key
(
"gpt-3.5-turbo",
4096,
PromptMode.UNSPECIFIED,
Prompt.UNSPECIFIED,
0,
"OPEN_AI_API_KEY",
), # unpecified prompt mode
("gpt-4", 4096, PromptMode.RANK_GPT, 0, None), # missing key
("gpt-4", 4096, PromptMode.LRL, 0, None), # missing key
("gpt-4", 4096, Prompt.RANK_GPT, 0, None), # missing key
("gpt-4", 4096, Prompt.LRL, 0, None), # missing key
(
"gpt-4",
4096,
PromptMode.UNSPECIFIED,
Prompt.UNSPECIFIED,
0,
"OPEN_AI_API_KEY",
), # unpecified prompt mode
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_run_llm(self, mock_call_completion):
agent = SafeOpenai(
model="gpt-3.5",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
prompt_mode=Prompt.RANK_GPT,
num_few_shot_examples=0,
keys="OPEN_AI_API_KEY",
)
Expand All @@ -97,7 +97,7 @@ def test_num_output_tokens(self):
agent = SafeOpenai(
model="gpt-3.5",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
prompt_mode=Prompt.RANK_GPT,
num_few_shot_examples=0,
keys="OPEN_AI_API_KEY",
)
Expand Down

0 comments on commit e768a59

Please sign in to comment.