Skip to content

Commit

Permalink
Merge pull request #34 from nlesc-nano/devel2
Browse files Browse the repository at this point in the history
Devel2
  • Loading branch information
BvB93 authored Jun 13, 2019
2 parents affdbb5 + 42fa389 commit 152cb6c
Show file tree
Hide file tree
Showing 15 changed files with 247 additions and 159 deletions.
4 changes: 2 additions & 2 deletions CAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .data_handling import (
Database, mol_to_file,
read_mol, set_prop,
read_mol, set_mol_prop,
sanitize_optional, sanitize_input_mol, sanitize_path
)

Expand All @@ -32,7 +32,7 @@
'init_qd_opt', 'init_ligand_opt', 'init_qd_construction', 'init_ligand_anchoring',

'Database', 'mol_to_file',
'read_mol', 'set_prop',
'read_mol', 'set_mol_prop',
'sanitize_optional', 'sanitize_input_mol', 'sanitize_path',

'prep',
Expand Down
2 changes: 1 addition & 1 deletion CAT/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.1'
__version__ = '0.4.2'
10 changes: 4 additions & 6 deletions CAT/analysis/ligand_bde.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,12 +564,10 @@ def get_lig_core_combinations(xy, res_list, lig_count=2):
:parameter int lig_count: The number of ligand (*n*) in XYn.
:return:
"""
ret = {}
dict_ = {}
for core, lig in xy.T:
try:
ret[res_list[0][core].id].append([at.id for at in res_list[lig]])
dict_[res_list[0][core].id].append([at.id for at in res_list[lig]])
except KeyError:
ret[res_list[0][core].id] = [[at.id for at in res_list[lig]]]
for i in ret:
ret[i] = combinations(ret[i], lig_count)
return ret
dict_[res_list[0][core].id] = [[at.id for at in res_list[lig]]]
return {k: combinations(v, lig_count) for k, v in dict_.items()}
4 changes: 2 additions & 2 deletions CAT/attachment/ligand_attach.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ def _get_indices(mol, index):
break
except ValueError:
pass
k += i
k += i - 1

# Append and return
ref_name = mol[k].properties.pdb_info.Name
ref_name = mol[k+1].properties.pdb_info.Name
for i, at in enumerate(mol.atoms[k:], k+1):
if at.properties.pdb_info.Name == ref_name:
at.properties.anchor = True
Expand Down
36 changes: 25 additions & 11 deletions CAT/attachment/ligand_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,33 @@ def init_ligand_opt(ligand_df, arg):
# Print messages
print(get_time() + ligand.properties.name + message)
if lig_new:
ligand_df.loc[idx, 'mol'] = lig_new
if len(lig_new) == 1: # pd.DataFrame.loc has serious issues when assigning 1 molecue
idx, _ = next(ligand_df[idx].iterrows())
ligand_df.at[idx, ('mol', '')] = lig_new[0]
else:
ligand_df.loc[idx, 'mol'] = lig_new
print('')

remove_duplicates(ligand_df)

# Write newly optimized structures to the database
if 'ligand' in arg.optional.database.write and arg.optional.ligand.optimize:
recipe = Settings()
recipe['1'] = {'key': 'RDKit_' + rdkit.__version__, 'value': 'UFF'}
columns = [('formula', ''), ('hdf5 index', ''), ('settings', '1')]
database.update_csv(
ligand_df, columns=columns, job_recipe=recipe, database='ligand', overwrite=overwrite
)
path = arg.optional.ligand.dirname
mol_to_file(ligand_df['mol'], path, overwrite, arg.optional.database.mol_format)


def remove_duplicates(ligand_df):
"""Remove duplicate rows from a dataframe.
Duplicates are identified based on their index.
Performs an inplace update of **ligand_df**.
"""
# Remove duplicate ligands and sort
if ligand_df.index.duplicated().any():
idx_name = ligand_df.index.names
Expand All @@ -80,16 +104,6 @@ def init_ligand_opt(ligand_df, arg):
ligand_df.index.names = idx_name
ligand_df.sort_index(inplace=True)

# Write newly optimized structures to the database
if 'ligand' in arg.optional.database.write and arg.optional.ligand.optimize:
recipe = Settings()
recipe['1'] = {'key': 'RDKit_' + rdkit.__version__, 'value': 'UFF'}
columns = [('formula', ''), ('hdf5 index', ''), ('settings', '1')]
database.update_csv(ligand_df, columns=columns,
job_recipe=recipe, database='ligand', overwrite=overwrite)
path = arg.optional.ligand.dirname
mol_to_file(ligand_df['mol'], path, overwrite, arg.optional.database.mol_format)


@add_to_class(Molecule)
def split_bond(self, bond, atom_type='H', bond_length=1.1):
Expand Down
106 changes: 88 additions & 18 deletions CAT/data_handling/CAT_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,36 @@ def _create_yaml(path, name='job_settings.yaml'):
return path


def even_index(df1: pd.DataFrame,
df2: pd.DataFrame) -> pd.DataFrame:
"""Ensure that ``df2.index`` is a subset of ``df1.index``.
Parameters
----------
df1 : |pd.DataFrame|_
A DataFrame whose index is to-be a superset of ``df2.index``.
df2 : |pd.DataFrame|_
A DataFrame whose index is to-be a subset of ``df1.index``.
Returns
-------
|pd.DataFrame|_
A new
"""
# Figure out if ``df1.index`` is a subset of ``df2.index``
bool_ar = df2.index.isin(df1.index)
if bool_ar.all():
return df1

# Make ``df1.index`` a subset of ``df2.index``
nan_row = get_nan_row(df1)
idx = df2.index[~bool_ar]
df_tmp = pd.DataFrame(len(idx) * [nan_row], index=idx, columns=df1.columns)
return df1.append(df_tmp, sort=True)


class Database():
""" The Database class.
Expand Down Expand Up @@ -298,6 +328,7 @@ class open_yaml():
:param bool write: Whether or not the database file should be updated after
closing **self**.
"""

def __init__(self, path=None, write=True):
self.path = path or getcwd()
self.write = write
Expand Down Expand Up @@ -330,6 +361,7 @@ class open_csv_lig():
:param bool write: Whether or not the database file should be updated after
closing **self**.
"""

