From 72bb55f816d71092da51698d0f3a6b3fffacae71 Mon Sep 17 00:00:00 2001 From: Mitchell Lyons Date: Mon, 30 Sep 2024 10:38:48 +1000 Subject: [PATCH 1/6] add probability array output to predict_xr --- Tools/dea_tools/classification.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/Tools/dea_tools/classification.py b/Tools/dea_tools/classification.py index 5ca8b411..660b1073 100644 --- a/Tools/dea_tools/classification.py +++ b/Tools/dea_tools/classification.py @@ -226,6 +226,7 @@ def predict_xr( chunk_size=None, persist=False, proba=False, + max_proba=True, clean=False, return_input=False, ): @@ -255,6 +256,11 @@ def predict_xr( distributed RAM. proba : bool If True, predict probabilities + max_proba : bool + If True, the probabilities array will be flattened to contain + only the probabiltiy for the "Predictions" class. If False, + the "Probabilities" object will be an array of prediction + probaiblities for each classes clean : bool If True, remove Infs and NaNs from input and output arrays return_input : bool @@ -282,7 +288,7 @@ def predict_xr( input_xr.chunks["y"][0] ) - def _predict_func(model, input_xr, persist, proba, clean, return_input): + def _predict_func(model, input_xr, persist, proba, proba_max, clean, return_input): x, y, crs = input_xr.x, input_xr.y, input_xr.geobox.crs input_data = [] @@ -331,7 +337,12 @@ def _predict_func(model, input_xr, persist, proba, clean, return_input): out_proba = model.predict_proba(input_data_flattened) # convert to % - out_proba = da.max(out_proba, axis=1) * 100.0 + if proba_max == True: + print(" returning single probability band.") + out_proba = da.max(out_proba, axis=1) * 100.0 + else: + print(" returning class probability array.") + out_proba = out_proba * 100.0 if clean == True: out_proba = da.where(da.isfinite(out_proba), out_proba, 0) From 507f53bd4cb3f23f612b8e1a2f0b46ab940459d1 Mon Sep 17 00:00:00 2001 From: Mitchell Lyons Date: Mon, 30 Sep 2024 10:45:31 +1000 Subject: [PATCH 2/6] predict_xr at proba_max args --- Tools/dea_tools/classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Tools/dea_tools/classification.py b/Tools/dea_tools/classification.py index 660b1073..f345e197 100644 --- a/Tools/dea_tools/classification.py +++ b/Tools/dea_tools/classification.py @@ -402,12 +402,12 @@ def _predict_func(model, input_xr, persist, proba, proba_max, clean, return_inpu model = ParallelPostFit(model) with joblib.parallel_backend("dask"): output_xr = _predict_func( - model, input_xr, persist, proba, clean, return_input + model, input_xr, persist, proba, proba_max, clean, return_input ) else: output_xr = _predict_func( - model, input_xr, persist, proba, clean, return_input + model, input_xr, persist, proba, proba_max, clean, return_input ).compute() return output_xr From 9a0072231dd8b5a02ee20f0018335a1e6549376c Mon Sep 17 00:00:00 2001 From: Mitchell Lyons Date: Mon, 30 Sep 2024 10:54:56 +1000 Subject: [PATCH 3/6] predict_xr match arg names --- Tools/dea_tools/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Tools/dea_tools/classification.py b/Tools/dea_tools/classification.py index f345e197..684978ae 100644 --- a/Tools/dea_tools/classification.py +++ b/Tools/dea_tools/classification.py @@ -288,7 +288,7 @@ def predict_xr( input_xr.chunks["y"][0] ) - def _predict_func(model, input_xr, persist, proba, proba_max, clean, return_input): + def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_input): x, y, crs = input_xr.x, input_xr.y, input_xr.geobox.crs input_data = [] @@ -337,7 +337,7 @@ def _predict_func(model, input_xr, persist, proba, proba_max, clean, return_inpu out_proba = model.predict_proba(input_data_flattened) # convert to % - if proba_max == True: + if max_proba == True: print(" returning single probability band.") out_proba = da.max(out_proba, axis=1) * 100.0 else: @@ -402,12 +402,12 @@ def _predict_func(model, input_xr, persist, proba, proba_max, clean, return_inpu model = ParallelPostFit(model) with joblib.parallel_backend("dask"): output_xr = _predict_func( - model, input_xr, persist, proba, proba_max, clean, return_input + model, input_xr, persist, proba, max_proba, clean, return_input ) else: output_xr = _predict_func( - model, input_xr, persist, proba, proba_max, clean, return_input + model, input_xr, persist, proba, max_proba, clean, return_input ).compute() return output_xr From 8d11aa629751672ab4f1067d7387972c23833e55 Mon Sep 17 00:00:00 2001 From: Mitchell Lyons Date: Mon, 30 Sep 2024 11:54:34 +1000 Subject: [PATCH 4/6] xr_predict deal with multiband prob outout --- Tools/dea_tools/classification.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/Tools/dea_tools/classification.py b/Tools/dea_tools/classification.py index 684978ae..692119f6 100644 --- a/Tools/dea_tools/classification.py +++ b/Tools/dea_tools/classification.py @@ -340,18 +340,24 @@ def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_inpu if max_proba == True: print(" returning single probability band.") out_proba = da.max(out_proba, axis=1) * 100.0 + out_proba = out_proba.reshape(len(y), len(x)) + out_proba = xr.DataArray( + out_proba, coords={"x": x, "y": y}, dims=["y", "x"] + ) else: print(" returning class probability array.") out_proba = out_proba * 100.0 - + # Loop through each DataArray in the Dataset + for band_name in out_proba.data_vars: + reshaped_band = out_proba[band_name].values.reshape(len(y), len(x)) + reshaped_band = xr.DataArray( + reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"] + ) + output_xr[out_proba] = reshaped_band + if clean == True: out_proba = da.where(da.isfinite(out_proba), out_proba, 0) - - out_proba = out_proba.reshape(len(y), len(x)) - - out_proba = xr.DataArray( - out_proba, coords={"x": x, "y": y}, dims=["y", "x"] - ) + output_xr["Probabilities"] = out_proba if return_input == True: From c50ff736ec6ed1e1b847a01b3598fd3968434a82 Mon Sep 17 00:00:00 2001 From: Mitchell Lyons Date: Mon, 30 Sep 2024 13:27:33 +1000 Subject: [PATCH 5/6] xr_predict merge output probs --- Tools/dea_tools/classification.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/Tools/dea_tools/classification.py b/Tools/dea_tools/classification.py index 692119f6..a9d37e98 100644 --- a/Tools/dea_tools/classification.py +++ b/Tools/dea_tools/classification.py @@ -344,21 +344,29 @@ def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_inpu out_proba = xr.DataArray( out_proba, coords={"x": x, "y": y}, dims=["y", "x"] ) + output_xr["Probabilities"] = out_proba else: print(" returning class probability array.") out_proba = out_proba * 100.0 - # Loop through each DataArray in the Dataset - for band_name in out_proba.data_vars: - reshaped_band = out_proba[band_name].values.reshape(len(y), len(x)) - reshaped_band = xr.DataArray( + + class_names = model.classes_ # Get the unique class names from the fitted classifier + + probabilities_dataset = xr.Dataset() + + # Loop through each class (band) + for i, class_name in enumerate(class_names): + reshaped_band = out_proba[:, i].reshape(len(y), len(x)) + reshaped_da = xr.DataArray( reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"] ) - output_xr[out_proba] = reshaped_band + probabilities_dataset[f"prob_{class_name}"] = reshaped_da + + # merge in the probabilities + output_xr = xr.merge([output_xr, probabilities_dataset]) if clean == True: out_proba = da.where(da.isfinite(out_proba), out_proba, 0) - output_xr["Probabilities"] = out_proba if return_input == True: print(" input features...") From a6e937f89fd24ea91a6abcaf42e8f1dd09cb0c85 Mon Sep 17 00:00:00 2001 From: Mitchell Lyons Date: Mon, 30 Sep 2024 14:03:40 +1000 Subject: [PATCH 6/6] clean up comments and spacing --- Tools/dea_tools/classification.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Tools/dea_tools/classification.py b/Tools/dea_tools/classification.py index a9d37e98..27ac5e32 100644 --- a/Tools/dea_tools/classification.py +++ b/Tools/dea_tools/classification.py @@ -336,7 +336,7 @@ def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_inpu print(" probabilities...") out_proba = model.predict_proba(input_data_flattened) - # convert to % + # return either one band with the max probability, or the whole probability array if max_proba == True: print(" returning single probability band.") out_proba = da.max(out_proba, axis=1) * 100.0 @@ -348,12 +348,10 @@ def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_inpu else: print(" returning class probability array.") out_proba = out_proba * 100.0 - class_names = model.classes_ # Get the unique class names from the fitted classifier - probabilities_dataset = xr.Dataset() - # Loop through each class (band) + probabilities_dataset = xr.Dataset() for i, class_name in enumerate(class_names): reshaped_band = out_proba[:, i].reshape(len(y), len(x)) reshaped_da = xr.DataArray(