Skip to content

Commit

Permalink
add decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Sep 19, 2024
1 parent 3e60480 commit c6c9517
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions decode_eyes_open_closed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run_cv(df):
#X_train = X_train[cols_use]
#X_test = X_test[cols_use]

model = linear_model.LogisticRegression(class_weight="balanced")
#model = linear_model.LogisticRegression(class_weight="balanced")
model = CatBoostClassifier(verbose=0)
model.fit(X_train, y_train)
pred = model.predict(X_test)
Expand Down Expand Up @@ -108,16 +108,16 @@ def compute_modality(sub, mod):
ba_mean, cm_mean, coef_mean = run_cv(df_sub_ch.copy())

dict_out = {"sub": sub, "loc": loc, "ch" : str(ch), "mod": mod, "dout": disease, "ba": float(ba_mean)}
for i, coef in enumerate(coef_mean.T):
dict_out[f"coef_{f_bands[i]}"] = float(coef)
#for i, coef in enumerate(coef_mean.T):
# dict_out[f"coef_{f_bands[i]}"] = float(coef)
l_df.append(dict_out)
return l_df

if __name__ == "__main__":

df_all = pd.read_csv(PATH_FEATURES)
df_all["label_enc"] = df_all["label"].map({"SLEEP": 0, "EyesOpen": 1, "EyesClosed": 2})
df_all.query("label_enc != 0", inplace=True)
#df_all.query("label_enc != 0", inplace=True)

df_all = df_all.drop(columns=["time", "label"])
np.random.seed()
Expand All @@ -132,12 +132,12 @@ def compute_modality(sub, mod):
#compute_modality(subs[0], modality_[0])

PATH_BASE = r"C:\Users\ICN_admin\OneDrive - Charité - Universitätsmedizin Berlin\Dokumente\Decoding toolbox\EyesOpenBeijing\2708"
for mod in modality_:
for mod in modality_[::-1]:
l_df_ = Parallel(n_jobs=len(modality_))(delayed(compute_modality)(sub, mod) for sub in subs)

df_per = pd.DataFrame(list(np.concat(l_df_)))
df_per = pd.DataFrame(list(np.concatenate(l_df_)))

df_per.to_csv(os.path.join(PATH_BASE, f"out_per_loc_mod_{mod}_CB.csv"), index=False)
df_per.to_csv(os.path.join(PATH_BASE, f"out_per_loc_mod_{mod}_three_class_CB.csv"), index=False)



Expand Down
4 changes: 2 additions & 2 deletions read_res_decoding_single_ch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

l_ = []
for mod in mods:
df = pd.read_csv(os.path.join(PATH_, f"out_per_loc_mod_{mod}.csv"))
df = pd.read_csv(os.path.join(PATH_, f"out_per_loc_mod_{mod}_three_class.csv"))
df["mod"] = mod
l_.append(df)

Expand Down Expand Up @@ -65,6 +65,6 @@
plt.figure(figsize=(4, 3), dpi=300)
sns.boxplot(data=df.query("loc == 'STN'"), x="dout", y="ba")
sns.swarmplot(data=df.query("loc == 'STN'"), x="dout", y="ba", color="gray", alpha=0.5)
plt.title("STN sleep-eyes open-eye closed alpha only")
plt.tight_layout()
plt.savefig("STN_per_comp_location.svg")
plt.show(block=True)

0 comments on commit c6c9517

Please sign in to comment.