Skip to content

Commit

Permalink
add unittests, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
tab-cmd committed Dec 13, 2024
1 parent 09c6656 commit 34845c9
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 51 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ omit =
*/gui/experiments/*
*/gui/viewer/*
*/gui/BCInterface.py
*/signal/model/offline_analysis.py
*/signal/evaluate/fusion.py

[report]
exclude_lines =
Expand Down
3 changes: 2 additions & 1 deletion bcipy/signal/evaluate/fusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# mypy: disable-error-code="assignment,var-annotated"
import numpy as np
from sklearn.utils import resample
from typing import List, Tuple
Expand Down Expand Up @@ -33,7 +34,7 @@ def calculate_eeg_gaze_fusion_acc(
data_folder: str,
n_iterations: int = 10,
eeg_model: SignalModel = PcaRdaKdeModel,
gaze_model: SignalModel = GaussianProcess) -> Tuple[float, float, float]:
gaze_model: SignalModel = GaussianProcess) -> Tuple[List[float], List[float], List[float]]:
"""
Preprocess the EEG and gaze data. Calculate the accuracy of the fusion of EEG and Gaze models.
Args:
Expand Down
10 changes: 7 additions & 3 deletions bcipy/signal/generator/generator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import numpy as np
from typing import List, NamedTuple
from typing import List


def truncate_float(num: float, precision: float) -> float:
def truncate_float(num: float, precision: int) -> float:
"""Truncate a float to a given precision."""
return float(str(num)[:precision])


def gen_random_data(low, high, channel_count, precision=8) -> List[float]:
def gen_random_data(
low: float,
high: float,
channel_count: int,
precision: int = 8) -> List[float]:
"""Generate random data.
This function generates random data for testing purposes within a given range. The data is
Expand Down
2 changes: 1 addition & 1 deletion bcipy/signal/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __repr__(self):
return f"SignalModelMetadata(device_spec={self.device_spec}, transform={self.transform}, " \
f"evidence_type={self.evidence_type}, auc={self.auc}, accuracy={self.acc}, " \
f"balanced_accuracy={self.balanced_accuracy})"

def __str__(self):
return self.__repr__()

Expand Down
46 changes: 35 additions & 11 deletions bcipy/signal/model/gaussian_mixture/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,30 @@
warnings.filterwarnings("ignore") # ignore DeprecationWarnings from tensorflow


class GazeModelType(Enum):
"""Enum for gaze model types"""
GAUSSIAN_PROCESS = "GaussianProcess"
GM_INDIVIDUAL = "GMIndividual"
GM_CENTRALIZED = "GMCentralized"

def __str__(self):
return self.value

def __repr__(self):
return self.value

@staticmethod
def from_str(label: str):
if label == "GaussianProcess":
return GazeModelType.GAUSSIAN_PROCESS
elif label == "GMIndividual":
return GazeModelType.GM_INDIVIDUAL
elif label == "GMCentralized":
return GazeModelType.GM_CENTRALIZED
else:
raise ValueError(f"Model type {label} not recognized.")


