Skip to content

Commit

Permalink
remove FilterTime
Browse files Browse the repository at this point in the history
  • Loading branch information
nbren12 committed Jul 15, 2024
1 parent a032ea9 commit 4f731e5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 93 deletions.
91 changes: 10 additions & 81 deletions examples/generative/corrdiff/datasets/cwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,80 +171,6 @@ def __len__(self):
return self.valid_times.sum()


class FilterTime(DownscalingDataset):
"""Filter a time dependent dataset"""

def __init__(self, dataset, filter_fn):
"""
Args:
filter_fn: if filter_fn(time) is True then return point
"""
self._dataset = dataset
self._filter_fn = filter_fn
self._indices = [i for i, t in enumerate(self._dataset.time()) if filter_fn(t)]

def longitude(self):
"""Get longitude values from the dataset."""
return self._dataset.longitude()

def latitude(self):
"""Get latitude values from the dataset."""
return self._dataset.latitude()

def input_channels(self):
"""Metadata for the input channels. A list of dictionaries, one for each channel"""
return self._dataset.input_channels()

def output_channels(self):
"""Metadata for the output channels. A list of dictionaries, one for each channel"""
return self._dataset.output_channels()

def time(self):
"""Get time values from the dataset."""
time = self._dataset.time()
return [time[i] for i in self._indices]

def info(self):
"""Get information about the dataset."""
return self._dataset.info()

def image_shape(self):
"""Get the shape of the image (same for input and output)."""
return self._dataset.image_shape()

def normalize_input(self, x, channels=None):
"""Convert input from physical units to normalized data."""
return self._dataset.normalize_input(x, channels=channels)

def denormalize_input(self, x, channels=None):
"""Convert input from normalized data to physical units."""
return self._dataset.denormalize_input(x, channels=channels)

def normalize_output(self, x, channels=None):
"""Convert output from physical units to normalized data."""
return self._dataset.normalize_output(x, channels=channels)

def denormalize_output(self, x, channels=None):
"""Convert output from normalized data to physical units."""
return self._dataset.denormalize_output(x, channels=channels)

def __getitem__(self, idx):
return self._dataset[self._indices[idx]]

def __len__(self):
return len(self._indices)


def is_2021(time):
"""Check if the given time is in the year 2021."""
return time.year == 2021


def is_not_2021(time):
"""Check if the given time is not in the year 2021."""
return not is_2021(time)


class ZarrDataset(DownscalingDataset):
"""A Dataset for loading paired training data from a Zarr-file with the
following schema::
Expand Down Expand Up @@ -333,11 +259,13 @@ def __init__(
train=True,
all_times=False,
):

self._dataset = dataset
if not all_times:
self._dataset = (
FilterTime(dataset, is_not_2021)
self._time_indices = (
[i for i, t in enumerate(self._dataset.time()) if t.year != 2021]
if train
else FilterTime(dataset, is_2021)
else [i for i, t in enumerate(self._dataset.time()) if t.year == 2021]
)
else:
self._dataset = dataset
Expand All @@ -354,8 +282,8 @@ def info(self):
"""Check if the given time is not in the year 2021."""
return self._dataset.info()

def __getitem__(self, idx):
(target, input, _) = self._dataset[idx]
def __getitem__(self, idx_in_split):
(target, input, _) = self._dataset[self._time_indices[idx_in_split]]
# channels
input = input[self.in_channels, :, :]
target = target[self.out_channels, :, :]
Expand All @@ -365,7 +293,7 @@ def __getitem__(self, idx):

target = target[:, : self.img_shape_x, : self.img_shape_y]
input = input[:, : self.img_shape_x, : self.img_shape_y]
return torch.as_tensor(target), torch.as_tensor(input), idx
return torch.as_tensor(target), torch.as_tensor(input), idx_in_split

def input_channels(self):
"""Metadata for the input channels. A list of dictionaries, one for each channel"""
Expand Down Expand Up @@ -393,7 +321,8 @@ def latitude(self):

def time(self):
"""Get time values from the dataset."""
return self._dataset.time()
src_times = self._dataset.time()
return [src_times[i] for i in self._time_indices]

def image_shape(self):
"""Get the shape of the image (same for input and output)."""
Expand Down
17 changes: 5 additions & 12 deletions examples/generative/corrdiff/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from datasets.cwb import (
_ZarrDataset,
FilterTime,
get_zarr_dataset,
)
import torch
Expand Down Expand Up @@ -61,18 +60,12 @@ def test_zarr_dataset_get_valid_time_index(dataset):
assert isinstance(ans, np.int64)


def test_filter_time():
class MockData(torch.utils.data.Dataset):
def __getitem__(self, idx):
return self.time()[idx]
def test_train_test_split():
ds = get_zarr_dataset(data_path=path, train=True)
assert not any(t.year == 2021 for t in ds.time())

def time(self):
return [datetime.datetime(2018, 1, 1), datetime.datetime(1970, 1, 1)]

data = MockData()
filtered = FilterTime(data, lambda time: time.year > 1990)
assert filtered.time() == [datetime.datetime(2018, 1, 1)]
assert filtered[0]
ds = get_zarr_dataset(data_path=path, train=False)
assert all(t.year == 2021 for t in ds.time())


def hash_array(arr, tol=1e-3):
Expand Down

0 comments on commit 4f731e5

Please sign in to comment.