Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement append method to dance data object #209

Merged
merged 2 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions dance/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from abc import ABC, abstractmethod
from copy import deepcopy

import anndata
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
from anndata import AnnData
from mudata import MuData

from dance import logger
from dance.typing import Any, Dict, FeatType, Iterator, List, Optional, Sequence, Tuple, Union
from dance.typing import Any, Dict, FeatType, Iterator, List, Literal, Optional, Sequence, Tuple, Union


def _ensure_iter(val: Optional[Union[List[str], str]]) -> Iterator[Optional[str]]:
Expand Down Expand Up @@ -70,7 +70,7 @@ class BaseData(ABC):
_LABEL_CONFIGS: List[str] = ["label_mod", "label_channel", "label_channel_type"]
_DATA_CHANNELS: List[str] = ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns"]

def __init__(self, data: Union[AnnData, MuData], train_size: Optional[int] = None, val_size: int = 0,
def __init__(self, data: Union[anndata.AnnData, MuData], train_size: Optional[int] = None, val_size: int = 0,
test_size: int = -1, split_index_range_dict: Optional[Dict[str, Tuple[int, int]]] = None):
super().__init__()

Expand Down Expand Up @@ -325,7 +325,7 @@ def get_split_mask(self, split_name: str, return_type: FeatType = "numpy") -> Un

@staticmethod
def _get_feature(
in_data: Union[AnnData, MuData],
in_data: Union[anndata.AnnData, MuData],
channel: Optional[str],
channel_type: Optional[str],
mod: Optional[str],
Expand Down Expand Up @@ -421,6 +421,72 @@ def get_feature(self, *, split_name: Optional[str] = None, return_type: FeatType

return feature

def append(
self,
data,
*,
mode: Optional[Literal["merge", "rename", "new_split"]] = "merge",
rename_dict: Optional[Dict[str, str]] = None,
new_split_name: Optional[str] = None,
index_unique: Optional[str] = "_",
):
"""Append another dance data object to the current data object.

Parameters
----------
data
New dance data object to be added.
mode
How to combine the splits from the new data and the current data. (1) ``"merge"``: merge the splits from
the data, e.g., the training indexes from both data are used as the training indexes in the new combined
data. (2) ``"rename"``: rename the splits of the new data and add to the current split index dictionary,
e.g., renaming 'train' to 'ref'. Requires passing the ``rename_dict``. Raise an error if the newly renamed
key is already used in the current split index dictionary. (3) ``"new_split"``: assign the whole new data
to a new split. Requires pssing the ``new_split_name`` that is not already used as a split name in the
current data. (4) ``None``: do not specify split index to the newly added data.
rename_dict
Optional argument that is only used when ``mode="rename"``. A dictionary to map the split names in the new
data to other names.
new_split_name
Optional argument that is only used when ``mode="new_split"``. Name of the split to assign to the new data.
index_unique
See :meth:`anndata.concat`.

"""
offset = self.shape[0]
new_split_idx_dict = {i: sorted(np.array(j) + offset) for i, j in data._split_idx_dict.items()}

if mode == "merge":
for split_name, split_idxs in self._split_idx_dict.items():
if split_name in new_split_idx_dict:
split_idxs = split_idxs + new_split_idx_dict[split_name]
new_split_idx_dict[split_name] = split_idxs
elif mode == "rename":
if rename_dict is None:
raise ValueError("Mode 'rename' is selected but 'rename_dict' is not specified.")
elif len(common_keys := set(self._split_idx_dict) & set(rename_dict.values())) > 0:
raise ValueError(f"'rename_dict' cannot caontain split keys present in current data: {common_keys}")
elif len(missed_keys := [i for i in data._split_idx_dict if i not in rename_dict]) > 0:
raise KeyError(f"Missing rename mapping for keys: {missed_keys}")
new_split_idx_dict = {rename_dict[i]: j for i, j in new_split_idx_dict.items()}
new_split_idx_dict.update(self._split_idx_dict)
elif mode == "new_split":
if new_split_name is None:
raise ValueError("Mode 'new_split' is selected but 'new_split_name' is not specified.")
elif not isinstance(new_split_name, str):
raise TypeError(f"'new_split_name' must be a string, got {type(new_split_name)}: {new_split_name}.")
elif new_split_name in self._split_idx_dict:
raise ValueError(f"{new_split_name!r} is being used in the current splits. Please pick another name.")
new_split_idx_dict = {new_split_name: list(range(offset, offset + data.shape[0]))}
new_split_idx_dict.update(self._split_idx_dict)
elif mode is None:
new_split_idx_dict = self._split_idx_dict
else:
raise ValueError(f"Unknown mode {mode!r}. Available options are: 'merge', 'rename', 'new_split'")

self._data = anndata.concat((self.data, data.data), index_unique=index_unique)
self._split_idx_dict = new_split_idx_dict


class Data(BaseData):

Expand Down
50 changes: 48 additions & 2 deletions tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from dance.data import Data

X = np.array([[0, 1], [1, 2], [2, 3]])
Y = np.array([[0], [1], [2]])
X = np.array([[0, 1], [1, 2], [2, 3]], dtype=np.float32)
Y = np.array([[0], [1], [2]], dtype=np.float32)


def test_data_basic_properties(subtests):
Expand Down Expand Up @@ -127,3 +127,49 @@ def test_get_data(subtests):
(x1, x2), _ = data.get_train_data()
assert x1.tolist() == [0, 1, 2]
assert x2.tolist() == [2, 3]


def test_append(subtests):
data1 = Data(AnnData(X=X), train_size=1)
data2 = Data(AnnData(X=X), train_size=2)
data2_split_idx = {"train": [0, 1], "test": [2]}

with subtests.test(mode="merge"):
data = data1.copy()
data.append(data2, mode="merge")
assert data._split_idx_dict == {"train": [0, 3, 4], "test": [1, 2, 5]}
# Make sure the appended data is not inplace modeified
assert data2._split_idx_dict == data2_split_idx

with subtests.test(mode="rename"):
# Missing rename_dict
pytest.raises(ValueError, data1.copy().append, data2, mode="rename")

# Missing value for 'test'
pytest.raises(KeyError, data1.copy().append, data2, mode="rename", rename_dict={"train": "new"})

# Conflicting value for 'test'
pytest.raises(ValueError, data1.copy().append, data2, mode="rename", rename_dict={"train": "a", "test": "test"})

data = data1.copy()
data.append(data2, mode="rename", rename_dict={"train": "new_train", "test": "new_test"})
assert data._split_idx_dict == {"train": [0], "new_train": [3, 4], "test": [1, 2], "new_test": [5]}
assert data2._split_idx_dict == data2_split_idx

with subtests.test(mode="new_split"):
# Missing new_split_name
pytest.raises(ValueError, data1.copy().append, data2, mode="new_split")

# Conflicting name "test"
pytest.raises(ValueError, data1.copy().append, data2, mode="new_split", new_split_name="test")

data = data1.copy()
data.append(data2, mode="new_split", new_split_name="ref")
assert data._split_idx_dict == {"train": [0], "test": [1, 2], "ref": [3, 4, 5]}
assert data2._split_idx_dict == data2_split_idx

with subtests.test(mode=None):
data = data1.copy()
data.append(data2, mode=None)
assert data._split_idx_dict == {"train": [0], "test": [1, 2]}
assert data2._split_idx_dict == data2_split_idx