class GazeModelResolver:
"""Factory class for gaze models
Expand All @@ -22,24 +46,24 @@ class GazeModelResolver:
@staticmethod
def resolve(model_type: str, *args, **kwargs) -> SignalModel:
"""Load a gaze model from the provided path."""
if model_type == "GaussianProcess":
model = GaussianProcess(*args, **kwargs)
elif model_type == "GMIndividual":
model = GMIndividual(*args, **kwargs)
elif model_type == "GMCentralized":
model = GMCentralized( *args, **kwargs)
model_type = GazeModelType.from_str(model_type)
if model_type == GazeModelType.GAUSSIAN_PROCESS:
return GaussianProcess(*args, **kwargs)
elif model_type == GazeModelType.GM_INDIVIDUAL:
return GMIndividual(*args, **kwargs)
elif model_type == GazeModelType.GM_CENTRALIZED:
return GMCentralized(*args, **kwargs)
else:
raise ValueError(f"Model type {model_type} not recognized.")

return model
raise ValueError(
f"Model type {model_type} not able to resolve. Not registered in GazeModelResolver.")


class GaussianProcess(SignalModel):

name = "GaussianProcessGazeModel"
reshaper = GazeReshaper()

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs):
self.ready_to_predict = False
self.acc = None

Expand Down Expand Up @@ -212,7 +236,7 @@ class GMCentralized(SignalModel):
reshaper = GazeReshaper()
name = "gaze_model_combined"

def __init__(self, num_components=4, random_state=0, *args, **kwargs):
def __init__(self, num_components=4, random_state=0, *args, **kwargs):
self.num_components = num_components # number of gaussians to fit
self.random_state = random_state
self.acc = None
Expand Down
6 changes: 3 additions & 3 deletions bcipy/signal/model/offline_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def analyze_erp(
device_spec: DeviceSpec,
data_folder: str,
estimate_balanced_acc: bool,
save_figures: bool=False,
show_figures: bool=False) -> SignalModel:
save_figures: bool = False,
show_figures: bool = False) -> SignalModel:
"""Analyze ERP data and return/save the ERP model.
Extract relevant information from raw data object.
Extract timing information from trigger file.
Expand Down Expand Up @@ -233,7 +233,7 @@ def analyze_gaze(
parameters: Parameters,
device_spec: DeviceSpec,
data_folder: str,
model_type: str="GaussianProcess",
model_type: str = "GaussianProcess",
symbol_set: List[str] = alphabet()) -> SignalModel:
"""Analyze gaze data and return/save the gaze model.
Extract relevant information from gaze data object.
Expand Down
56 changes: 31 additions & 25 deletions bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,45 @@
import shutil
import tempfile
import unittest
from pathlib import Path

import numpy as np
from bcipy.signal.model.gaussian_mixture import (
GaussianProcess,
GMCentralized,
GMIndividual,
GazeModelResolver
)

from bcipy.signal.model import GaussianProcess

class TestGazeModelResolver(unittest.TestCase):

class ModelSetup(unittest.TestCase):
@classmethod
def setUpClass(cls):
np.random.seed(0)
cls.tmp_dir = Path(tempfile.mkdtemp())
def test_resolve(self):
response = GazeModelResolver.resolve('GaussianProcess')
self.assertIsInstance(response, GaussianProcess)

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmp_dir)
def test_resolve_centralized(self):
response = GazeModelResolver.resolve('GMCentralized')
self.assertIsInstance(response, GMCentralized)

def test_resolve_individual(self):
response = GazeModelResolver.resolve('GMIndividual')
self.assertIsInstance(response, GMIndividual)

class TestGaussianMixtureInternals(ModelSetup):
@classmethod
def setUpClass(cls):
super().setUpClass()
def test_resolve_raises_value_error_on_invalid_model(self):
with self.assertRaises(ValueError):
GazeModelResolver.resolve('InvalidModel')

@classmethod
def tearDownClass(cls):
super().tearDownClass

def setUp(self):
np.random.seed(0)
self.model = GaussianProcess()
class TestModelInit(unittest.TestCase):

def test_predict(self):
...
def test_gaussian_process(self):
model = GaussianProcess()
self.assertIsInstance(model, GaussianProcess)

def test_centrailized(self):
model = GMCentralized()
self.assertIsInstance(model, GMCentralized)

def test_individual(self):
model = GMIndividual()
self.assertIsInstance(model, GMIndividual)


if __name__ == "__main__":
Expand Down
20 changes: 13 additions & 7 deletions bcipy/signal/tests/model/test_offline_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def test_model_auc(self):
def test_model_metadata_loads(self):
self.assertIsNotNone(self.model.metadata)
self.assertAlmostEqual(
self.model.metadata.auc, self.get_auc(list(expected_output_folder.glob("model_eeg_*.pkl"))[0].name), delta=0.005)
self.model.metadata.auc, self.get_auc(
list(expected_output_folder.glob("model_eeg_*.pkl"))[0].name), delta=0.005)
self.assertIsNotNone(self.model.metadata.transform)


Expand All @@ -104,12 +105,16 @@ def setUpClass(cls):

# expand eyetracker_data_tobii.csv.gz into tmp_dir
with gzip.open(file_loc, "rb") as f_source:
with open(cls.tmp_dir / f"eyetracker_data_tobii-p0.csv", "wb") as f_dest:
with open(cls.tmp_dir / "eyetracker_data_tobii-p0.csv", "wb") as f_dest:
shutil.copyfileobj(f_source, f_dest)

# copy the other required inputs into tmp_dir
shutil.copyfile(eye_tracking_input_folder / TRIGGER_FILENAME, cls.tmp_dir / TRIGGER_FILENAME)
shutil.copyfile(eye_tracking_input_folder / DEFAULT_DEVICE_SPEC_FILENAME, cls.tmp_dir / DEFAULT_DEVICE_SPEC_FILENAME)
shutil.copyfile(
eye_tracking_input_folder /
DEFAULT_DEVICE_SPEC_FILENAME,
cls.tmp_dir /
DEFAULT_DEVICE_SPEC_FILENAME)

params_path = pwd.parent.parent.parent / "parameters" / DEFAULT_PARAMETERS_FILENAME
cls.parameters = load_json_parameters(params_path, value_cast=True)
Expand Down Expand Up @@ -149,8 +154,8 @@ class TestOfflineAnalysisFusion(unittest.TestCase):
This test is slow because it runs the full offline analysis pipeline and compares its' output
to a set of expected outputs. The expected outputs are generated by running the pipeline on
the same input data and saving them to the expected_output_folder. See the main `signal` module
README.md for more information
README.md for more information
The test will fail if the acc is not within 0.005 of the expected acc or if the auc is not within
0.005 of the expected auc.
"""
Expand All @@ -173,7 +178,7 @@ def setUpClass(cls):

# expand eyetracker_data_tobii.csv.gz into tmp_dir
with gzip.open(eye_tracking_file_loc, "rb") as f_source:
with open(cls.tmp_dir / f"eyetracker_data_tobii-p0.csv", "wb") as f_dest:
with open(cls.tmp_dir / "eyetracker_data_tobii-p0.csv", "wb") as f_dest:
shutil.copyfileobj(f_source, f_dest)

# copy the other required inputs into tmp_dir
Expand Down Expand Up @@ -206,7 +211,7 @@ def get_acc(model_filename):
raise ValueError()
return None
return float(match[1])

@staticmethod
def get_auc(model_filename):
match = re.search("^model_eeg_([.0-9]+).pkl$", model_filename)
Expand All @@ -224,5 +229,6 @@ def test_model_auc(self):
found_auc = self.get_auc(list(self.tmp_dir.glob("model_eeg_*.pkl"))[0].name)
self.assertAlmostEqual(expected_auc, found_auc, delta=0.005)


if __name__ == "__main__":
unittest.main()

0 comments on commit 34845c9

Please sign in to comment.