-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
1,592 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import unittest | ||
from os import remove | ||
|
||
from numpy import array | ||
|
||
from ecnet.models.mlp import MultilayerPerceptron | ||
|
||
|
||
class TestMLP(unittest.TestCase): | ||
|
||
def test_mlp(self): | ||
|
||
print('\nUNIT TEST: MultilayerPerceptron') | ||
mlp = MultilayerPerceptron() | ||
mlp.add_layer(2, 'relu', 2) | ||
mlp.add_layer(5, 'relu') | ||
mlp.add_layer(5, 'relu') | ||
mlp.add_layer(1, 'linear') | ||
mlp.fit(array([[1, 0], [0, 1]]), array([[1], [0]]), epochs=15000) | ||
|
||
results = mlp.use(array([[1, 0], [0, 1]])) | ||
self.assertAlmostEqual(results[0][0], 1, 3) | ||
self.assertAlmostEqual(results[1][0], 0, 3) | ||
|
||
mlp.save() | ||
mlp_saved = MultilayerPerceptron() | ||
mlp_saved.load() | ||
results = mlp_saved.use(array([[1, 0], [0, 1]])) | ||
self.assertAlmostEqual(results[0][0], 1, 3) | ||
self.assertAlmostEqual(results[1][0], 0, 3) | ||
remove('model.h5') | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import unittest | ||
|
||
from ecnet.utils.data_utils import DataFrame | ||
from ecnet.tasks.limit_inputs import limit_rforest | ||
|
||
|
||
class TestLimit(unittest.TestCase): | ||
|
||
def test_limit_rforest(self): | ||
|
||
print('\nUNIT TEST: limit_rforest') | ||
df = DataFrame('cn_model_v1.0.csv') | ||
df_res = limit_rforest(df, 2) | ||
self.assertEqual(len(df._input_names), 2) | ||
self.assertIn(df_res._input_names[0], df._input_names) | ||
self.assertIn(df_res._input_names[1], df._input_names) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import unittest | ||
from copy import deepcopy | ||
|
||
from ecnet.utils.data_utils import DataFrame | ||
from ecnet.tasks.remove_outliers import remove_outliers | ||
|
||
|
||
class TestRemoveOutliers(unittest.TestCase): | ||
|
||
def test_limit_rforest(self): | ||
|
||
print('UNIT TEST: remove_outliers') | ||
df = DataFrame('cn_model_v1.0.csv') | ||
df_res = remove_outliers(deepcopy(df)) | ||
self.assertLess(len(df_res), len(df)) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import unittest | ||
|
||
from ecnet.utils.data_utils import DataFrame | ||
from ecnet.utils.server_utils import default_config | ||
from ecnet.tasks.tuning import tune_hyperparameters | ||
|
||
|
||
class TestTune(unittest.TestCase): | ||
|
||
def test_tune_hyperparameters(self): | ||
|
||
print('\nUNIT TEST: tune_hyperparameters') | ||
df = DataFrame('cn_model_v1.0.csv') | ||
df.create_sets(random=True) | ||
config = default_config() | ||
new_hp = tune_hyperparameters(df, config, 2, 1) | ||
self.assertGreaterEqual(new_hp['beta_1'], 0) | ||
self.assertLessEqual(new_hp['beta_1'], 1) | ||
self.assertGreaterEqual(new_hp['beta_2'], 0) | ||
self.assertLessEqual(new_hp['beta_2'], 1) | ||
self.assertGreaterEqual(new_hp['decay'], 0) | ||
self.assertLessEqual(new_hp['decay'], 1) | ||
self.assertGreaterEqual(new_hp['epsilon'], 0) | ||
self.assertLessEqual(new_hp['epsilon'], 1) | ||
self.assertGreaterEqual(new_hp['learning_rate'], 0) | ||
self.assertLessEqual(new_hp['learning_rate'], 1) | ||
self.assertGreaterEqual(new_hp['hidden_layers'][0][0], 1) | ||
self.assertLessEqual(new_hp['hidden_layers'][0][0], 50) | ||
self.assertGreaterEqual(new_hp['hidden_layers'][1][0], 1) | ||
self.assertLessEqual(new_hp['hidden_layers'][1][0], 50) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import unittest | ||
|
||
from models.test_mlp import TestMLP | ||
from tasks.test_limit_inputs import TestLimit | ||
from tasks.test_remove_outliers import TestRemoveOutliers | ||
from tasks.test_tune_hyperparameters import TestTune | ||
from tools.test_conversions import TestConversions | ||
from tools.test_database import TestDatabase | ||
from tools.test_project import TestUseProject | ||
from utils.test_data_utils import TestDataUtils | ||
from utils.test_error_utils import TestErrors | ||
from utils.test_server_utils import TestServerUtils | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import unittest | ||
from os import remove | ||
from csv import DictReader | ||
|
||
import ecnet.tools.conversions as conversions | ||
|
||
|
||
class TestConversions(unittest.TestCase): | ||
|
||
def test_get_smiles(self): | ||
|
||
print('\nUNIT TEST: get_smiles') | ||
smiles = conversions.get_smiles('Propane') | ||
self.assertIn('CCC', smiles) | ||
|
||
def test_smiles_to_descriptors(self): | ||
|
||
print('\nUNIT TEST: smiles_to_descriptors') | ||
with open('smiles.smi', 'w') as smi_file: | ||
smi_file.write('CCC') | ||
smi_file.close() | ||
conversions.smiles_to_descriptors('smiles.smi', 'desc.csv') | ||
with open('desc.csv', 'r') as desc_file: | ||
reader = DictReader(desc_file) | ||
for row in reader: | ||
mol_row = row | ||
self.assertEqual(len(list(mol_row.keys())), 1876) | ||
desc_file.close() | ||
remove('smiles.smi') | ||
remove('desc.csv') | ||
|
||
def test_smiles_to_mdl(self): | ||
|
||
print('\nUNIT TEST: smiles_to_mdl') | ||
with open('smiles.smi', 'w') as smi_file: | ||
smi_file.write('CCC') | ||
smi_file.close() | ||
conversions.smiles_to_mdl('smiles.smi', 'mdl.mdl') | ||
with open('mdl.mdl', 'r') as mdl_file: | ||
self.assertGreater(len(mdl_file.read()), 0) | ||
mdl_file.close() | ||
remove('smiles.smi') | ||
remove('mdl.mdl') | ||
|
||
def test_mdl_to_descriptors(self): | ||
|
||
print('\nUNIT TEST: mdl_to_descriptors') | ||
with open('smiles.smi', 'w') as smi_file: | ||
smi_file.write('CCC') | ||
smi_file.close() | ||
conversions.smiles_to_mdl('smiles.smi', 'mdl.mdl') | ||
conversions.mdl_to_descriptors('mdl.mdl', 'desc.csv') | ||
with open('desc.csv', 'r') as desc_file: | ||
reader = DictReader(desc_file) | ||
for row in reader: | ||
mol_row = row | ||
self.assertEqual(len(list(mol_row.keys())), 1876) | ||
desc_file.close() | ||
remove('smiles.smi') | ||
remove('mdl.mdl') | ||
remove('desc.csv') | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
from os import remove | ||
|
||
from ecnet.utils.data_utils import DataFrame | ||
import ecnet.tools.database as database | ||
|
||
|
||
class TestDatabase(unittest.TestCase): | ||
|
||
def test_create_db(self): | ||
|
||
print('\nUNIT TEST: create_db') | ||
with open('smiles.smi', 'w') as smi_file: | ||
smi_file.write('CCC') | ||
smi_file.close() | ||
database.create_db('smiles.smi', 'database.csv', form='smiles') | ||
df = DataFrame('database.csv') | ||
self.assertEqual(len(df), 1) | ||
self.assertEqual(len(df._input_names), 1875) | ||
remove('smiles.smi') | ||
remove('database.csv') | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import unittest | ||
from os import remove | ||
|
||
from ecnet import Server | ||
from ecnet.tools.project import predict | ||
|
||
|
||
class TestUseProject(unittest.TestCase): | ||
|
||
def test_predict(self): | ||
|
||
print('\nUNIT TEST: project.predict') | ||
sv = Server() | ||
sv.load_data('cn_model_v2.0.csv') | ||
sv.create_project('test_project', 1, 1) | ||
sv.train() | ||
sv.save_project() | ||
|
||
with open('smiles.smi', 'w') as smi_file: | ||
smi_file.write('CCC') | ||
smi_file.close() | ||
predict('smiles.smi', 'results.csv', 'test_project', form='smiles') | ||
with open('results.csv', 'r') as res_file: | ||
self.assertGreater(len(res_file.read()), 0) | ||
res_file.close() | ||
remove('smiles.smi') | ||
remove('results.csv') | ||
remove('test_project.prj') | ||
remove('config.yml') | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import unittest | ||
from os import remove | ||
|
||
import ecnet.utils.data_utils as data_utils | ||
|
||
|
||
class TestDataUtils(unittest.TestCase): | ||
|
||
def test_df_init(self): | ||
|
||
print('\nUNIT TEST: DataFrame init') | ||
df = data_utils.DataFrame('cn_model_v1.0.csv') | ||
self.assertEqual(len(df._string_names), 7) | ||
self.assertEqual(len(df._group_names), 1) | ||
self.assertEqual(len(df._target_names), 1) | ||
self.assertEqual(len(df._input_names), 15) | ||
self.assertEqual(len(df), 482) | ||
|
||
def test_set_creation(self): | ||
|
||
print('\nUNIT TEST: DataFrame set creation') | ||
df = data_utils.DataFrame('cn_model_v1.0.csv') | ||
|
||
df.create_sets() | ||
self.assertEqual(len(df.learn_set), 329) | ||
self.assertEqual(len(df.valid_set), 118) | ||
self.assertEqual(len(df.test_set), 35) | ||
|
||
df.create_sets(random=True, split=[0.7, 0.2, 0.1]) | ||
self.assertEqual(len(df.learn_set), 337) | ||
self.assertEqual(len(df.valid_set), 96) | ||
self.assertEqual(len(df.test_set), 49) | ||
|
||
def test_normalize(self): | ||
|
||
print('\nUNIT TEST: DataFrame normalize') | ||
df = data_utils.DataFrame('cn_model_v1.0.csv') | ||
df.normalize() | ||
df.create_sets(random=True) | ||
pd = df.package_sets() | ||
for entry in pd.learn_x: | ||
for val in entry: | ||
self.assertGreaterEqual(val, 0) | ||
self.assertLessEqual(val, 1) | ||
for entry in pd.valid_x: | ||
for val in entry: | ||
self.assertGreaterEqual(val, 0) | ||
self.assertLessEqual(val, 1) | ||
for entry in pd.test_x: | ||
for val in entry: | ||
self.assertGreaterEqual(val, 0) | ||
self.assertLessEqual(val, 1) | ||
|
||
def test_shuffle(self): | ||
|
||
print('\nUNIT TEST: DataFrame shuffle') | ||
df = data_utils.DataFrame('cn_model_v1.0.csv') | ||
|
||
df.shuffle(sets='all', split=[0.7, 0.2, 0.1]) | ||
self.assertEqual(len(df.learn_set), 337) | ||
self.assertEqual(len(df.valid_set), 96) | ||
self.assertEqual(len(df.test_set), 49) | ||
|
||
df.shuffle(sets='train', split=[0.7, 0.2, 0.1]) | ||
self.assertEqual(len(df.learn_set), 337) | ||
self.assertEqual(len(df.valid_set), 96) | ||
self.assertEqual(len(df.test_set), 49) | ||
|
||
def test_package_sets(self): | ||
|
||
print('\nUNIT TEST: DataFrame package_sets') | ||
df = data_utils.DataFrame('cn_model_v1.0.csv') | ||
df.shuffle(sets='all', split=[0.7, 0.2, 0.1]) | ||
|
||
pd = df.package_sets() | ||
self.assertEqual(len(pd.learn_x), 337) | ||
self.assertEqual(len(pd.valid_x), 96) | ||
self.assertEqual(len(pd.test_x), 49) | ||
for entry in pd.learn_x: | ||
self.assertEqual(len(entry), 15) | ||
for entry in pd.valid_x: | ||
self.assertEqual(len(entry), 15) | ||
for entry in pd.test_x: | ||
self.assertEqual(len(entry), 15) | ||
self.assertEqual(len(pd.learn_y), 337) | ||
self.assertEqual(len(pd.valid_y), 96) | ||
self.assertEqual(len(pd.test_y), 49) | ||
for entry in pd.learn_y: | ||
self.assertEqual(len(entry), 1) | ||
for entry in pd.valid_y: | ||
self.assertEqual(len(entry), 1) | ||
for entry in pd.test_y: | ||
self.assertEqual(len(entry), 1) | ||
|
||
def test_set_inputs(self): | ||
|
||
print('\nUNIT TEST: DataFrame set_inputs') | ||
df = data_utils.DataFrame('cn_model_v1.0.csv') | ||
df.set_inputs(['PHI', 'piPC05']) | ||
self.assertEqual(len(df._input_names), 2) | ||
df.create_sets(random=True) | ||
pd = df.package_sets() | ||
self.assertEqual(len(pd.learn_x[0]), 2) | ||
|
||
def test_save_df(self): | ||
|
||
print('\nUNIT TEST: DataFrame save') | ||
df = data_utils.DataFrame('cn_model_v1.0.csv') | ||
df.save('cn_test_save.csv') | ||
df_new = data_utils.DataFrame('cn_test_save.csv') | ||
self.assertEqual( | ||
len(df), | ||
len(df_new) | ||
) | ||
self.assertEqual( | ||
df._string_names, | ||
df_new._string_names | ||
) | ||
self.assertEqual( | ||
df._group_names, | ||
df_new._group_names | ||
) | ||
self.assertEqual( | ||
df._target_names, | ||
df_new._target_names | ||
) | ||
self.assertEqual( | ||
df._input_names, | ||
df_new._input_names | ||
) | ||
remove('cn_test_save.csv') | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
Oops, something went wrong.