Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
MArpogaus committed Sep 13, 2024
2 parents fdce998 + 3ce0cda commit 7b81c70
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
- uses: pre-commit/action@v3.0.1
12 changes: 6 additions & 6 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.x
- name: Install pypa/build
Expand All @@ -24,7 +24,7 @@ jobs:
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
Expand All @@ -42,7 +42,7 @@ jobs:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
Expand All @@ -63,12 +63,12 @@ jobs:

steps:
- name: Download all the dists
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Sign the dists with Sigstore
uses: sigstore/gh-action-sigstore-python@v1.2.3
uses: sigstore/gh-action-sigstore-python@v3.0.0
with:
inputs: >-
./dist/*.tar.gz
Expand Down Expand Up @@ -107,7 +107,7 @@ jobs:

steps:
- name: Download all the dists
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:

strategy:
matrix:
platform: [ubuntu-latest, macos-latest, windows-latest]
platform: [ubuntu-latest, windows-latest]
python-version: [3.7, 3.8, 3.9, '3.10', '3.11']

runs-on: ${{ matrix.platform }}
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
## v0.2.0 (2024-09-13)

### Feat

- allows to filter out patches containing NaN values

### Fix

- disable NaN filtering per default to ensure errors if they are unexpected

## v0.1.2 (2024-02-19)

## v0.1.1 (2024-02-19)
Expand Down
3 changes: 2 additions & 1 deletion src/tensorflow_time_series_dataset/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com>
#
# created : 2022-01-07 09:02:38 (Marcel Arpogaus)
# changed : 2024-02-19 12:57:42 (Marcel Arpogaus)
# changed : 2024-09-12 16:21:24 (Marcel Arpogaus)
# DESCRIPTION #################################################################
# ...
# LICENSE #####################################################################
Expand Down Expand Up @@ -53,6 +53,7 @@ class WindowedTimeSeriesDatasetFactory:
"cycle_length": 1,
"shuffle_buffer_size": 1000,
"cache": True,
"filter_nans": False,
}

def __init__(
Expand Down
23 changes: 20 additions & 3 deletions src/tensorflow_time_series_dataset/pipeline/patch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com>
#
# created : 2022-01-07 09:02:38 (Marcel Arpogaus)
# changed : 2024-02-19 12:52:06 (Marcel Arpogaus)
# changed : 2024-09-12 15:52:32 (Marcel Arpogaus)
# DESCRIPTION #################################################################
# ...
# LICENSE #####################################################################
Expand Down Expand Up @@ -35,20 +35,25 @@ class PatchGenerator:
The size of each patch.
shift : int
The shift between patches.
filter_nans : int
Apply a filter function to drop patches containing NaN values.
"""

def __init__(self, window_size: int, shift: int) -> None:
def __init__(self, window_size: int, shift: int, filter_nans: bool) -> None:
"""Parameters
----------
window_size : int
The size of each patch.
shift : int
The shift between patches.
filter_nans : int
If True, apply a filter function to drop patches containing NaN values.
"""
self.window_size: int = window_size
self.shift: int = shift
self.filter_nans: bool = filter_nans

def __call__(self, data: tf.Tensor) -> tf.data.Dataset:
"""Converts input data into patches of provided window size.
Expand All @@ -71,6 +76,18 @@ def __call__(self, data: tf.Tensor) -> tf.data.Dataset:
size=self.window_size,
shift=self.shift,
drop_remainder=True,
).flat_map(lambda sub: sub.batch(self.window_size, drop_remainder=True))
)

def sub_to_patch(sub):
return sub.batch(self.window_size, drop_remainder=True)

data_set = data_set.flat_map(sub_to_patch)

if self.filter_nans:

def filter_func(patch):
return tf.reduce_all(tf.logical_not(tf.math.is_nan(patch)))

data_set = data_set.filter(filter_func)

return data_set
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com>
#
# created : 2022-01-07 09:02:38 (Marcel Arpogaus)
# changed : 2024-02-19 12:53:06 (Marcel Arpogaus)
# changed : 2024-09-12 16:01:27 (Marcel Arpogaus)
# DESCRIPTION #################################################################
# ...
# LICENSE #####################################################################
Expand Down Expand Up @@ -64,6 +64,8 @@ class WindowedTimeSeriesPipeline:
Whether to cache the dataset in memory or to a specific file.
drop_remainder : bool
Whether to drop the remainder of batches that are not equal to the batch size.
filter_nans : int
Apply a filter function to drop patches containing NaN values.
Raises
------
Expand All @@ -85,6 +87,7 @@ def __init__(
shuffle_buffer_size: int,
cache: Union[str, bool],
drop_remainder: bool,
filter_nans: bool,
) -> None:
assert (
prediction_size > 0
Expand All @@ -101,6 +104,7 @@ def __init__(
self.shuffle_buffer_size = shuffle_buffer_size
self.cache = cache
self.drop_remainder = drop_remainder
self.filter_nans = filter_nans

def __call__(self, ds: Dataset) -> Dataset:
"""Applies the pipeline operations to the given dataset.
Expand All @@ -117,7 +121,7 @@ def __call__(self, ds: Dataset) -> Dataset:
"""
ds = ds.interleave(
PatchGenerator(self.window_size, self.shift),
PatchGenerator(self.window_size, self.shift, self.filter_nans),
cycle_length=self.cycle_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
Expand Down
44 changes: 39 additions & 5 deletions tests/test_pipleine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def patched_dataset(

ds = tf.data.Dataset.from_tensors(df[sorted(used_cols)])
ds = ds.interleave(
PatchGenerator(window_size=window_size, shift=shift),
PatchGenerator(window_size=window_size, shift=shift, filter_nans=False),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
return ds, df, window_size, shift
Expand All @@ -53,7 +53,7 @@ def test_patch_generator(time_series_df, window_size, shift):

ds = tf.data.Dataset.from_tensors(df)
ds_patched = ds.interleave(
PatchGenerator(window_size=window_size, shift=shift),
PatchGenerator(window_size=window_size, shift=shift, filter_nans=False),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
for i, patch in enumerate(ds_patched.as_numpy_iterator()):
Expand All @@ -65,6 +65,36 @@ def test_patch_generator(time_series_df, window_size, shift):
assert i + 1 == patches, "Not enough patches"


@pytest.mark.parametrize("window_size,shift", [(2 * 48, 48), (48 + 1, 1)])
def test_patch_generator_filter_nans(time_series_df, window_size, shift):
df = time_series_df.set_index("date_time")
# randomly set 20% of elemnts in the dataset for nans

df = time_series_df.set_index("date_time")
nan_mask = np.random.default_rng(1).uniform(0, 1, df.shape) < 0.01
df[nan_mask] = np.nan

initial_size = window_size - shift
data_size = df.index.size - initial_size
patches = data_size // shift

expected_shape = (window_size, len(df.columns))

ds = tf.data.Dataset.from_tensors(df)
ds_patched = ds.interleave(
PatchGenerator(window_size=window_size, shift=shift, filter_nans=True),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
for i, patch in enumerate(ds_patched.as_numpy_iterator()):
assert patch.shape == expected_shape, "Wrong shape"
x1 = patch[0, 0]
idx = int(x1 % 1e5)
expected_values = df.iloc[idx : idx + window_size]
assert np.all(patch == expected_values), "Patch contains wrong data"
assert not np.isnan(patch).any(), "Patch contains NaNs."
assert i + 1 < patches, "No patches have been dropped"


@pytest.mark.parametrize("window_size,shift", [(2 * 48, 48), (48 + 1, 1)])
def test_patch_generator_groupby(groupby_dataset, window_size, shift):
ds, df = groupby_dataset
Expand All @@ -78,7 +108,7 @@ def test_patch_generator_groupby(groupby_dataset, window_size, shift):
expected_shape = (window_size, len(columns))

ds_patched = ds.interleave(
PatchGenerator(window_size=window_size, shift=shift),
PatchGenerator(window_size=window_size, shift=shift, filter_nans=True),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)

Expand Down Expand Up @@ -166,7 +196,9 @@ def test_windowed_time_series_pipeline(
batch_size=batch_size,
drop_remainder=True,
)
pipeline_kwargs = dict(cycle_length=1, shuffle_buffer_size=100, cache=True)
pipeline_kwargs = dict(
cycle_length=1, shuffle_buffer_size=100, cache=True, filter_nans=False
)

with validate_args(
history_size=history_size,
Expand Down Expand Up @@ -209,7 +241,9 @@ def test_windowed_time_series_pipeline_groupby(
batch_size=batch_size,
drop_remainder=False,
)
pipeline_kwargs = dict(cycle_length=len(ids), shuffle_buffer_size=1000, cache=True)
pipeline_kwargs = dict(
cycle_length=len(ids), shuffle_buffer_size=1000, cache=True, filter_nans=False
)

with validate_args(
history_size=history_size,
Expand Down

0 comments on commit 7b81c70

Please sign in to comment.