From bda30f530931ab5217b182a17ff3d656752cd480 Mon Sep 17 00:00:00 2001 From: Spencer Sun Date: Wed, 22 May 2024 18:37:35 +1200 Subject: [PATCH] fix: update test_batch to enforce classification in basic test --- tests/test_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 40a8bf61..9f33b6ce 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -67,7 +67,7 @@ def test_batch_basic(): y = np.arange(n) assert x.shape == (n, feature_count) - stream = NumpyStream(x, y) + stream = NumpyStream(x, y, target_type='categorical') learner = _DummyBatchClassifierSSL(batch_size, stream.schema, class_value_type=str) prequential_ssl_evaluation( stream=stream, learner=learner, label_probability=0.01, window_size=100