From e4673385940b91c9a0b15228f45f185d9009ab3c Mon Sep 17 00:00:00 2001 From: CompRhys Date: Mon, 5 Aug 2024 20:09:13 +0000 Subject: [PATCH] fix: working predict.py for a single model --- aviary/predict.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/aviary/predict.py b/aviary/predict.py index 0fed9fd..f50d1e2 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -111,9 +111,9 @@ def make_ensemble_predictions( else: df[pred_col] = preds - if len(checkpoint_paths) > 1: - df_preds = df.filter(regex=r"_pred_\d") + df_preds = df.filter(regex=r"_pred_\d") + if len(checkpoint_paths) > 1: pred_ens_col = f"{target_col}_pred_ens" if target_col else "pred_ens" df[pred_ens_col] = ensemble_preds = df_preds.mean(axis=1) @@ -134,7 +134,7 @@ def make_ensemble_predictions( ).mean(axis=1) df[pred_tot_std_ens] = (epistemic_std**2 + aleatoric_std**2) ** 0.5 - if target_col: + if target_col is not None: targets = df[target_col] all_model_metrics = [ get_metrics(targets, df_preds[col], task_type) for col in df_preds @@ -145,11 +145,12 @@ def make_ensemble_predictions( print("\nSingle model performance:") print(df_metrics.describe().round(4).loc[["mean", "std"]]) - ensemble_metrics = get_metrics(targets, ensemble_preds, task_type) + if len(checkpoint_paths) > 1: + ensemble_metrics = get_metrics(targets, ensemble_preds, task_type) - print("\nEnsemble performance:") - for key, val in ensemble_metrics.items(): - print(f"{key:<8} {val:.3}") + print("\nEnsemble performance:") + for key, val in ensemble_metrics.items(): + print(f"{key:<8} {val:.3}") return df, df_metrics return df @@ -208,7 +209,7 @@ def predict_from_wandb_checkpoints( if not os.path.isfile(checkpoint_path): run.file(f"{checkpoint_filename}").download(root=out_dir) - if target_col in kwargs: + if target_col is not None: df, ensemble_metrics = make_ensemble_predictions(checkpoint_paths, **kwargs) # round to save disk space and speed up cloud storage uploads return df.round(6), ensemble_metrics