def __init__(self, path=None, write=True):
self.path = path or getcwd()
self.write = write
Expand All @@ -338,7 +370,9 @@ def __init__(self, path=None, write=True):
def __enter__(self):
# Open the .csv file
dtype = {'hdf5 index': int, 'formula': str, 'settings': str}
self.df = pd.read_csv(self.path, index_col=[0, 1], header=[0, 1], dtype=dtype)
self.df = Database.DF(
pd.read_csv(self.path, index_col=[0, 1], header=[0, 1], dtype=dtype)
)

# Fix the columns
idx_tups = [(i, '') if 'Unnamed' in j else (i, j) for i, j in self.df.columns]
Expand All @@ -358,15 +392,18 @@ class open_csv_qd():
:param bool write: Whether or not the database file should be updated after
closing **self**.
"""

def __init__(self, path=None, write=True):
self.path = path or getcwd()
self.write = write
self.df = None

def __enter__(self):
# Open the .csv file
dtype = {'hdf5 index': int, 'ligand count': np.int64, 'settings': str}
self.df = pd.read_csv(self.path, index_col=[0, 1, 2, 3], header=[0, 1], dtype=dtype)
dtype = {'hdf5 index': int, 'ligand count': int, 'settings': str}
self.df = Database.DF(
pd.read_csv(self.path, index_col=[0, 1, 2, 3], header=[0, 1], dtype=dtype)
)

# Fix the columns
idx_tups = [(i, '') if 'Unnamed' in j else (i, j) for i, j in self.df.columns]
Expand All @@ -379,6 +416,43 @@ def __exit__(self, type, value, traceback):
self.df.to_csv(self.path)
self.df = None

class DF(dict):
"""A mutable container for holding dataframes."""

def __init__(self, df: pd.DataFrame) -> None:
super().__init__()
super().__setitem__('df', df)

def __getattribute__(self, key):
if key == 'update_df' or (key.startswith('__') and key.endswith('__')):
return super().__getattribute__(key)
return self['df'].__getattribute__(key)

def __setattr__(self, key, value):
self['df'].__setattr__(key, value)

def __setitem__(self, key, value):
if key == 'df' and not isinstance(value, pd.DataFrame):
try:
value = value['df']
if not isinstance(value, pd.DataFrame):
raise KeyError
super().__setitem__('df', value)
except KeyError:
err = ("Instance of 'pandas.DataFrame' or 'CAT.Database.DF' expected;"
" observed type: '{}'")
raise TypeError(err.format(value.__class__.__name__))
elif key == 'df':
super().__setitem__('df', value)
else:
self['df'].__setitem__(key, value)

def __getitem__(self, key):
df = super().__getitem__('df')
if key == 'df':
return df
return df.__getitem__(key)

""" ################################# Updating the database ############################## """

