Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplimentation of strict thresholding #316

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 63 additions & 39 deletions tobac/feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@
return not overlap


def remove_parents(features_thresholds, regions_i, regions_old):
def remove_parents(
features_thresholds, regions_i, regions_old, strict_thresholding=False
):
"""Remove parents of newly detected feature regions.

Remove features where its regions surround newly
Expand All @@ -279,6 +281,9 @@
threshold from previous threshold
(feature ids as keys).

strict_thresholding: Bool, optional
If True, a feature can only be detected if all previous thresholds have been met.
Default is False.
Returns
-------
features_thresholds : pandas.DataFrame
Expand All @@ -288,26 +293,72 @@

try:
all_curr_pts = np.concatenate([vals for idx, vals in regions_i.items()])
except ValueError:
# the case where there are no new regions
if strict_thresholding:
return features_thresholds, {}
else:
return features_thresholds, regions_old
try:
all_old_pts = np.concatenate([vals for idx, vals in regions_old.items()])
except ValueError:
# the case where there are no regions
return features_thresholds
# the case where there are no old regions
if strict_thresholding:
return (
features_thresholds[
~features_thresholds["idx"].isin(list(regions_i.keys()))
],
{},
)
else:
return features_thresholds, regions_i

Check warning on line 314 in tobac/feature_detection.py

View check run for this annotation

Codecov / codecov/patch

tobac/feature_detection.py#L314

Added line #L314 was not covered by tests

old_feat_arr = np.empty((len(all_old_pts)))
curr_loc = 0
for idx_old in regions_old:
old_feat_arr[curr_loc : curr_loc + len(regions_old[idx_old])] = idx_old
curr_loc += len(regions_old[idx_old])

_, _, common_ix_old = np.intersect1d(all_curr_pts, all_old_pts, return_indices=True)
_, common_ix_new, common_ix_old = np.intersect1d(
all_curr_pts, all_old_pts, return_indices=True
)
list_remove = np.unique(old_feat_arr[common_ix_old])

if strict_thresholding:
new_feat_arr = np.empty((len(all_curr_pts)))
curr_loc = 0
for idx_new in regions_i:
new_feat_arr[curr_loc : curr_loc + len(regions_i[idx_new])] = idx_new
curr_loc += len(regions_i[idx_new])
regions_i_overlap = np.unique(new_feat_arr[common_ix_new])
no_prev_feature = np.array(list(regions_i.keys()))[

Check warning on line 334 in tobac/feature_detection.py

View check run for this annotation

Codecov / codecov/patch

tobac/feature_detection.py#L328-L334

Added lines #L328 - L334 were not covered by tests
np.logical_not(np.isin(list(regions_i.keys()), regions_i_overlap))
]
list_remove = np.concatenate([list_remove, no_prev_feature])

Check warning on line 337 in tobac/feature_detection.py

View check run for this annotation

Codecov / codecov/patch

tobac/feature_detection.py#L337

Added line #L337 was not covered by tests

# remove parent regions:
if features_thresholds is not None:
features_thresholds = features_thresholds[
~features_thresholds["idx"].isin(list_remove)
]

return features_thresholds
if strict_thresholding:
keep_new_keys = np.isin(list(regions_i.keys()), features_thresholds["idx"])
regions_old = {

Check warning on line 347 in tobac/feature_detection.py

View check run for this annotation

Codecov / codecov/patch

tobac/feature_detection.py#L346-L347

Added lines #L346 - L347 were not covered by tests
k: v for i, (k, v) in enumerate(regions_i.items()) if keep_new_keys[i]
}
else:
keep_old_keys = np.isin(
list(regions_old.keys()), features_thresholds["idx"]
)
regions_old = {
k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i]
}
regions_old.update(regions_i)
else:
regions_old = regions_i

Check warning on line 359 in tobac/feature_detection.py

View check run for this annotation

Codecov / codecov/patch

tobac/feature_detection.py#L359

Added line #L359 was not covered by tests

return features_thresholds, regions_old


def feature_detection_threshold(
Expand Down Expand Up @@ -995,42 +1046,15 @@
)

# For multiple threshold, and features found both in the current and previous step, remove "parent" features from Dataframe
if i_threshold > 0 and not features_thresholds.empty and regions_old:
if i_threshold > 0 and not features_thresholds.empty:
# for each threshold value: check if newly found features are surrounded by feature based on less restrictive threshold
features_thresholds = remove_parents(
features_thresholds, regions_i, regions_old
features_thresholds, regions_old = remove_parents(
features_thresholds,
regions_i,
regions_old,
strict_thresholding=strict_thresholding,
)

if strict_thresholding:
if regions_i:
# remove data in regions where no features were detected
valid_regions: np.ndarray = np.zeros_like(track_data)
region_indices: list[int] = list(regions_i.values())[
0
] # linear indices
valid_regions.ravel()[region_indices] = 1
track_data: np.ndarray = np.multiply(valid_regions, track_data)
else:
# since regions_i is empty no further features can be detected
logging.debug(
"Finished feature detection for threshold "
+ str(i_threshold)
+ " : "
+ str(threshold_i)
)
return features_thresholds

if i_threshold > 0 and not features_thresholds.empty and regions_old:
# Work out which regions are still in feature_thresholds to keep
# This is faster than calling "in" for every idx
keep_old_keys = np.isin(
list(regions_old.keys()), features_thresholds["idx"]
)
regions_old = {
k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i]
}
regions_old.update(regions_i)
else:
elif i_threshold == 0:
regions_old = regions_i

logging.debug(
Expand Down
78 changes: 78 additions & 0 deletions tobac/tests/test_feature_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import tobac
import tobac.testing as tbtest
import tobac.feature_detection as feat_detect
import pytest
import numpy as np
import xarray as xr
from pandas.testing import assert_frame_equal


Expand Down Expand Up @@ -745,6 +747,82 @@ def test_strict_thresholding():
assert len(features) == 1
assert features["threshold_value"].item() == thresholds[0]

# Repeat for minima
test_data_iris = tbtest.make_dataset_from_arr(10 - test_data, data_type="iris")
# All of these thresholds will be met
thresholds = [9, 5, 2.5]

# This will detect 2 features (first and last threshold value)
features = feat_detect.feature_detection_multithreshold_timestep(
test_data_iris,
0,
dxy=1,
threshold=thresholds,
n_min_threshold=n_min_thresholds,
strict_thresholding=False,
target="minimum",
)
assert len(features) == 1
assert features["threshold_value"].item() == thresholds[-1]

# Since the second n_min_thresholds value is not met this will only detect 1 feature
features = feat_detect.feature_detection_multithreshold_timestep(
test_data_iris,
0,
dxy=1,
threshold=thresholds,
n_min_threshold=n_min_thresholds,
strict_thresholding=True,
target="minimum",
)
assert len(features) == 1
assert features["threshold_value"].item() == thresholds[0]

# Test example from documentation
input_field_arr = np.zeros((1, 101, 101))

for idx, side in enumerate([40, 20, 10, 5]):
input_field_arr[
:,
(50 - side - 4 * idx) : (50 + side - 4 * idx),
(50 - side - 4 * idx) : (50 + side - 4 * idx),
] = (
50 - side
)

input_field_iris = xr.DataArray(
input_field_arr,
dims=["time", "Y", "X"],
coords={"time": [np.datetime64("2019-01-01T00:00:00")]},
).to_iris()

thresholds = [8, 29, 39, 44]

n_min_thresholds = [79**2, input_field_arr.size, 8**2, 3**2]

features_demo = tobac.feature_detection_multithreshold(
input_field_iris,
dxy=1000,
threshold=thresholds,
n_min_threshold=n_min_thresholds,
strict_thresholding=False,
)

assert features_demo.iloc[0]["hdim_1"] == pytest.approx(37.5)
assert features_demo.iloc[0]["hdim_2"] == pytest.approx(37.5)

# Now repeat with strict thresholding
features_demo = tobac.feature_detection_multithreshold(
input_field_iris,
dxy=1000,
threshold=thresholds,
n_min_threshold=n_min_thresholds,
strict_thresholding=True,
)

assert features_demo.iloc[0]["hdim_1"] == pytest.approx(49.5)
assert features_demo.iloc[0]["hdim_2"] == pytest.approx(49.5)


@pytest.mark.parametrize(
"h1_indices, h2_indices, max_h1, max_h2, PBC_flag, position_threshold, expected_output",
Expand Down
Loading