diff --git a/aix360/datasets/climate_dataset.py b/aix360/datasets/climate_dataset.py index 7531fb0..70f63b5 100644 --- a/aix360/datasets/climate_dataset.py +++ b/aix360/datasets/climate_dataset.py @@ -23,7 +23,10 @@ class ClimateDataset: """ - def __init__(self): + def __init__( + self, + url: str = None, + ): self.data_folder = os.path.realpath( os.path.join( os.path.dirname(os.path.realpath(__file__)), "../data", "climate_data" @@ -32,7 +35,11 @@ def __init__(self): self.data_file = os.path.realpath( os.path.join(self.data_folder, "jena_climate_2009_2016.csv") ) - climate_data_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip" + climate_data_url = ( + url + if url is not None + else "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip" + ) self.input_length = 500 # download data diff --git a/aix360/datasets/diabetes_dataset.py b/aix360/datasets/diabetes_dataset.py index f857bcd..b2e4d5f 100644 --- a/aix360/datasets/diabetes_dataset.py +++ b/aix360/datasets/diabetes_dataset.py @@ -18,7 +18,10 @@ class DiabetesDataset: """ - def __init__(self): + def __init__( + self, + url: str = None, + ): self.data_folder = os.path.realpath( os.path.join( os.path.dirname(os.path.realpath(__file__)), "../data", "diabetes_data" @@ -27,7 +30,11 @@ def __init__(self): self.data_file = os.path.realpath( os.path.join(self.data_folder, "diabetes.csv") ) - diabetes_url = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt" + diabetes_url = ( + url + if url is not None + else "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt" + ) if not os.path.exists(self.data_file): response = requests.get(diabetes_url) diff --git a/aix360/datasets/ford_dataset.py b/aix360/datasets/ford_dataset.py index 9d01c3d..acabeb5 100644 --- a/aix360/datasets/ford_dataset.py +++ b/aix360/datasets/ford_dataset.py @@ -27,7 +27,7 @@ class FordDataset: """ - def __init__(self, category_a: bool = True): + def __init__(self, url: str = None, category_a: bool = True): self.data_folder = os.path.realpath( os.path.join( os.path.dirname(os.path.realpath(__file__)), "../data", "ford_data" @@ -41,8 +41,12 @@ def __init__(self, category_a: bool = True): ) self.category = "A" if category_a else "B" - ford_data_url = "http://timeseriesclassification.com/ClassificationDownloads/Ford{}.zip".format( - self.category + ford_data_url = ( + url + if url is not None + else "https://timeseriesclassification.com/aeon-toolkit/Ford{}.zip".format( + self.category + ) ) self.input_length = 500 diff --git a/aix360/datasets/sunspots_dataset.py b/aix360/datasets/sunspots_dataset.py index 24d4470..249860f 100644 --- a/aix360/datasets/sunspots_dataset.py +++ b/aix360/datasets/sunspots_dataset.py @@ -23,7 +23,10 @@ class SunspotDataset: """ - def __init__(self): + def __init__( + self, + url: str = None, + ): self.data_folder = os.path.realpath( os.path.join( os.path.dirname(os.path.realpath(__file__)), "../data", "sunspots_data" @@ -32,7 +35,11 @@ def __init__(self): self.data_file = os.path.realpath( os.path.join(self.data_folder, "sunspots.csv") ) - sunspots_url = "https://raw.githubusercontent.com/PacktPublishing/Practical-Time-Series-Analysis/master/Data%20Files/monthly-sunspot-number-zurich-17.csv" + sunspots_url = ( + url + if url is not None + else "https://raw.githubusercontent.com/PacktPublishing/Practical-Time-Series-Analysis/master/Data%20Files/monthly-sunspot-number-zurich-17.csv" + ) if not os.path.exists(self.data_file): response = requests.get(sunspots_url) diff --git a/tests/tslime/test_tslime.py b/tests/tslime/test_tslime.py index f8869e4..a665686 100644 --- a/tests/tslime/test_tslime.py +++ b/tests/tslime/test_tslime.py @@ -102,3 +102,7 @@ def test_tslime(self): self.assertIn("surrogate_prediction", explanation) self.assertEqual(explanation["history_weights"].shape[0], relevant_history) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tssaliency/test_tssaliency.py b/tests/tssaliency/test_tssaliency.py index 2ab2f62..281883d 100644 --- a/tests/tssaliency/test_tssaliency.py +++ b/tests/tssaliency/test_tssaliency.py @@ -97,3 +97,7 @@ def test_tssaliency(self): self.assertIn("base_value_prediction", explanation) self.assertEqual(explanation["saliency"].shape, test_window.shape) + + +if __name__ == "__main__": + unittest.main()