Skip to content

Commit

Permalink
xr_predict deal with multiband prob outout
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchest authored Sep 30, 2024
1 parent 9a00722 commit 8d11aa6
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions Tools/dea_tools/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,24 @@ def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_inpu
if max_proba == True:
print(" returning single probability band.")
out_proba = da.max(out_proba, axis=1) * 100.0
out_proba = out_proba.reshape(len(y), len(x))
out_proba = xr.DataArray(
out_proba, coords={"x": x, "y": y}, dims=["y", "x"]
)
else:
print(" returning class probability array.")
out_proba = out_proba * 100.0

# Loop through each DataArray in the Dataset
for band_name in out_proba.data_vars:
reshaped_band = out_proba[band_name].values.reshape(len(y), len(x))
reshaped_band = xr.DataArray(
reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"]
)
output_xr[out_proba] = reshaped_band

if clean == True:
out_proba = da.where(da.isfinite(out_proba), out_proba, 0)

out_proba = out_proba.reshape(len(y), len(x))

out_proba = xr.DataArray(
out_proba, coords={"x": x, "y": y}, dims=["y", "x"]
)

output_xr["Probabilities"] = out_proba

if return_input == True:
Expand Down

0 comments on commit 8d11aa6

Please sign in to comment.