Skip to content

Commit

Permalink
xr_predict merge output probs
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchest authored Sep 30, 2024
1 parent 8d11aa6 commit c50ff73
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions Tools/dea_tools/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,21 +344,29 @@ def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_inpu
out_proba = xr.DataArray(
out_proba, coords={"x": x, "y": y}, dims=["y", "x"]
)
output_xr["Probabilities"] = out_proba
else:
print(" returning class probability array.")
out_proba = out_proba * 100.0
# Loop through each DataArray in the Dataset
for band_name in out_proba.data_vars:
reshaped_band = out_proba[band_name].values.reshape(len(y), len(x))
reshaped_band = xr.DataArray(

class_names = model.classes_ # Get the unique class names from the fitted classifier

probabilities_dataset = xr.Dataset()

# Loop through each class (band)
for i, class_name in enumerate(class_names):
reshaped_band = out_proba[:, i].reshape(len(y), len(x))
reshaped_da = xr.DataArray(
reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"]
)
output_xr[out_proba] = reshaped_band
probabilities_dataset[f"prob_{class_name}"] = reshaped_da

# merge in the probabilities
output_xr = xr.merge([output_xr, probabilities_dataset])

if clean == True:
out_proba = da.where(da.isfinite(out_proba), out_proba, 0)

output_xr["Probabilities"] = out_proba

if return_input == True:
print(" input features...")
Expand Down

0 comments on commit c50ff73

Please sign in to comment.