diff --git a/.coveragerc b/.coveragerc index 437e6b79..3d8cc939 100644 --- a/.coveragerc +++ b/.coveragerc @@ -11,6 +11,8 @@ omit = */gui/experiments/* */gui/viewer/* */gui/BCInterface.py + */signal/model/offline_analysis.py + */signal/evaluate/fusion.py [report] exclude_lines = diff --git a/bcipy/signal/evaluate/fusion.py b/bcipy/signal/evaluate/fusion.py index f18c90a3..6d91cddb 100644 --- a/bcipy/signal/evaluate/fusion.py +++ b/bcipy/signal/evaluate/fusion.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="assignment,var-annotated" import numpy as np from sklearn.utils import resample from typing import List, Tuple @@ -33,7 +34,7 @@ def calculate_eeg_gaze_fusion_acc( data_folder: str, n_iterations: int = 10, eeg_model: SignalModel = PcaRdaKdeModel, - gaze_model: SignalModel = GaussianProcess) -> Tuple[float, float, float]: + gaze_model: SignalModel = GaussianProcess) -> Tuple[List[float], List[float], List[float]]: """ Preprocess the EEG and gaze data. Calculate the accuracy of the fusion of EEG and Gaze models. Args: diff --git a/bcipy/signal/generator/generator.py b/bcipy/signal/generator/generator.py index 62b03e2c..b9f411cc 100644 --- a/bcipy/signal/generator/generator.py +++ b/bcipy/signal/generator/generator.py @@ -1,13 +1,17 @@ import numpy as np -from typing import List, NamedTuple +from typing import List -def truncate_float(num: float, precision: float) -> float: +def truncate_float(num: float, precision: int) -> float: """Truncate a float to a given precision.""" return float(str(num)[:precision]) -def gen_random_data(low, high, channel_count, precision=8) -> List[float]: +def gen_random_data( + low: float, + high: float, + channel_count: int, + precision: int = 8) -> List[float]: """Generate random data. This function generates random data for testing purposes within a given range. The data is diff --git a/bcipy/signal/model/base_model.py b/bcipy/signal/model/base_model.py index 02e8790e..af9f7028 100644 --- a/bcipy/signal/model/base_model.py +++ b/bcipy/signal/model/base_model.py @@ -24,7 +24,7 @@ def __repr__(self): return f"SignalModelMetadata(device_spec={self.device_spec}, transform={self.transform}, " \ f"evidence_type={self.evidence_type}, auc={self.auc}, accuracy={self.acc}, " \ f"balanced_accuracy={self.balanced_accuracy})" - + def __str__(self): return self.__repr__() diff --git a/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py b/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py index 77768c8f..634e131b 100644 --- a/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py +++ b/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py @@ -13,6 +13,30 @@ warnings.filterwarnings("ignore") # ignore DeprecationWarnings from tensorflow +class GazeModelType(Enum): + """Enum for gaze model types""" + GAUSSIAN_PROCESS = "GaussianProcess" + GM_INDIVIDUAL = "GMIndividual" + GM_CENTRALIZED = "GMCentralized" + + def __str__(self): + return self.value + + def __repr__(self): + return self.value + + @staticmethod + def from_str(label: str): + if label == "GaussianProcess": + return GazeModelType.GAUSSIAN_PROCESS + elif label == "GMIndividual": + return GazeModelType.GM_INDIVIDUAL + elif label == "GMCentralized": + return GazeModelType.GM_CENTRALIZED + else: + raise ValueError(f"Model type {label} not recognized.") + + class GazeModelResolver: """Factory class for gaze models @@ -22,16 +46,16 @@ class GazeModelResolver: @staticmethod def resolve(model_type: str, *args, **kwargs) -> SignalModel: """Load a gaze model from the provided path.""" - if model_type == "GaussianProcess": - model = GaussianProcess(*args, **kwargs) - elif model_type == "GMIndividual": - model = GMIndividual(*args, **kwargs) - elif model_type == "GMCentralized": - model = GMCentralized( *args, **kwargs) + model_type = GazeModelType.from_str(model_type) + if model_type == GazeModelType.GAUSSIAN_PROCESS: + return GaussianProcess(*args, **kwargs) + elif model_type == GazeModelType.GM_INDIVIDUAL: + return GMIndividual(*args, **kwargs) + elif model_type == GazeModelType.GM_CENTRALIZED: + return GMCentralized(*args, **kwargs) else: - raise ValueError(f"Model type {model_type} not recognized.") - - return model + raise ValueError( + f"Model type {model_type} not able to resolve. Not registered in GazeModelResolver.") class GaussianProcess(SignalModel): @@ -39,7 +63,7 @@ class GaussianProcess(SignalModel): name = "GaussianProcessGazeModel" reshaper = GazeReshaper() - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): self.ready_to_predict = False self.acc = None @@ -212,7 +236,7 @@ class GMCentralized(SignalModel): reshaper = GazeReshaper() name = "gaze_model_combined" - def __init__(self, num_components=4, random_state=0, *args, **kwargs): + def __init__(self, num_components=4, random_state=0, *args, **kwargs): self.num_components = num_components # number of gaussians to fit self.random_state = random_state self.acc = None diff --git a/bcipy/signal/model/offline_analysis.py b/bcipy/signal/model/offline_analysis.py index 920e10e1..ab0b3f9d 100644 --- a/bcipy/signal/model/offline_analysis.py +++ b/bcipy/signal/model/offline_analysis.py @@ -79,8 +79,8 @@ def analyze_erp( device_spec: DeviceSpec, data_folder: str, estimate_balanced_acc: bool, - save_figures: bool=False, - show_figures: bool=False) -> SignalModel: + save_figures: bool = False, + show_figures: bool = False) -> SignalModel: """Analyze ERP data and return/save the ERP model. Extract relevant information from raw data object. Extract timing information from trigger file. @@ -233,7 +233,7 @@ def analyze_gaze( parameters: Parameters, device_spec: DeviceSpec, data_folder: str, - model_type: str="GaussianProcess", + model_type: str = "GaussianProcess", symbol_set: List[str] = alphabet()) -> SignalModel: """Analyze gaze data and return/save the gaze model. Extract relevant information from gaze data object. diff --git a/bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py b/bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py index fbea8274..df373f28 100644 --- a/bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py +++ b/bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py @@ -1,39 +1,45 @@ -import shutil -import tempfile import unittest -from pathlib import Path -import numpy as np +from bcipy.signal.model.gaussian_mixture import ( + GaussianProcess, + GMCentralized, + GMIndividual, + GazeModelResolver +) -from bcipy.signal.model import GaussianProcess +class TestGazeModelResolver(unittest.TestCase): -class ModelSetup(unittest.TestCase): - @classmethod - def setUpClass(cls): - np.random.seed(0) - cls.tmp_dir = Path(tempfile.mkdtemp()) + def test_resolve(self): + response = GazeModelResolver.resolve('GaussianProcess') + self.assertIsInstance(response, GaussianProcess) - @classmethod - def tearDownClass(cls): - shutil.rmtree(cls.tmp_dir) + def test_resolve_centralized(self): + response = GazeModelResolver.resolve('GMCentralized') + self.assertIsInstance(response, GMCentralized) + def test_resolve_individual(self): + response = GazeModelResolver.resolve('GMIndividual') + self.assertIsInstance(response, GMIndividual) -class TestGaussianMixtureInternals(ModelSetup): - @classmethod - def setUpClass(cls): - super().setUpClass() + def test_resolve_raises_value_error_on_invalid_model(self): + with self.assertRaises(ValueError): + GazeModelResolver.resolve('InvalidModel') - @classmethod - def tearDownClass(cls): - super().tearDownClass - def setUp(self): - np.random.seed(0) - self.model = GaussianProcess() +class TestModelInit(unittest.TestCase): - def test_predict(self): - ... + def test_gaussian_process(self): + model = GaussianProcess() + self.assertIsInstance(model, GaussianProcess) + + def test_centrailized(self): + model = GMCentralized() + self.assertIsInstance(model, GMCentralized) + + def test_individual(self): + model = GMIndividual() + self.assertIsInstance(model, GMIndividual) if __name__ == "__main__": diff --git a/bcipy/signal/tests/model/test_offline_analysis.py b/bcipy/signal/tests/model/test_offline_analysis.py index 918fa7ad..5f0a8502 100644 --- a/bcipy/signal/tests/model/test_offline_analysis.py +++ b/bcipy/signal/tests/model/test_offline_analysis.py @@ -79,7 +79,8 @@ def test_model_auc(self): def test_model_metadata_loads(self): self.assertIsNotNone(self.model.metadata) self.assertAlmostEqual( - self.model.metadata.auc, self.get_auc(list(expected_output_folder.glob("model_eeg_*.pkl"))[0].name), delta=0.005) + self.model.metadata.auc, self.get_auc( + list(expected_output_folder.glob("model_eeg_*.pkl"))[0].name), delta=0.005) self.assertIsNotNone(self.model.metadata.transform) @@ -104,12 +105,16 @@ def setUpClass(cls): # expand eyetracker_data_tobii.csv.gz into tmp_dir with gzip.open(file_loc, "rb") as f_source: - with open(cls.tmp_dir / f"eyetracker_data_tobii-p0.csv", "wb") as f_dest: + with open(cls.tmp_dir / "eyetracker_data_tobii-p0.csv", "wb") as f_dest: shutil.copyfileobj(f_source, f_dest) # copy the other required inputs into tmp_dir shutil.copyfile(eye_tracking_input_folder / TRIGGER_FILENAME, cls.tmp_dir / TRIGGER_FILENAME) - shutil.copyfile(eye_tracking_input_folder / DEFAULT_DEVICE_SPEC_FILENAME, cls.tmp_dir / DEFAULT_DEVICE_SPEC_FILENAME) + shutil.copyfile( + eye_tracking_input_folder / + DEFAULT_DEVICE_SPEC_FILENAME, + cls.tmp_dir / + DEFAULT_DEVICE_SPEC_FILENAME) params_path = pwd.parent.parent.parent / "parameters" / DEFAULT_PARAMETERS_FILENAME cls.parameters = load_json_parameters(params_path, value_cast=True) @@ -149,8 +154,8 @@ class TestOfflineAnalysisFusion(unittest.TestCase): This test is slow because it runs the full offline analysis pipeline and compares its' output to a set of expected outputs. The expected outputs are generated by running the pipeline on the same input data and saving them to the expected_output_folder. See the main `signal` module - README.md for more information - + README.md for more information + The test will fail if the acc is not within 0.005 of the expected acc or if the auc is not within 0.005 of the expected auc. """ @@ -173,7 +178,7 @@ def setUpClass(cls): # expand eyetracker_data_tobii.csv.gz into tmp_dir with gzip.open(eye_tracking_file_loc, "rb") as f_source: - with open(cls.tmp_dir / f"eyetracker_data_tobii-p0.csv", "wb") as f_dest: + with open(cls.tmp_dir / "eyetracker_data_tobii-p0.csv", "wb") as f_dest: shutil.copyfileobj(f_source, f_dest) # copy the other required inputs into tmp_dir @@ -206,7 +211,7 @@ def get_acc(model_filename): raise ValueError() return None return float(match[1]) - + @staticmethod def get_auc(model_filename): match = re.search("^model_eeg_([.0-9]+).pkl$", model_filename) @@ -224,5 +229,6 @@ def test_model_auc(self): found_auc = self.get_auc(list(self.tmp_dir.glob("model_eeg_*.pkl"))[0].name) self.assertAlmostEqual(expected_auc, found_auc, delta=0.005) + if __name__ == "__main__": unittest.main()