Skip to content

Commit

Permalink
adjust alpaca and llama reference scores in tests (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkupnicki authored Oct 14, 2024
1 parent 36444eb commit 02fb34e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_pytorch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def wrapper(**kwargs):
@unittest.skipIf(psutil.virtual_memory().available / 1024 ** 3 < 100, "too little memory")
@unittest.skipUnless('_aio_profiler_print' in dir(torch._C), "Ampere optimized PyTorch required")
def test_llama2_7b(self):
f1_ref = 0.349
f1_ref = 0.330
acc = run_process(self.wrapper,
{"model_name": "meta-llama/Llama-2-7b-chat-hf", "batch_size": 1, "num_runs": 50,
"timeout": None, "dataset_path": self.dataset_path})
Expand All @@ -60,7 +60,7 @@ def test_llama2_7b(self):
@unittest.skipIf(psutil.virtual_memory().available / 1024 ** 3 < 200, "too little memory")
@unittest.skipUnless('_aio_profiler_print' in dir(torch._C), "Ampere optimized PyTorch required")
def test_llama2_13b(self):
f1_ref = 0.195
f1_ref = 0.261
acc = run_process(self.wrapper,
{"model_name": "meta-llama/Llama-2-13b-chat-hf", "batch_size": 1, "num_runs": 50,
"timeout": None, "dataset_path": self.dataset_path})
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_alpaca(self):
def wrapper(**kwargs):
kwargs["q"].put(run_pytorch_fp32(**kwargs)[0])

exact_match_ref, f1_ref = 0.260, 0.616
exact_match_ref, f1_ref = 0.180, 0.548
acc = run_process(wrapper, {"model_path": self.model_path, "batch_size": 1, "num_runs": 50,
"timeout": None, "dataset_path": self.dataset_path})
self.assertTrue(acc["exact_match"] / exact_match_ref > 0.95)
Expand Down

0 comments on commit 02fb34e

Please sign in to comment.