From 23e3304f992fd7b8cba8c5d7a0c0e72857f2039c Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 7 Nov 2022 08:05:14 -0800 Subject: [PATCH] Explicitly re-run the test_add_scores_stage_multi_segment_pipe test with different values for the repeat parameter of the file source stage, to explicitly trigger the bug that was worked-around with 8a1f21cda2be2e9bc7757b98bcaff68bc45ed06b --- tests/test_add_scores_stage_pipe.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_add_scores_stage_pipe.py b/tests/test_add_scores_stage_pipe.py index 52ab084e1f..42c2d39427 100755 --- a/tests/test_add_scores_stage_pipe.py +++ b/tests/test_add_scores_stage_pipe.py @@ -80,14 +80,16 @@ def test_add_scores_stage_pipe(config, tmp_path, order, pipeline_batch_size, rep assert output_np.tolist() == expected.tolist() -def test_add_scores_stage_multi_segment_pipe(config, tmp_path): +@pytest.mark.parametrize('repeat', [1, 2, 5]) +def test_add_scores_stage_multi_segment_pipe(config, tmp_path, repeat): + # Intentionally using FileSourceStage's repeat argument as this triggers a bug in #443 config.class_labels = ['frogs', 'lizards', 'toads', 'turtles'] input_file = os.path.join(TEST_DIRS.tests_data_dir, "filter_probs.csv") out_file = os.path.join(tmp_path, 'results.csv') pipe = LinearPipeline(config) - pipe.set_source(FileSourceStage(config, filename=input_file, iterative=False)) + pipe.set_source(FileSourceStage(config, filename=input_file, iterative=False, repeat=repeat)) pipe.add_segment_boundary(MessageMeta) pipe.add_stage(DeserializeStage(config)) pipe.add_segment_boundary(MultiMessage) @@ -102,7 +104,8 @@ def test_add_scores_stage_multi_segment_pipe(config, tmp_path): assert os.path.exists(out_file) - expected = np.loadtxt(input_file, delimiter=",", skiprows=1) + expected_data = np.loadtxt(input_file, delimiter=",", skiprows=1) + expected = np.concatenate([expected_data for _ in range(repeat)]) # The output data will contain an additional id column that we will need to slice off # also somehow 0.7 ends up being 0.7000000000000001