diff --git a/tobac/feature_detection.py b/tobac/feature_detection.py index 108069ae..25f9faa0 100644 --- a/tobac/feature_detection.py +++ b/tobac/feature_detection.py @@ -258,7 +258,9 @@ def test_overlap(region_inner, region_outer): return not overlap -def remove_parents(features_thresholds, regions_i, regions_old): +def remove_parents( + features_thresholds, regions_i, regions_old, strict_thresholding=False +): """Remove parents of newly detected feature regions. Remove features where its regions surround newly @@ -279,6 +281,9 @@ def remove_parents(features_thresholds, regions_i, regions_old): threshold from previous threshold (feature ids as keys). + strict_thresholding: Bool, optional + If True, a feature can only be detected if all previous thresholds have been met. + Default is False. Returns ------- features_thresholds : pandas.DataFrame @@ -288,26 +293,72 @@ def remove_parents(features_thresholds, regions_i, regions_old): try: all_curr_pts = np.concatenate([vals for idx, vals in regions_i.items()]) + except ValueError: + # the case where there are no new regions + if strict_thresholding: + return features_thresholds, {} + else: + return features_thresholds, regions_old + try: all_old_pts = np.concatenate([vals for idx, vals in regions_old.items()]) except ValueError: - # the case where there are no regions - return features_thresholds + # the case where there are no old regions + if strict_thresholding: + return ( + features_thresholds[ + ~features_thresholds["idx"].isin(list(regions_i.keys())) + ], + {}, + ) + else: + return features_thresholds, regions_i + old_feat_arr = np.empty((len(all_old_pts))) curr_loc = 0 for idx_old in regions_old: old_feat_arr[curr_loc : curr_loc + len(regions_old[idx_old])] = idx_old curr_loc += len(regions_old[idx_old]) - _, _, common_ix_old = np.intersect1d(all_curr_pts, all_old_pts, return_indices=True) + _, common_ix_new, common_ix_old = np.intersect1d( + all_curr_pts, all_old_pts, return_indices=True + ) list_remove = np.unique(old_feat_arr[common_ix_old]) + if strict_thresholding: + new_feat_arr = np.empty((len(all_curr_pts))) + curr_loc = 0 + for idx_new in regions_i: + new_feat_arr[curr_loc : curr_loc + len(regions_i[idx_new])] = idx_new + curr_loc += len(regions_i[idx_new]) + regions_i_overlap = np.unique(new_feat_arr[common_ix_new]) + no_prev_feature = np.array(list(regions_i.keys()))[ + np.logical_not(np.isin(list(regions_i.keys()), regions_i_overlap)) + ] + list_remove = np.concatenate([list_remove, no_prev_feature]) + # remove parent regions: if features_thresholds is not None: features_thresholds = features_thresholds[ ~features_thresholds["idx"].isin(list_remove) ] - return features_thresholds + if strict_thresholding: + keep_new_keys = np.isin(list(regions_i.keys()), features_thresholds["idx"]) + regions_old = { + k: v for i, (k, v) in enumerate(regions_i.items()) if keep_new_keys[i] + } + else: + keep_old_keys = np.isin( + list(regions_old.keys()), features_thresholds["idx"] + ) + regions_old = { + k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i] + } + regions_old.update(regions_i) + else: + regions_old = regions_i + + return features_thresholds, regions_old def feature_detection_threshold( @@ -995,42 +1046,15 @@ def feature_detection_multithreshold_timestep( ) # For multiple threshold, and features found both in the current and previous step, remove "parent" features from Dataframe - if i_threshold > 0 and not features_thresholds.empty and regions_old: + if i_threshold > 0 and not features_thresholds.empty: # for each threshold value: check if newly found features are surrounded by feature based on less restrictive threshold - features_thresholds = remove_parents( - features_thresholds, regions_i, regions_old + features_thresholds, regions_old = remove_parents( + features_thresholds, + regions_i, + regions_old, + strict_thresholding=strict_thresholding, ) - - if strict_thresholding: - if regions_i: - # remove data in regions where no features were detected - valid_regions: np.ndarray = np.zeros_like(track_data) - region_indices: list[int] = list(regions_i.values())[ - 0 - ] # linear indices - valid_regions.ravel()[region_indices] = 1 - track_data: np.ndarray = np.multiply(valid_regions, track_data) - else: - # since regions_i is empty no further features can be detected - logging.debug( - "Finished feature detection for threshold " - + str(i_threshold) - + " : " - + str(threshold_i) - ) - return features_thresholds - - if i_threshold > 0 and not features_thresholds.empty and regions_old: - # Work out which regions are still in feature_thresholds to keep - # This is faster than calling "in" for every idx - keep_old_keys = np.isin( - list(regions_old.keys()), features_thresholds["idx"] - ) - regions_old = { - k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i] - } - regions_old.update(regions_i) - else: + elif i_threshold == 0: regions_old = regions_i logging.debug( diff --git a/tobac/tests/test_feature_detection.py b/tobac/tests/test_feature_detection.py index 4eb0784a..37b6b586 100644 --- a/tobac/tests/test_feature_detection.py +++ b/tobac/tests/test_feature_detection.py @@ -1,7 +1,9 @@ +import tobac import tobac.testing as tbtest import tobac.feature_detection as feat_detect import pytest import numpy as np +import xarray as xr from pandas.testing import assert_frame_equal @@ -745,6 +747,82 @@ def test_strict_thresholding(): assert len(features) == 1 assert features["threshold_value"].item() == thresholds[0] + # Repeat for minima + test_data_iris = tbtest.make_dataset_from_arr(10 - test_data, data_type="iris") + # All of these thresholds will be met + thresholds = [9, 5, 2.5] + + # This will detect 2 features (first and last threshold value) + features = feat_detect.feature_detection_multithreshold_timestep( + test_data_iris, + 0, + dxy=1, + threshold=thresholds, + n_min_threshold=n_min_thresholds, + strict_thresholding=False, + target="minimum", + ) + assert len(features) == 1 + assert features["threshold_value"].item() == thresholds[-1] + + # Since the second n_min_thresholds value is not met this will only detect 1 feature + features = feat_detect.feature_detection_multithreshold_timestep( + test_data_iris, + 0, + dxy=1, + threshold=thresholds, + n_min_threshold=n_min_thresholds, + strict_thresholding=True, + target="minimum", + ) + assert len(features) == 1 + assert features["threshold_value"].item() == thresholds[0] + + # Test example from documentation + input_field_arr = np.zeros((1, 101, 101)) + + for idx, side in enumerate([40, 20, 10, 5]): + input_field_arr[ + :, + (50 - side - 4 * idx) : (50 + side - 4 * idx), + (50 - side - 4 * idx) : (50 + side - 4 * idx), + ] = ( + 50 - side + ) + + input_field_iris = xr.DataArray( + input_field_arr, + dims=["time", "Y", "X"], + coords={"time": [np.datetime64("2019-01-01T00:00:00")]}, + ).to_iris() + + thresholds = [8, 29, 39, 44] + + n_min_thresholds = [79**2, input_field_arr.size, 8**2, 3**2] + + features_demo = tobac.feature_detection_multithreshold( + input_field_iris, + dxy=1000, + threshold=thresholds, + n_min_threshold=n_min_thresholds, + strict_thresholding=False, + ) + + assert features_demo.iloc[0]["hdim_1"] == pytest.approx(37.5) + assert features_demo.iloc[0]["hdim_2"] == pytest.approx(37.5) + + # Now repeat with strict thresholding + features_demo = tobac.feature_detection_multithreshold( + input_field_iris, + dxy=1000, + threshold=thresholds, + n_min_threshold=n_min_thresholds, + strict_thresholding=True, + ) + + assert features_demo.iloc[0]["hdim_1"] == pytest.approx(49.5) + assert features_demo.iloc[0]["hdim_2"] == pytest.approx(49.5) + @pytest.mark.parametrize( "h1_indices, h2_indices, max_h1, max_h2, PBC_flag, position_threshold, expected_output",