Skip to content

Commit

Permalink
More repo tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Apr 18, 2024
1 parent af9bb17 commit 61d391c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 24 deletions.
43 changes: 27 additions & 16 deletions tests/test_abp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@

@pytest.mark.slow
@pytest.mark.use_python
@pytest.mark.parametrize('num_threads', [1, 4])
@mock.patch('tritonclient.grpc.InferenceServerClient')
def test_abp_no_cpp(mock_triton_client, config: Config, tmp_path):
def test_abp_no_cpp(mock_triton_client: mock.MagicMock, config: Config, tmp_path: str, num_threads: int):
mock_metadata = {
"inputs": [{
'name': 'input__0', 'datatype': 'FP32', "shape": [-1, FEATURE_LENGTH]
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_abp_no_cpp(mock_triton_client, config: Config, tmp_path):
config.pipeline_batch_size = 1024
config.feature_length = FEATURE_LENGTH
config.edge_buffer_size = 128
config.num_threads = 1
config.num_threads = num_threads

config.fil = ConfigFIL()
config.fil.feature_columns = load_labels_file(os.path.join(TEST_DIRS.data_dir, 'columns_fil.txt'))
Expand All @@ -108,21 +109,24 @@ def test_abp_no_cpp(mock_triton_client, config: Config, tmp_path):

pipe.run()
compare_class_to_scores(out_file, config.class_labels, '', 'score_', threshold=0.5)
results = calc_error_val(results_file_name)
assert results.diff_rows == 0

if num_threads == 1:
results = calc_error_val(results_file_name)
assert results.diff_rows == 0


@pytest.mark.slow
@pytest.mark.use_cpp
@pytest.mark.usefixtures("launch_mock_triton")
def test_abp_cpp(config, tmp_path):
@pytest.mark.parametrize('num_threads', [1, 4])
def test_abp_cpp(config: Config, tmp_path: str, num_threads: int):
config.mode = PipelineModes.FIL
config.class_labels = ["mining"]
config.model_max_batch_size = MODEL_MAX_BATCH_SIZE
config.pipeline_batch_size = 1024
config.feature_length = FEATURE_LENGTH
config.edge_buffer_size = 128
config.num_threads = 1
config.num_threads = num_threads

config.fil = ConfigFIL()
config.fil.feature_columns = load_labels_file(os.path.join(TEST_DIRS.data_dir, 'columns_fil.txt'))
Expand Down Expand Up @@ -151,14 +155,17 @@ def test_abp_cpp(config, tmp_path):

pipe.run()
compare_class_to_scores(out_file, config.class_labels, '', 'score_', threshold=0.5)
results = calc_error_val(results_file_name)
assert results.diff_rows == 0

if num_threads == 1:
results = calc_error_val(results_file_name)
assert results.diff_rows == 0


@pytest.mark.slow
@pytest.mark.use_python
@pytest.mark.parametrize('num_threads', [1, 4])
@mock.patch('tritonclient.grpc.InferenceServerClient')
def test_abp_multi_segment_no_cpp(mock_triton_client, config: Config, tmp_path):
def test_abp_multi_segment_no_cpp(mock_triton_client: mock.MagicMock, config: Config, tmp_path: str, num_threads: int):
mock_metadata = {
"inputs": [{
'name': 'input__0', 'datatype': 'FP32', "shape": [-1, FEATURE_LENGTH]
Expand Down Expand Up @@ -189,7 +196,7 @@ def test_abp_multi_segment_no_cpp(mock_triton_client, config: Config, tmp_path):
config.pipeline_batch_size = 1024
config.feature_length = FEATURE_LENGTH
config.edge_buffer_size = 128
config.num_threads = 1
config.num_threads = num_threads

config.fil = ConfigFIL()
config.fil.feature_columns = load_labels_file(os.path.join(TEST_DIRS.data_dir, 'columns_fil.txt'))
Expand Down Expand Up @@ -230,21 +237,24 @@ def test_abp_multi_segment_no_cpp(mock_triton_client, config: Config, tmp_path):
pipe.add_stage(WriteToFileStage(config, filename=out_file, overwrite=False))

pipe.run()
results = calc_error_val(results_file_name)
assert results.diff_rows == 0

if num_threads == 1:
results = calc_error_val(results_file_name)
assert results.diff_rows == 0


@pytest.mark.slow
@pytest.mark.use_cpp
@pytest.mark.usefixtures("launch_mock_triton")
def test_abp_multi_segment_cpp(config, tmp_path):
@pytest.mark.parametrize('num_threads', [1, 4])
def test_abp_multi_segment_cpp(config: Config, tmp_path: str, num_threads: int):
config.mode = PipelineModes.FIL
config.class_labels = ["mining"]
config.model_max_batch_size = MODEL_MAX_BATCH_SIZE
config.pipeline_batch_size = 1024
config.feature_length = FEATURE_LENGTH
config.edge_buffer_size = 128
config.num_threads = 1
config.num_threads = num_threads

config.fil = ConfigFIL()
config.fil.feature_columns = load_labels_file(os.path.join(TEST_DIRS.data_dir, 'columns_fil.txt'))
Expand Down Expand Up @@ -289,5 +299,6 @@ def test_abp_multi_segment_cpp(config, tmp_path):

pipe.run()

results = calc_error_val(results_file_name)
assert results.diff_rows == 0
if num_threads == 1:
results = calc_error_val(results_file_name)
assert results.diff_rows == 0
15 changes: 10 additions & 5 deletions tests/test_phishing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from _utils import TEST_DIRS
from _utils import calc_error_val
from _utils import mk_async_infer
from morpheus.config import Config
from morpheus.config import PipelineModes
from morpheus.pipeline import LinearPipeline
from morpheus.stages.general.monitor_stage import MonitorStage
Expand All @@ -44,7 +45,7 @@
@pytest.mark.slow
@pytest.mark.use_python
@mock.patch('tritonclient.grpc.InferenceServerClient')
def test_email_no_cpp(mock_triton_client, config, tmp_path):
def test_email_no_cpp(mock_triton_client: mock.MagicMock, config: Config, tmp_path: str):
mock_metadata = {
"inputs": [{
"name": "input_ids", "datatype": "INT64", "shape": [-1, FEATURE_LENGTH]
Expand Down Expand Up @@ -104,21 +105,23 @@ def test_email_no_cpp(mock_triton_client, config, tmp_path):
pipe.add_stage(WriteToFileStage(config, filename=out_file, overwrite=False))

pipe.run()

results = calc_error_val(results_file_name)
assert results.diff_rows == 153


@pytest.mark.slow
@pytest.mark.use_cpp
@pytest.mark.usefixtures("launch_mock_triton")
def test_email_cpp(config, tmp_path):
@pytest.mark.parametrize('num_threads', [1, 4])
def test_email_cpp(config: Config, tmp_path: str, num_threads: int):
config.mode = PipelineModes.NLP
config.class_labels = load_labels_file(os.path.join(TEST_DIRS.data_dir, "labels_phishing.txt"))
config.model_max_batch_size = MODEL_MAX_BATCH_SIZE
config.pipeline_batch_size = 1024
config.feature_length = FEATURE_LENGTH
config.edge_buffer_size = 128
config.num_threads = 1
config.num_threads = num_threads

val_file_name = os.path.join(TEST_DIRS.validation_data_dir, 'phishing-email-validation-data.jsonlines')
vocab_file_name = os.path.join(TEST_DIRS.data_dir, 'bert-base-uncased-hash.txt')
Expand Down Expand Up @@ -147,5 +150,7 @@ def test_email_cpp(config, tmp_path):
pipe.add_stage(WriteToFileStage(config, filename=out_file, overwrite=False))

pipe.run()
results = calc_error_val(results_file_name)
assert results.diff_rows == 682

if num_threads == 1:
results = calc_error_val(results_file_name)
assert results.diff_rows == 682
8 changes: 5 additions & 3 deletions tests/test_triton_inference_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def test_stage_get_inference_worker(config: Config, pipeline_mode: PipelineModes
@pytest.mark.slow
@pytest.mark.use_python
@pytest.mark.parametrize('num_records', [1000, 2000, 4000])
@pytest.mark.parametrize('num_threads', [1, 4, 12])
@mock.patch('tritonclient.grpc.InferenceServerClient')
def test_triton_stage_pipe(mock_triton_client, config, num_records):
def test_triton_stage_pipe(mock_triton_client: mock.MagicMock, config: Config, num_records: int, num_threads: int):
mock_metadata = {
"inputs": [{
'name': 'input__0', 'datatype': 'FP32', "shape": [-1, 1]
Expand Down Expand Up @@ -185,7 +186,7 @@ def test_triton_stage_pipe(mock_triton_client, config, num_records):
config.pipeline_batch_size = 1024
config.feature_length = 1
config.edge_buffer_size = 128
config.num_threads = 1
config.num_threads = num_threads

config.fil = ConfigFIL()
config.fil.feature_columns = ['v']
Expand All @@ -202,4 +203,5 @@ def test_triton_stage_pipe(mock_triton_client, config, num_records):

pipe.run()

assert_results(comp_stage.get_results())
if num_threads == 1:
assert_results(comp_stage.get_results())

0 comments on commit 61d391c

Please sign in to comment.