From d69d700e8220146ce63ff4b42044cbe90ebe65e6 Mon Sep 17 00:00:00 2001 From: Remy Date: Fri, 16 Dec 2022 07:09:29 -0500 Subject: [PATCH 1/2] fix: update data split index test --- tests/data/test_data.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/data/test_data.py b/tests/data/test_data.py index a0748311..d9c2a584 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -9,39 +9,39 @@ def test_data_basic_properties(subtests): - xad = AnnData(X=X) - yad = AnnData(X=Y) + adata = AnnData(X=X) + adata.obsm["label"] = Y with subtests.test("No training splits"): - data = Data(x=xad, y=yad) + data = Data(adata) assert data.num_cells == 3 assert data.num_features == 2 assert data.cells == ["0", "1", "2"] assert data.train_idx is data.val_idx is data.test_idx is None with subtests.test("Training and testing splits"): - data = Data(x=xad, y=yad, train_size=2) - assert data.train_idx == ["0", "1"] + data = Data(adata, train_size=2) + assert data.train_idx == [0, 1] assert data.val_idx is None - assert data.test_idx == ["2"] + assert data.test_idx == [2] - data = Data(x=xad, y=yad, train_size=-1, test_size=1) - assert data.train_idx == ["0", "1"] + data = Data(adata, train_size=-1, test_size=1) + assert data.train_idx == [0, 1] assert data.val_idx is None - assert data.test_idx == ["2"] + assert data.test_idx == [2] with subtests.test("Training validation and testing splits"): - data = Data(x=xad, y=yad, train_size=1, val_size=1) - assert data.train_idx == ["0"] - assert data.val_idx == ["1"] - assert data.test_idx == ["2"] + data = Data(adata, train_size=1, val_size=1) + assert data.train_idx == [0] + assert data.val_idx == [1] + assert data.test_idx == [2] with subtests.test("Error sizes"): with pytest.raises(TypeError): - Data(x=xad, y=yad, train_size="1") + Data(adata, train_size="1") with pytest.raises(ValueError): # cannot have two -1 - Data(x=xad, y=yad, train_size=-1) + Data(adata, train_size=-1) with pytest.raises(ValueError): # train size exceeds data size - Data(x=xad, y=yad, train_size=5) + Data(adata, train_size=5) with pytest.raises(ValueError): # sum of sizes exceeds data size - Data(x=xad, y=yad, train_size=2, test_size=2) + Data(adata, train_size=2, test_size=2) From b28d9e57d3ead210a3d763c1242c137406eb690b Mon Sep 17 00:00:00 2001 From: Remy Date: Fri, 16 Dec 2022 07:18:07 -0500 Subject: [PATCH 2/2] add get split data test --- tests/data/test_data.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/data/test_data.py b/tests/data/test_data.py index d9c2a584..adb642a8 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -10,7 +10,6 @@ def test_data_basic_properties(subtests): adata = AnnData(X=X) - adata.obsm["label"] = Y with subtests.test("No training splits"): data = Data(adata) @@ -45,3 +44,23 @@ def test_data_basic_properties(subtests): Data(adata, train_size=5) with pytest.raises(ValueError): # sum of sizes exceeds data size Data(adata, train_size=2, test_size=2) + + +def test_get_data(subtests): + adata = AnnData(X=X) + adata.obsm["label"] = Y + + with subtests.test("Single feature"): + data = Data(adata, train_size=2) + data.set_config(label_channel="label") + + x_train, y_train = data.get_train_data() + assert x_train.tolist() == [[0, 1], [1, 2]] + assert y_train.tolist() == [[0], [1]] + + x_test, y_test = data.get_test_data() + assert x_test.tolist() == [[2, 3]] + assert y_test.tolist() == [[2]] + + # Validation set not set + pytest.raises(KeyError, data.get_val_data)