def update_csv(self, df, database='ligand', columns=None, overwrite=False, job_recipe=None):
Expand Down Expand Up @@ -406,30 +480,26 @@ def update_csv(self, df, database='ligand', columns=None, overwrite=False, job_r
# Update **self.yaml**
if job_recipe is not None:
job_settings = self.update_yaml(job_recipe)
for key in job_settings:
df[('settings', key)] = job_settings[key]
for key, value in job_settings.items():
df[('settings', key)] = value

with open_csv(path, write=True) as db:
# Update **db.index**
nan_row = get_nan_row(db)
for i in df.index:
if i not in db.index:
db.at[i, :] = nan_row
db['hdf5 index'] = db['hdf5 index'].astype(int, copy=False) # Fix the data type
db['df'] = even_index(db['df'], df)

# Filter columns
if not columns:
df_columns = df.columns
else:
df_columns = columns + [i for i in df.columns if i[0] == 'settings']
df_columns = pd.Index(columns + [i for i in df.columns if i[0] == 'settings'])

# Update **db.columns**
for i in df_columns:
if i not in db.columns:
try:
db[i] = np.array((None), dtype=df[i].dtype)
except TypeError: # e.g. if csv[i] consists of the datatype np.int64
db[i] = -1
bool_ar = df_columns.isin(db.columns)
for i in df_columns[~bool_ar]:
try:
db[i] = np.array((None), dtype=df[i].dtype)
except TypeError: # e.g. if csv[i] consists of the datatype np.int64
db[i] = -1

# Update **self.hdf5**; returns a new series of indices
hdf5_series = self.update_hdf5(df, database=database, overwrite=overwrite)
Expand Down Expand Up @@ -536,7 +606,7 @@ def from_csv(self, df, database='ligand', get_mol=True, inplace=True):

# Update the *hdf5 index* column in **df**
with open_csv(path, write=False) as db:
df.update(db, overwrite=True)
df.update(db['df'], overwrite=True)
df['hdf5 index'] = df['hdf5 index'].astype(int, copy=False)

# **df** has been updated and **get_mol** = *False*
Expand Down
4 changes: 2 additions & 2 deletions CAT/data_handling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
""" Modules related to the importing, exporting and general handling of data. """

from .CAT_database import (Database, mol_to_file)
from .mol_import import (read_mol, set_prop)
from .mol_import import (read_mol, set_mol_prop)
from .input_sanitizer import (sanitize_optional, sanitize_input_mol, sanitize_path)


__all__ = [
'Database', 'mol_to_file',
'read_mol', 'set_prop',
'read_mol', 'set_mol_prop',
'sanitize_optional', 'sanitize_input_mol', 'sanitize_path'
]
37 changes: 19 additions & 18 deletions CAT/data_handling/input_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,15 @@ def sanitize_path(arg):
elif isinstance(arg.path, str):
if arg.path.lower() in ('none', '.', 'pwd', '$pwd', 'cwd'):
arg.path = os.getcwd()
else:
if not os.path.exists(arg.path):
raise FileNotFoundError(get_time() + 'path ' + arg.path + ' not found')
elif os.path.isfile(arg.path):
raise TypeError(get_time() + 'path ' + arg.path + ' is a file, not a directory')
elif not os.path.exists(arg.path):
raise FileNotFoundError(get_time() + "path '{}' not found".format(arg.path))
elif os.path.isfile(arg.path):
raise OSError(get_time() + "path '{}' is a file, not a directory".format(arg.path))
return arg

else:
error = 'arg.path should be None or a string, ' + str(type(arg.path))
error += ' is not a valid type'
raise TypeError(error)
error = "arg.path should be None or a string, '{}' is not a valid type"
raise TypeError(error.format(arg.path.__class__.__name__))


""" ########################## Sanitize input_ligands & input_cores ######################## """
Expand Down Expand Up @@ -99,16 +97,19 @@ def get_mol_defaults(mol_list, path=None, core=False):
tmp.path = path
tmp.is_core = core

if isinstance(mol, dict):
for key1 in mol:
tmp.mol = key1
for key2 in mol[key1]:
try:
tmp[key2] = key_dict[key2](mol[key1][key2])
except KeyError:
raise KeyError(str(key2) + ' is not a valid argument for ' + str(key1))
if key2 == 'guess_bonds':
tmp.tmp_guess = True
if not isinstance(mol, dict):
ret.append(tmp)
continue

for k1, v1 in mol.items():
tmp.mol = k1
for k2, v2 in v1.items():
try:
tmp[k2] = key_dict[k2](v2)
except KeyError:
raise KeyError("'{}' is not a valid argument for '{}'".format(str(k2), str(k1)))
if k2 == 'guess_bonds':
tmp.tmp_guess = True

ret.append(tmp)
return ret
Expand Down
Loading

0 comments on commit 152cb6c

Please sign in to comment.