diff --git a/graphein/ml/datasets/torch_geometric_dataset.py b/graphein/ml/datasets/torch_geometric_dataset.py index 1f145849..9f464fd7 100644 --- a/graphein/ml/datasets/torch_geometric_dataset.py +++ b/graphein/ml/datasets/torch_geometric_dataset.py @@ -301,9 +301,12 @@ def __init__( root, pdb_codes: Optional[List[str]] = None, uniprot_ids: Optional[List[str]] = None, - graph_label_map: Optional[Dict[str, int]] = None, - node_label_map: Optional[Dict[str, int]] = None, - chain_selection_map: Optional[Dict[str, List[str]]] = None, + # graph_label_map: Optional[Dict[str, int]] = None, + graph_labels: Optional[List[torch.Tensor]] = None, + node_labels: Optional[List[torch.Tensor]] = None, + chain_selections: Optional[List[str]] = None, + # node_label_map: Optional[Dict[str, int]] = None, + # chain_selection_map: Optional[Dict[str, List[str]]] = None, graphein_config: ProteinGraphConfig = ProteinGraphConfig(), graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( src_format="nx", dst_format="pyg" @@ -395,10 +398,25 @@ def __init__( self.af_version = af_version # Labels & Chains - self.graph_label_map = graph_label_map - self.node_label_map = node_label_map - self.chain_selection_map = chain_selection_map - self.bad_pdbs: List[str] = [] + + self.examples: Dict[int, str] = dict(enumerate(self.structures)) + + if graph_labels is not None: + self.graph_label_map = dict(enumerate(graph_labels)) + else: + self.graph_label_map = None + + if node_labels is not None: + self.node_label_map = dict(enumerate(node_labels)) + else: + self.node_label_map = None + + if chain_selections is not None: + self.chain_selection_map = dict(enumerate(chain_selections)) + else: + self.graph_label_map = None + self.validate_input() + self.bad_pdbs: List[str] = [] # Configs self.config = graphein_config @@ -422,7 +440,34 @@ def raw_file_names(self) -> List[str]: @property def processed_file_names(self) -> List[str]: """Names of processed files to look for""" - return [f"{pdb}.pt" for pdb in self.structures] + if self.chain_selection_map is not None: + return [ + f"{pdb}_{chain}.pt" + for pdb, chain in zip( + self.structures, self.chain_selection_map.values() + ) + ] + else: + return [f"{pdb}.pt" for pdb in self.structures] + + def validate_input(self): + assert len(self.structures) == len( + self.graph_label_map + ), "Number of proteins and graph labels must match" + assert len(self.structures) == len( + self.node_label_map + ), "Number of proteins and node labels must match" + assert len(self.structures) == len( + self.chain_selection_map + ), "Number of proteins and chain selections must match" + assert len( + { + f"{pdb}_{chain}" + for pdb, chain in zip( + self.structures, self.chain_selection_map + ) + } + ) == len(self.structures), "Duplicate protein/chain combinations" def download(self): """Download the PDB files from RCSB or Alphafold.""" @@ -489,48 +534,46 @@ def divide_chunks(l: List[str], n: int = 2) -> List[List[str]]: for i in range(0, len(l), n): yield l[i : i + n] - chunks = list(divide_chunks(self.structures, chunk_size)) + # chunks = list(divide_chunks(self.structures, chunk_size)) + chunks: List[int] = list( + divide_chunks(list(self.examples.keys()), chunk_size) + ) for chunk in tqdm(chunks): + pdbs = [self.examples[idx] for idx in chunk] # Get chain selections - if self.chain_selection_map: + if self.chain_selection_map is not None: chain_selections = [ - self.chain_selection_map[pdb] - if pdb in self.chain_selection_map.keys() - else "all" - for pdb in chunk + self.chain_selection_map[idx] for idx in chunk ] else: - chain_selections = None + chain_selections = ["all"] * len(chunk) # Create graph objects - file_names = [f"{self.raw_dir}/{pdb}.pdb" for pdb in chunk] + file_names = [f"{self.raw_dir}/{pdb}.pdb" for pdb in pdbs] graphs = construct_graphs_mp( pdb_path_it=file_names, config=self.config, chain_selections=chain_selections, - return_dict=True, + return_dict=False, ) if self.graph_transformation_funcs is not None: - graphs = { - k: self.transform_graphein_graphs(v) - for k, v in graphs.items() - } + graphs = [self.transform_graphein_graphs(g) for g in graphs] + # Convert to PyTorch Geometric Data - graphs = { - k: self.graph_format_convertor(v) for k, v in graphs.items() - } - graphs = dict(zip(chunk, graphs.values())) + graphs = [self.graph_format_convertor(g) for g in graphs] # Assign labels if self.graph_label_map: - for k, v in self.graph_label_map.items(): - graphs[k].graph_y = v + labels = [self.graph_label_map[idx] for idx in chunk] + for i, _ in enumerate(chunk): + graphs[i].graph_y = labels[i] if self.node_label_map: - for k, v in self.node_label_map.items(): - graphs[k].node_y = v + labels = [self.node_label_map[idx] for idx in chunk] + for i, _ in enumerate(chunk): + graphs[i].graph_y = labels[i] - data_list = list(graphs.values()) + data_list = graphs del graphs @@ -540,18 +583,11 @@ def divide_chunks(l: List[str], n: int = 2) -> List[List[str]]: if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] - idxs = [ - i - for i in range(idx * chunk_size, idx * chunk_size + len(chunk)) - ] - - for data, id in zip(data_list, idxs): + for i, (pdb, chain) in enumerate(zip(pdbs, chain_selections)): torch.save( - data, - os.path.join( - self.processed_dir, f"{self.structures[id]}.pt" - ), + data_list[i], + os.path.join(self.processed_dir, f"{pdb}_{chain}.pt"), ) idx += 1 @@ -563,9 +599,17 @@ def get(self, idx: int): :type idx: int :return: PyTorch Geometric Data object. """ - return torch.load( - os.path.join(self.processed_dir, f"{self.structures[idx]}.pt") - ) + if self.chain_selection_map is not None: + return torch.load( + os.path.join( + self.processed_dir, + f"{self.structures[idx]}_{self.chain_selection_map[idx]}.pt", + ) + ) + else: + return torch.load( + os.path.join(self.processed_dir, f"{self.structures[idx]}.pt") + ) class ProteinGraphListDataset(InMemoryDataset):