From ced3044d60cda2d7040190ec45fe1b0e6ede4b93 Mon Sep 17 00:00:00 2001 From: William Jones Date: Fri, 28 Jul 2023 21:39:15 +0100 Subject: [PATCH 1/3] Fix to strict_thresholding that now works for both minima and maxima, and provides the same results as strict_thresholding=False if n_min_threshold is a fixed value --- tobac/feature_detection.py | 150 +++++++++++++++++++++++++++---------- 1 file changed, 111 insertions(+), 39 deletions(-) diff --git a/tobac/feature_detection.py b/tobac/feature_detection.py index 108069ae..0a4d7c63 100644 --- a/tobac/feature_detection.py +++ b/tobac/feature_detection.py @@ -258,7 +258,9 @@ def test_overlap(region_inner, region_outer): return not overlap -def remove_parents(features_thresholds, regions_i, regions_old): +def remove_parents( + features_thresholds, regions_i, regions_old, strict_thresholding=False +): """Remove parents of newly detected feature regions. Remove features where its regions surround newly @@ -279,6 +281,9 @@ def remove_parents(features_thresholds, regions_i, regions_old): threshold from previous threshold (feature ids as keys). + strict_thresholding: Bool, optional + If True, a feature can only be detected if all previous thresholds have been met. + Default is False. Returns ------- features_thresholds : pandas.DataFrame @@ -288,26 +293,73 @@ def remove_parents(features_thresholds, regions_i, regions_old): try: all_curr_pts = np.concatenate([vals for idx, vals in regions_i.items()]) + except ValueError: + # the case where there are no new regions + if strict_thresholding: + return features_thresholds, {} + else: + return features_thresholds, regions_old + try: all_old_pts = np.concatenate([vals for idx, vals in regions_old.items()]) except ValueError: - # the case where there are no regions - return features_thresholds + # the case where there are no old regions + if strict_thresholding: + return ( + features_thresholds[ + ~features_thresholds["idx"].isin(list(regions_i.keys())) + ], + {}, + ) + else: + return features_thresholds, regions_i + old_feat_arr = np.empty((len(all_old_pts))) curr_loc = 0 for idx_old in regions_old: old_feat_arr[curr_loc : curr_loc + len(regions_old[idx_old])] = idx_old curr_loc += len(regions_old[idx_old]) - _, _, common_ix_old = np.intersect1d(all_curr_pts, all_old_pts, return_indices=True) + _, common_ix_new, common_ix_old = np.intersect1d( + all_curr_pts, all_old_pts, return_indices=True + ) list_remove = np.unique(old_feat_arr[common_ix_old]) + if strict_thresholding: + new_feat_arr = np.empty((len(all_curr_pts))) + curr_loc = 0 + for idx_new in regions_i: + new_feat_arr[curr_loc : curr_loc + len(regions_i[idx_new])] = idx_new + curr_loc += len(regions_i[idx_new]) + # _, _, common_ix_new = np.intersect1d(all_old_pts, all_curr_pts, return_indices=True) + regions_i_overlap = np.unique(new_feat_arr[common_ix_new]) + no_prev_feature = np.array(list(regions_i.keys()))[ + np.logical_not(np.isin(list(regions_i.keys()), regions_i_overlap)) + ] + list_remove = np.concatenate([list_remove, no_prev_feature]) + # remove parent regions: if features_thresholds is not None: features_thresholds = features_thresholds[ ~features_thresholds["idx"].isin(list_remove) ] - return features_thresholds + if strict_thresholding: + keep_new_keys = np.isin(list(regions_i.keys()), features_thresholds["idx"]) + regions_old = { + k: v for i, (k, v) in enumerate(regions_i.items()) if keep_new_keys[i] + } + else: + keep_old_keys = np.isin( + list(regions_old.keys()), features_thresholds["idx"] + ) + regions_old = { + k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i] + } + regions_old.update(regions_i) + else: + regions_old = regions_i + + return features_thresholds, regions_old def feature_detection_threshold( @@ -994,45 +1046,65 @@ def feature_detection_multithreshold_timestep( [features_thresholds, features_threshold_i], ignore_index=True ) + # if i_threshold>0: + # print(regions_old.keys()) + # For multiple threshold, and features found both in the current and previous step, remove "parent" features from Dataframe - if i_threshold > 0 and not features_thresholds.empty and regions_old: + if i_threshold > 0 and not features_thresholds.empty: # for each threshold value: check if newly found features are surrounded by feature based on less restrictive threshold - features_thresholds = remove_parents( - features_thresholds, regions_i, regions_old - ) - - if strict_thresholding: - if regions_i: - # remove data in regions where no features were detected - valid_regions: np.ndarray = np.zeros_like(track_data) - region_indices: list[int] = list(regions_i.values())[ - 0 - ] # linear indices - valid_regions.ravel()[region_indices] = 1 - track_data: np.ndarray = np.multiply(valid_regions, track_data) - else: - # since regions_i is empty no further features can be detected - logging.debug( - "Finished feature detection for threshold " - + str(i_threshold) - + " : " - + str(threshold_i) - ) - return features_thresholds - - if i_threshold > 0 and not features_thresholds.empty and regions_old: - # Work out which regions are still in feature_thresholds to keep - # This is faster than calling "in" for every idx - keep_old_keys = np.isin( - list(regions_old.keys()), features_thresholds["idx"] + features_thresholds, regions_old = remove_parents( + features_thresholds, + regions_i, + regions_old, + strict_thresholding=strict_thresholding, ) - regions_old = { - k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i] - } - regions_old.update(regions_i) - else: + elif i_threshold == 0: regions_old = regions_i + # print(regions_i.keys()) + # if i_threshold>0: + # print(regions_old.keys()) + + # if strict_thresholding: + # if regions_i: + # # remove data in regions where no features were detected + # valid_regions: np.ndarray = np.zeros_like(track_data, dtype=bool) + # region_indices: list[int] = list(regions_i.values())[ + # 0 + # ] # linear indices + # valid_regions.ravel()[region_indices] = 1 + # if i_threshold > 2: + # raise RuntimeError + # # track_data[np.logical_not(valid_regions)] = threshold_i + # if target=="maximum": + # track_data[np.logical_not(valid_regions)] = np.minimum(track_data[np.logical_not(valid_regions)], threshold_i) + # # track_data[np.logical_not(valid_regions)] = -np.inf + # elif target=="minimum": + # track_data[np.logical_not(valid_regions)] = np.maximum(track_data[np.logical_not(valid_regions)], threshold_i) + # # track_data[np.logical_not(valid_regions)] = np.inf + # else: + # # since regions_i is empty no further features can be detected + # logging.debug( + # "Finished feature detection for threshold " + # + str(i_threshold) + # + " : " + # + str(threshold_i) + # ) + # return features_thresholds + + # if i_threshold > 0 and not features_thresholds.empty and regions_old: + # # Work out which regions are still in feature_thresholds to keep + # # This is faster than calling "in" for every idx + # keep_old_keys = np.isin( + # list(regions_old.keys()), features_thresholds["idx"] + # ) + # regions_old = { + # k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i] + # } + # regions_old.update(regions_i) + # else: + # regions_old = regions_i + logging.debug( "Finished feature detection for threshold " + str(i_threshold) From b762ee540f61acd34f6925c822ccb40f1da05d5b Mon Sep 17 00:00:00 2001 From: William Jones Date: Fri, 28 Jul 2023 21:53:43 +0100 Subject: [PATCH 2/3] Remove commented out code --- tobac/feature_detection.py | 48 -------------------------------------- 1 file changed, 48 deletions(-) diff --git a/tobac/feature_detection.py b/tobac/feature_detection.py index 0a4d7c63..25f9faa0 100644 --- a/tobac/feature_detection.py +++ b/tobac/feature_detection.py @@ -330,7 +330,6 @@ def remove_parents( for idx_new in regions_i: new_feat_arr[curr_loc : curr_loc + len(regions_i[idx_new])] = idx_new curr_loc += len(regions_i[idx_new]) - # _, _, common_ix_new = np.intersect1d(all_old_pts, all_curr_pts, return_indices=True) regions_i_overlap = np.unique(new_feat_arr[common_ix_new]) no_prev_feature = np.array(list(regions_i.keys()))[ np.logical_not(np.isin(list(regions_i.keys()), regions_i_overlap)) @@ -1046,9 +1045,6 @@ def feature_detection_multithreshold_timestep( [features_thresholds, features_threshold_i], ignore_index=True ) - # if i_threshold>0: - # print(regions_old.keys()) - # For multiple threshold, and features found both in the current and previous step, remove "parent" features from Dataframe if i_threshold > 0 and not features_thresholds.empty: # for each threshold value: check if newly found features are surrounded by feature based on less restrictive threshold @@ -1061,50 +1057,6 @@ def feature_detection_multithreshold_timestep( elif i_threshold == 0: regions_old = regions_i - # print(regions_i.keys()) - # if i_threshold>0: - # print(regions_old.keys()) - - # if strict_thresholding: - # if regions_i: - # # remove data in regions where no features were detected - # valid_regions: np.ndarray = np.zeros_like(track_data, dtype=bool) - # region_indices: list[int] = list(regions_i.values())[ - # 0 - # ] # linear indices - # valid_regions.ravel()[region_indices] = 1 - # if i_threshold > 2: - # raise RuntimeError - # # track_data[np.logical_not(valid_regions)] = threshold_i - # if target=="maximum": - # track_data[np.logical_not(valid_regions)] = np.minimum(track_data[np.logical_not(valid_regions)], threshold_i) - # # track_data[np.logical_not(valid_regions)] = -np.inf - # elif target=="minimum": - # track_data[np.logical_not(valid_regions)] = np.maximum(track_data[np.logical_not(valid_regions)], threshold_i) - # # track_data[np.logical_not(valid_regions)] = np.inf - # else: - # # since regions_i is empty no further features can be detected - # logging.debug( - # "Finished feature detection for threshold " - # + str(i_threshold) - # + " : " - # + str(threshold_i) - # ) - # return features_thresholds - - # if i_threshold > 0 and not features_thresholds.empty and regions_old: - # # Work out which regions are still in feature_thresholds to keep - # # This is faster than calling "in" for every idx - # keep_old_keys = np.isin( - # list(regions_old.keys()), features_thresholds["idx"] - # ) - # regions_old = { - # k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i] - # } - # regions_old.update(regions_i) - # else: - # regions_old = regions_i - logging.debug( "Finished feature detection for threshold " + str(i_threshold) From c21905efefd0b9ac69039cfd8e67afd55c6673b3 Mon Sep 17 00:00:00 2001 From: William Jones Date: Fri, 28 Jul 2023 22:03:35 +0100 Subject: [PATCH 3/3] Add additional tests for strict_thresholding --- tobac/tests/test_feature_detection.py | 78 +++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tobac/tests/test_feature_detection.py b/tobac/tests/test_feature_detection.py index 4eb0784a..37b6b586 100644 --- a/tobac/tests/test_feature_detection.py +++ b/tobac/tests/test_feature_detection.py @@ -1,7 +1,9 @@ +import tobac import tobac.testing as tbtest import tobac.feature_detection as feat_detect import pytest import numpy as np +import xarray as xr from pandas.testing import assert_frame_equal @@ -745,6 +747,82 @@ def test_strict_thresholding(): assert len(features) == 1 assert features["threshold_value"].item() == thresholds[0] + # Repeat for minima + test_data_iris = tbtest.make_dataset_from_arr(10 - test_data, data_type="iris") + # All of these thresholds will be met + thresholds = [9, 5, 2.5] + + # This will detect 2 features (first and last threshold value) + features = feat_detect.feature_detection_multithreshold_timestep( + test_data_iris, + 0, + dxy=1, + threshold=thresholds, + n_min_threshold=n_min_thresholds, + strict_thresholding=False, + target="minimum", + ) + assert len(features) == 1 + assert features["threshold_value"].item() == thresholds[-1] + + # Since the second n_min_thresholds value is not met this will only detect 1 feature + features = feat_detect.feature_detection_multithreshold_timestep( + test_data_iris, + 0, + dxy=1, + threshold=thresholds, + n_min_threshold=n_min_thresholds, + strict_thresholding=True, + target="minimum", + ) + assert len(features) == 1 + assert features["threshold_value"].item() == thresholds[0] + + # Test example from documentation + input_field_arr = np.zeros((1, 101, 101)) + + for idx, side in enumerate([40, 20, 10, 5]): + input_field_arr[ + :, + (50 - side - 4 * idx) : (50 + side - 4 * idx), + (50 - side - 4 * idx) : (50 + side - 4 * idx), + ] = ( + 50 - side + ) + + input_field_iris = xr.DataArray( + input_field_arr, + dims=["time", "Y", "X"], + coords={"time": [np.datetime64("2019-01-01T00:00:00")]}, + ).to_iris() + + thresholds = [8, 29, 39, 44] + + n_min_thresholds = [79**2, input_field_arr.size, 8**2, 3**2] + + features_demo = tobac.feature_detection_multithreshold( + input_field_iris, + dxy=1000, + threshold=thresholds, + n_min_threshold=n_min_thresholds, + strict_thresholding=False, + ) + + assert features_demo.iloc[0]["hdim_1"] == pytest.approx(37.5) + assert features_demo.iloc[0]["hdim_2"] == pytest.approx(37.5) + + # Now repeat with strict thresholding + features_demo = tobac.feature_detection_multithreshold( + input_field_iris, + dxy=1000, + threshold=thresholds, + n_min_threshold=n_min_thresholds, + strict_thresholding=True, + ) + + assert features_demo.iloc[0]["hdim_1"] == pytest.approx(49.5) + assert features_demo.iloc[0]["hdim_2"] == pytest.approx(49.5) + @pytest.mark.parametrize( "h1_indices, h2_indices, max_h1, max_h2, PBC_flag, position_threshold, expected_output",