Skip to content

Commit

Permalink
add single channel decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Sep 12, 2024
1 parent 9ab8f42 commit e3fa53d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 41 deletions.
100 changes: 60 additions & 40 deletions decode_eyes_open_closed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from sklearn import metrics, model_selection, linear_model
from joblib import Parallel, delayed
from catboost import CatBoostClassifier

PATH_FEATURES = r"C:\Users\ICN_admin\OneDrive - Charité - Universitätsmedizin Berlin\Dokumente\Decoding toolbox\EyesOpenBeijing\2708\features_all.csv"
def balance_classes(df, target_column):
Expand All @@ -25,6 +26,7 @@ def run_cv(df):
model = linear_model.LinearRegression()
cm_list = []
ba_list = []
coef_list = []
for train_, test_ in cv.split(df, df["label_enc"]):
# integer encode the label column
train_df = df.iloc[train_]
Expand All @@ -34,8 +36,6 @@ def run_cv(df):
y_train = train_df_balanced["label_enc"]

# ensure that the same number of samples are taken for training from each class


X_test = df.iloc[test_, :]
y_test = df.iloc[test_, :]["label_enc"]
X_test = X_test.drop(columns=["label_enc"])
Expand All @@ -44,9 +44,9 @@ def run_cv(df):
#cols_use = [c for c in X_train.columns if all([f in c for f in features])]
#X_train = X_train[cols_use]
#X_test = X_test[cols_use]


model = linear_model.LogisticRegression(class_weight="balanced")
model = CatBoostClassifier(verbose=0)
model.fit(X_train, y_train)
pred = model.predict(X_test)

Expand All @@ -60,7 +60,10 @@ def run_cv(df):

cm_list.append(cm)
ba_list.append(metrics.balanced_accuracy_score(y_test, pred))

if type(model) == linear_model.LogisticRegression:
coef_list.append(model.coef_)
else:
coef_list.append(model.get_feature_importance())
print(metrics.balanced_accuracy_score(y_test, pred))

# plot the mean confusion matrix
Expand All @@ -71,31 +74,43 @@ def run_cv(df):
#plt.show(block=True)

ba_mean = sum(ba_list) / len(ba_list)
return ba_mean, cm_mean
return ba_mean, cm_mean, np.mean(coef_list, axis=0)

def compute_modality(mod):
def compute_modality(sub, mod):



#for mod in modality_:
l_df = []
if mod != "all":
df_mod = df_all[[c for c in df_all.columns if mod in c] + ["label_enc"] + ["sub"] + ["disease"]].copy()
else:
df_mod = df_all.copy()

for sub in subs:
df_sub = df_mod[df_mod["sub"] == sub]
disease = df_sub["disease"].unique()[0]
for loc in locs:
cols_loc = [c for c in df_sub.columns if loc in c] + ["label_enc"]
if len(cols_loc) == 0:
continue
df_sub_loc = df_sub[cols_loc].copy()
# select only columns that have non NaN values
df_sub_loc = df_sub_loc.dropna(axis=1)
if len(df_sub_loc.columns) == 1:
continue

ba_mean, cm_mean = run_cv(df_sub_loc.copy())

l_df.append({"sub": sub, "loc": loc, "mod": mod, "disease": disease, "ba": ba_mean})
df_sub = df_mod[df_mod["sub"] == sub]
disease = df_sub["disease"].unique()[0]
for loc in locs:
print(loc)
cols_loc = [c for c in df_sub.columns if loc in c] + ["label_enc"]
if len(cols_loc) == 0:
continue
df_sub_loc = df_sub[cols_loc].copy()
# select only columns that have non NaN values
df_sub_loc = df_sub_loc.dropna(axis=1)
if len(df_sub_loc.columns) == 1:
continue

chs_ = np.unique([c.split("_")[0] for c in df_sub_loc.columns if "label" not in c])
for ch in chs_:
print(ch)
df_sub_ch = df_sub_loc[[c for c in df_sub_loc.columns if ch in c or "label" in c]].copy()
f_bands = [c.split("_")[2] for c in df_sub_ch.columns if "label" not in c]
ba_mean, cm_mean, coef_mean = run_cv(df_sub_ch.copy())

dict_out = {"sub": sub, "loc": loc, "ch" : str(ch), "mod": mod, "dout": disease, "ba": float(ba_mean)}
for i, coef in enumerate(coef_mean.T):
dict_out[f"coef_{f_bands[i]}"] = float(coef)
l_df.append(dict_out)
return l_df

if __name__ == "__main__":
Expand All @@ -109,36 +124,41 @@ def compute_modality(mod):

