diff --git a/dance/data/base.py b/dance/data/base.py index 9f2135bb..e42a9874 100644 --- a/dance/data/base.py +++ b/dance/data/base.py @@ -3,15 +3,15 @@ from abc import ABC, abstractmethod from copy import deepcopy +import anndata import numpy as np import pandas as pd import scipy.sparse as sp import torch -from anndata import AnnData from mudata import MuData from dance import logger -from dance.typing import Any, Dict, FeatType, Iterator, List, Optional, Sequence, Tuple, Union +from dance.typing import Any, Dict, FeatType, Iterator, List, Literal, Optional, Sequence, Tuple, Union def _ensure_iter(val: Optional[Union[List[str], str]]) -> Iterator[Optional[str]]: @@ -70,7 +70,7 @@ class BaseData(ABC): _LABEL_CONFIGS: List[str] = ["label_mod", "label_channel", "label_channel_type"] _DATA_CHANNELS: List[str] = ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns"] - def __init__(self, data: Union[AnnData, MuData], train_size: Optional[int] = None, val_size: int = 0, + def __init__(self, data: Union[anndata.AnnData, MuData], train_size: Optional[int] = None, val_size: int = 0, test_size: int = -1, split_index_range_dict: Optional[Dict[str, Tuple[int, int]]] = None): super().__init__() @@ -325,7 +325,7 @@ def get_split_mask(self, split_name: str, return_type: FeatType = "numpy") -> Un @staticmethod def _get_feature( - in_data: Union[AnnData, MuData], + in_data: Union[anndata.AnnData, MuData], channel: Optional[str], channel_type: Optional[str], mod: Optional[str], @@ -421,6 +421,72 @@ def get_feature(self, *, split_name: Optional[str] = None, return_type: FeatType return feature + def append( + self, + data, + *, + mode: Optional[Literal["merge", "rename", "new_split"]] = "merge", + rename_dict: Optional[Dict[str, str]] = None, + new_split_name: Optional[str] = None, + index_unique: Optional[str] = "_", + ): + """Append another dance data object to the current data object. + + Parameters + ---------- + data + New dance data object to be added. + mode + How to combine the splits from the new data and the current data. (1) ``"merge"``: merge the splits from + the data, e.g., the training indexes from both data are used as the training indexes in the new combined + data. (2) ``"rename"``: rename the splits of the new data and add to the current split index dictionary, + e.g., renaming 'train' to 'ref'. Requires passing the ``rename_dict``. Raise an error if the newly renamed + key is already used in the current split index dictionary. (3) ``"new_split"``: assign the whole new data + to a new split. Requires pssing the ``new_split_name`` that is not already used as a split name in the + current data. (4) ``None``: do not specify split index to the newly added data. + rename_dict + Optional argument that is only used when ``mode="rename"``. A dictionary to map the split names in the new + data to other names. + new_split_name + Optional argument that is only used when ``mode="new_split"``. Name of the split to assign to the new data. + index_unique + See :meth:`anndata.concat`. + + """ + offset = self.shape[0] + new_split_idx_dict = {i: sorted(np.array(j) + offset) for i, j in data._split_idx_dict.items()} + + if mode == "merge": + for split_name, split_idxs in self._split_idx_dict.items(): + if split_name in new_split_idx_dict: + split_idxs = split_idxs + new_split_idx_dict[split_name] + new_split_idx_dict[split_name] = split_idxs + elif mode == "rename": + if rename_dict is None: + raise ValueError("Mode 'rename' is selected but 'rename_dict' is not specified.") + elif len(common_keys := set(self._split_idx_dict) & set(rename_dict.values())) > 0: + raise ValueError(f"'rename_dict' cannot caontain split keys present in current data: {common_keys}") + elif len(missed_keys := [i for i in data._split_idx_dict if i not in rename_dict]) > 0: + raise KeyError(f"Missing rename mapping for keys: {missed_keys}") + new_split_idx_dict = {rename_dict[i]: j for i, j in new_split_idx_dict.items()} + new_split_idx_dict.update(self._split_idx_dict) + elif mode == "new_split": + if new_split_name is None: + raise ValueError("Mode 'new_split' is selected but 'new_split_name' is not specified.") + elif not isinstance(new_split_name, str): + raise TypeError(f"'new_split_name' must be a string, got {type(new_split_name)}: {new_split_name}.") + elif new_split_name in self._split_idx_dict: + raise ValueError(f"{new_split_name!r} is being used in the current splits. Please pick another name.") + new_split_idx_dict = {new_split_name: list(range(offset, offset + data.shape[0]))} + new_split_idx_dict.update(self._split_idx_dict) + elif mode is None: + new_split_idx_dict = self._split_idx_dict + else: + raise ValueError(f"Unknown mode {mode!r}. Available options are: 'merge', 'rename', 'new_split'") + + self._data = anndata.concat((self.data, data.data), index_unique=index_unique) + self._split_idx_dict = new_split_idx_dict + class Data(BaseData): diff --git a/tests/data/test_data.py b/tests/data/test_data.py index 106806d4..ae31eb46 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -5,8 +5,8 @@ from dance.data import Data -X = np.array([[0, 1], [1, 2], [2, 3]]) -Y = np.array([[0], [1], [2]]) +X = np.array([[0, 1], [1, 2], [2, 3]], dtype=np.float32) +Y = np.array([[0], [1], [2]], dtype=np.float32) def test_data_basic_properties(subtests): @@ -127,3 +127,49 @@ def test_get_data(subtests): (x1, x2), _ = data.get_train_data() assert x1.tolist() == [0, 1, 2] assert x2.tolist() == [2, 3] + + +def test_append(subtests): + data1 = Data(AnnData(X=X), train_size=1) + data2 = Data(AnnData(X=X), train_size=2) + data2_split_idx = {"train": [0, 1], "test": [2]} + + with subtests.test(mode="merge"): + data = data1.copy() + data.append(data2, mode="merge") + assert data._split_idx_dict == {"train": [0, 3, 4], "test": [1, 2, 5]} + # Make sure the appended data is not inplace modeified + assert data2._split_idx_dict == data2_split_idx + + with subtests.test(mode="rename"): + # Missing rename_dict + pytest.raises(ValueError, data1.copy().append, data2, mode="rename") + + # Missing value for 'test' + pytest.raises(KeyError, data1.copy().append, data2, mode="rename", rename_dict={"train": "new"}) + + # Conflicting value for 'test' + pytest.raises(ValueError, data1.copy().append, data2, mode="rename", rename_dict={"train": "a", "test": "test"}) + + data = data1.copy() + data.append(data2, mode="rename", rename_dict={"train": "new_train", "test": "new_test"}) + assert data._split_idx_dict == {"train": [0], "new_train": [3, 4], "test": [1, 2], "new_test": [5]} + assert data2._split_idx_dict == data2_split_idx + + with subtests.test(mode="new_split"): + # Missing new_split_name + pytest.raises(ValueError, data1.copy().append, data2, mode="new_split") + + # Conflicting name "test" + pytest.raises(ValueError, data1.copy().append, data2, mode="new_split", new_split_name="test") + + data = data1.copy() + data.append(data2, mode="new_split", new_split_name="ref") + assert data._split_idx_dict == {"train": [0], "test": [1, 2], "ref": [3, 4, 5]} + assert data2._split_idx_dict == data2_split_idx + + with subtests.test(mode=None): + data = data1.copy() + data.append(data2, mode=None) + assert data._split_idx_dict == {"train": [0], "test": [1, 2]} + assert data2._split_idx_dict == data2_split_idx