Skip to content

Commit

Permalink
add annotations for bad segments rejection
Browse files Browse the repository at this point in the history
  • Loading branch information
dmalt committed Aug 27, 2022
1 parent 349b5f7 commit 2068c1e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
1 change: 1 addition & 0 deletions rawdata/code/configs/042-gen_ic_audio_mi_report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ deriv_paths: ${paths.042-gen_ic_audio_mi_report}
input:
raw: ${paths.021-apply_maxfilter.maxfilt_raw}
ica: ${paths.041-compute_ica.solution}
annots: ${paths.031-annotate_postmaxfilt.annots}
output:
report: ${deriv_paths.report}

Expand Down
41 changes: 29 additions & 12 deletions rawdata/code/preproc/042-gen_ic_audio_mi_report.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#!/usr/bin/env python
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple

import hydra
import matplotlib.pyplot as plt # type: ignore
import mne # type: ignore
import numpy as np
from hydra.core.config_store import ConfigStore
from matplotlib.figure import Figure # type: ignore
from mne import Report # type: ignore
from mne.io import Raw, RawArray, read_raw_fif # type: ignore
from mne.preprocessing import ICA # type: ignore
from mne.preprocessing import read_ica # type: ignore
from mne.preprocessing import ICA, read_ica # type: ignore
from sklearn.feature_selection import mutual_info_regression # type: ignore
from tqdm import tqdm, trange # type: ignore
from utils import BaseConfig, prepare_script
Expand All @@ -23,6 +22,7 @@
class Input:
raw: str
ica: str
annots: str


@dataclass
Expand Down Expand Up @@ -70,22 +70,22 @@ def compute_crosscorr(
cc1 = np.correlate(c1, c1)
cc2 = np.correlate(c2, c2)
times = np.arange(-shift_nsamp, shift_nsamp + 1) / sr
return times, np.correlate(c1, c2) / np.sqrt(cc1 * cc2)
return times, np.correlate(c2, c1) / np.sqrt(cc1 * cc2)


def retreive_audio_envelope(raw: Raw, audio_ch: str, dsamp_sfreq: float) -> np.ndarray:
def retreive_audio_envelope(raw: mne.io.Raw, audio_ch: str, dsamp_sfreq: float) -> np.ndarray:
audio_array = raw.get_data(picks=audio_ch, reject_by_annotation="omit")
audio_env = smooth(amplitude_env(np.squeeze(audio_array)))[np.newaxis, :]
audio_info = raw.info.copy().pick_channels([audio_ch])
audio_env_smooth_raw = RawArray(audio_env, audio_info).resample(dsamp_sfreq)
audio_env_smooth_raw = mne.io.RawArray(audio_env, audio_info).resample(dsamp_sfreq)
return np.squeeze(audio_env_smooth_raw.get_data())


def retreive_ics_envelope(
ica: ICA, raw: Raw, dsamp_sfreq: float, ica_muscle_band_filt: IcaMuscleBandFilt
ica: ICA, raw: mne.io.Raw, dsamp_sfreq: float, ica_muscle_band_filt: IcaMuscleBandFilt
) -> np.ndarray:
ics = ica.get_sources(raw)
assert isinstance(ics, Raw)
assert isinstance(ics, mne.io.Raw)
ics._first_samps = raw._first_samps
ics._last_samps = raw._last_samps
del raw
Expand Down Expand Up @@ -120,13 +120,29 @@ def gen_crosscorrelation_fig(times: np.ndarray, corr: np.ndarray) -> Figure:
return corr_fig


def plot_envelopes(ica_env: np.ndarray, audio_env: np.ndarray, dsamp_sfreq: float) -> None:
import matplotlib
from mne import create_info

matplotlib.use("TkAgg")
ch_names = [f"IC {i}" for i in range(len(ica_env))] + ["Audio"]
info = create_info(sfreq=dsamp_sfreq, ch_names=ch_names)
data = np.concatenate([ica_env, audio_env[np.newaxis, :]], axis=0)
raw = mne.io.RawArray(data, info)
raw.plot(block=True)


@hydra.main(config_path="../configs/", config_name="042-gen_ic_audio_mi_report")
def main(cfg: Config) -> None:
prepare_script(logger, script_name=__file__)

raw = read_raw_fif(cfg.input.raw, preload=True)
raw = mne.io.read_raw_fif(cfg.input.raw, preload=True)
if Path(cfg.input.annots).exists():
raw.set_annotations(mne.read_annotations(cfg.input.annots))
else:
logger.warning(f"Annotation file is missing at {cfg.input.annots}")

report = Report()
report = mne.Report()
report.add_raw(raw, title="Raw")

audio_env = retreive_audio_envelope(raw, cfg.audio_ch, cfg.dsamp_sfreq)
Expand All @@ -146,7 +162,8 @@ def main(cfg: Config) -> None:
report.add_figure(topo_fig, title=f"ICA {i_comp} topo")
report.add_figure(corr_fig, title=f"ICA {i_comp} - audio cross-correlation")

report.save(cfg.output.report)
report.save(cfg.output.report, overwrite=True)
plot_envelopes(ica_env, audio_env, cfg.dsamp_sfreq)


if __name__ == "__main__":
Expand Down

0 comments on commit 2068c1e

Please sign in to comment.