locs = ["STN", "GPI", "ECOG", "EEG"]
modality_ = ["theta", "alpha", "low beta", "high beta", "low gamma", "high gamma", "HFA", "fft", "fooof", "all"]
modality_ = ["fft", "alpha", ]

subs = df_all["sub"].unique()
diseases = df_all["disease"].unique()

#compute_modality(modality_[0])
#compute_modality(subs[0], modality_[0])

l_df_ = Parallel(n_jobs=len(modality_))(delayed(compute_modality)(mod) for mod in modality_)
PATH_BASE = r"C:\Users\ICN_admin\OneDrive - Charité - Universitätsmedizin Berlin\Dokumente\Decoding toolbox\EyesOpenBeijing\2708"
for mod in modality_:
l_df_ = Parallel(n_jobs=len(modality_))(delayed(compute_modality)(sub, mod) for sub in subs)

df_per = pd.DataFrame(list(np.concat(l_df_)))

df_per.to_csv(os.path.join(PATH_BASE, f"out_per_loc_mod_{mod}_CB.csv"), index=False)




df_per = pd.DataFrame(list(np.concat(l_df_)))
df_per.to_csv(r"C:\Users\ICN_admin\OneDrive - Charité - Universitätsmedizin Berlin\Dokumente\Decoding toolbox\EyesOpenBeijing\2708\out_per_loc_mod.csv", index=False)

plt.figure()
sns.boxplot(data=df_per, x="loc", y="ba", hue="mod", palette="viridis")
#sns.swarmplot(data=df_per, x="loc", y="ba", color=".25", hue="mod")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.tight_layout()
plt.savefig("ba_")
plt.show(block=True)

plt.figure()
sns.boxplot(data=df_per.query("mod == 'all'"), x="disease", y="ba", hue="mod", palette="viridis")
sns.swarmplot(data=df_per.query("mod == 'all'"), x="disease", y="ba", color=".25", hue="mod")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.tight_layout()
plt.savefig("ba_")
plt.show(block=True)
# plt.figure()
# sns.boxplot(data=df_per, x="loc", y="ba", hue="mod", palette="viridis")
# #sns.swarmplot(data=df_per, x="loc", y="ba", color=".25", hue="mod")
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
# plt.tight_layout()
# plt.savefig("ba_")
# plt.show(block=True)

# plt.figure()
# sns.boxplot(data=df_per.query("mod == 'all'"), x="disease", y="ba", hue="mod", palette="viridis")
# sns.swarmplot(data=df_per.query("mod == 'all'"), x="disease", y="ba", color=".25", hue="mod")
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
# plt.tight_layout()
# plt.savefig("ba_")
# plt.show(block=True)

# for loc in locs:
# features = ["gamma",]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ dependencies = [
"numpy >= 1.21.2",
"pandas >= 2.0.0",
"scikit-image",
"scikit-learn >= 0.24.2",
"scikit-optimize",
"scipy >= 1.7.1",
"seaborn >= 0.11",
Expand All @@ -57,6 +56,7 @@ dependencies = [
"pydantic>=2.8.2",
"mne-qt-browser>=0.6.3",
"specparam>=2.0.0rc2",
"catboost>=1.2.7",
]

[project.optional-dependencies]
Expand Down
34 changes: 34 additions & 0 deletions read_res_decoding_single_ch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pandas as pd
import os

from matplotlib import pyplot as plt
import seaborn as sns

PATH_ = r"C:\Users\ICN_admin\OneDrive - Charité - Universitätsmedizin Berlin\Dokumente\Decoding toolbox\EyesOpenBeijing\2708"

mods = ["alpha", "fft"]

l_ = []
for mod in mods:
df = pd.read_csv(os.path.join(PATH_, f"out_per_loc_mod_{mod}.csv"))
df["mod"] = mod
l_.append(df)

df_all = pd.concat(l_)

plt.figure()
sns.boxplot(data=df_all, x="loc", y="ba", hue="mod", palette="viridis")
plt.tight_layout()
plt.show(block=True)

df = pd.read_csv(os.path.join(PATH_, f"out_per_loc_mod_fft.csv"))
# melt the dataframe that all columns with coef_ become a column
df_melt = df.melt(id_vars=["ba", "loc", "sub", "dout"], value_vars=[c for c in df.columns if "coef_" in c])

plt.figure()
# melt the dataframe that coef_ becomes a column
sns.boxplot(data=df_melt, x="variable", y="value", palette="viridis")
plt.xticks(rotation=90)
plt.ylabel("Coef")
plt.tight_layout()
plt.show(block=True)

0 comments on commit e3fa53d

Please sign in to comment.