Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add option to initialize splits from dict #207

Merged
merged 1 commit into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions dance/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,21 @@ class BaseData(ABC):
test_size
Number of cells to be used for testing. If set to -1, used what's left from training and validation.


"""

_FEATURE_CONFIGS: List[str] = ["feature_mod", "feature_channel", "feature_channel_type"]
_LABEL_CONFIGS: List[str] = ["label_mod", "label_channel", "label_channel_type"]
_DATA_CHANNELS: List[str] = ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns"]

def __init__(self, data: Union[AnnData, MuData], train_size: Optional[int] = None, val_size: int = 0,
test_size: int = -1):
test_size: int = -1, split_index_range_dict: Optional[Dict[str, Tuple[int, int]]] = None):
super().__init__()

self._data = data

# TODO: move _split_idx_dict into data.uns
self._split_idx_dict: Dict[str, Sequence[int]] = {}
self._setup_splits(train_size, val_size, test_size)
self._setup_splits(train_size, val_size, test_size, split_index_range_dict)

if "dance_config" not in self._data.uns:
self._data.uns["dance_config"] = dict()
Expand All @@ -88,7 +87,19 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__} object that wraps (.data):\n{self.data}"

# WARNING: need to be careful about subsampling cells as the index are not automatically updated!!
def _setup_splits(self, train_size: Optional[Union[int, str]], val_size: int, test_size: int):
def _setup_splits(
self,
train_size: Optional[Union[int, str]],
val_size: int,
test_size: int,
split_index_range_dict: Optional[Dict[str, Tuple[int, int]]],
):
if split_index_range_dict is None:
self._setup_splits_default(train_size, val_size, test_size)
else:
self._setup_splits_range(split_index_range_dict)

def _setup_splits_default(self, train_size: Optional[Union[int, str]], val_size: int, test_size: int):
if train_size is None:
return
elif isinstance(train_size, str) and train_size.lower() == "all":
Expand Down Expand Up @@ -127,6 +138,19 @@ def _setup_splits(self, train_size: Optional[Union[int, str]], val_size: int, te
if end - start > 0: # skip empty split
self._split_idx_dict[split_name] = list(range(start, end))

def _setup_splits_range(self, split_index_range_dict: Dict[str, Tuple[int, int]]):
for split_name, index_range in split_index_range_dict.items():
if (not isinstance(index_range, tuple)) or (len(index_range) != 2):
raise TypeError("The split index range must of a two-tuple containing the start and end index. "
f"Got {index_range!r} for key {split_name!r}")
elif any(not isinstance(i, int) for i in index_range):
raise TypeError("The split index range must of a two-tuple of int type. "
f"Got {index_range!r} for key {split_name!r}")

start, end = index_range
if end - start > 0: # skip empty split
self._split_idx_dict[split_name] = list(range(start, end))

def __getitem__(self, idx: Sequence[int]) -> Any:
return self.data[idx]

Expand Down
18 changes: 18 additions & 0 deletions tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def test_data_basic_properties(subtests):
with pytest.raises(ValueError): # sum of sizes exceeds data size
Data(adata.copy(), train_size=2, test_size=2)

with subtests.test("Index range dict"):
split_index_range_dict = {"train": (0, 1), "ref": (0, 2), "inf": (2, 3)}
data = Data(adata.copy(), split_index_range_dict=split_index_range_dict)

assert data.train_idx == [0]
assert data.get_split_idx("ref") == [0, 1]
assert data.get_split_idx("inf") == [2]

with subtests.test("Index range dict errors"):
with pytest.raises(TypeError): # value must be a two tuple, not three tuple
Data(adata.copy(), split_index_range_dict={"train": (0, 1, 2)})

with pytest.raises(TypeError): # value must be a two tuple, not a list
Data(adata.copy(), split_index_range_dict={"train": [0, 1]})

with pytest.raises(TypeError): # value must be a two tuple of int, not str
Data(adata.copy(), split_index_range_dict={"train": ("0", "1")})


def test_get_data(subtests):
adata = AnnData(X=X, obs=pd.DataFrame(X, columns=["a", "b"]), var=pd.DataFrame(X.T, columns=["x", "y", "z"]))
Expand Down