diff --git a/sidechainnet/utils/manual_adjustment.py b/sidechainnet/utils/manual_adjustment.py index 249c67e..0a7e29c 100644 --- a/sidechainnet/utils/manual_adjustment.py +++ b/sidechainnet/utils/manual_adjustment.py @@ -68,31 +68,37 @@ def _repair_1GJJ_1_A(datadict): """ # Locate positions in data arrays found_splits_indices = [] - for split in scn.DATA_SPLITS: - for idx, cur_id in enumerate(datadict[split]["ids"]): - if cur_id == '1GJJ_1_A': + for split in datadict.keys(): + try: + ids = datadict[split]["ids"] + except: + continue + for idx, cur_id in enumerate(ids): + if '1GJJ_1_A' in cur_id: found_splits_indices.append((split, idx)) # Carefully split into two entries containing the appropriate data ranges for split, idx in found_splits_indices: for key in datadict[split].keys(): if key == 'res': - datadict[split][key].append(datadict[split][key][idx]) + datadict[split][key].insert(idx + 1, datadict[split][key][idx]) elif key == 'ids': - datadict[split][key].append(datadict[split][key][idx] + "2") - datadict[split][key][idx] = datadict[split][key][idx] + "1" + original_id = str(datadict[split][key][idx]) + datadict[split][key].insert(idx + 1, original_id + "2") + datadict[split][key][idx] = original_id + "1" elif key == 'ums': - datadict[split][key].append(" ".join( - datadict[split][key][idx].split()[110:153])) - datadict[split][key][idx] = " ".join( - datadict[split][key][idx].split()[0:50]) + res_list = datadict[split][key][idx].split() + datadict[split][key].insert(idx + 1, " ".join(res_list[110:153])) + datadict[split][key][idx] = " ".join(res_list[0:50]) elif key == 'crd': - datadict[split][key].append( - datadict[split][key][idx][110 * NUM_COORDS_PER_RES:153 * - NUM_COORDS_PER_RES]) - datadict[split][key][idx] = datadict[split][key][idx][ - 0 * NUM_COORDS_PER_RES:50 * NUM_COORDS_PER_RES] + original_crds = datadict[split][key][idx].copy() + datadict[split][key].insert( + idx + 1, + original_crds[110 * NUM_COORDS_PER_RES:153 * NUM_COORDS_PER_RES]) + datadict[split][key][idx] = original_crds[0 * NUM_COORDS_PER_RES:50 * + NUM_COORDS_PER_RES] else: - datadict[split][key].append(datadict[split][key][idx][110:153]) - datadict[split][key][idx] = datadict[split][key][idx][0:50] + original_str = datadict[split][key][idx] + datadict[split][key].insert(idx + 1, original_str[110:153]) + datadict[split][key][idx] = original_str[0:50] return datadict