Skip to content

Commit

Permalink
test example scripts in ci (#40)
Browse files Browse the repository at this point in the history
* test example scripts in ci

* specify pool size

* smaller batch size

* try scifact

* try smaller partitions

* use max length from model

* smaller partition num
  • Loading branch information
edknv authored Dec 18, 2023
1 parent 05accee commit 2a7eefe
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
5 changes: 3 additions & 2 deletions crossfit/op/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from typing import Optional

import cudf
from cudf.core.subword_tokenizer import SubwordTokenizer, _cast_to_appropriate_type
Expand All @@ -32,11 +33,11 @@ def __init__(
cols=None,
keep_cols=None,
pre=None,
max_length: int = 1024,
max_length: Optional[int] = None,
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.model = model
self.max_length = max_length
self.max_length = max_length or model.max_seq_length()

# Make sure we download the tokenizer just once
GPUTokenizer.from_pretrained(self.model)
Expand Down
7 changes: 7 additions & 0 deletions examples/beir_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def parse_arguments():
)
parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
parser.add_argument("--k", type=int, default=10, help="Nearest neighbors")
parser.add_argument(
"--partition-num",
type=int,
default=50_000,
help="Number of items to allocate to each partition",
)

args = parser.parse_args()
return args
Expand All @@ -39,6 +45,7 @@ def main():
overwrite=args.overwrite,
sorted_data_loader=args.sorted_dataloader,
batch_size=args.batch_size,
partition_num=args.partition_num,
)

report.console()
Expand Down
46 changes: 46 additions & 0 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import runpy
import shutil
import sys
import tempfile

import pytest

examples_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples")


@pytest.mark.singlegpu
@pytest.mark.parametrize(
"script",
[
"beir_report.py",
],
)
def test_script_execution(script):
path = os.path.join(examples_dir, script)
orig_sys_argv = sys.argv

with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = os.path.join(tmpdir, script)
shutil.copy2(path, tmp_path)
# argv[0] will be replaced by runpy
sys.argv = [
"",
"--overwrite",
"--num-workers",
"1",
"--dataset",
"fiqa",
"--pool-size",
"12GB",
"--batch-size",
"8",
"--partition-num",
"100",
]
runpy.run_path(
tmp_path,
run_name="__main__",
)

sys.argv = orig_sys_argv

0 comments on commit 2a7eefe

Please sign in to comment.