This repository and its references contain the models, data and scripts used to carry out the experiments in the Masked graph modeling for molecule generation paper.
We used a Linux OS with an Nvidia Tesla P100-SXM2 GPU with 16 GB of memory.
A conda environment file (environment.yml) is provided as part of this repository. It may contain packages beyond those needed to run the scripts here. If not using this file, please install the following dependencies.
python 3.7
pytorch 1.4.0
tensorflow 1.14
rdkit 2019.09.3
guacamol 0.5.0
dgl 0.4.3post2
tensorboardx 2.0
scipy 1.4.1
CUDA (We used version 10.0)
Pytorch, Tensorflow and dgl installations should correspond to the CUDA version used.
QM9 SMILES strings are included in this repository at data/QM9/QM9_smiles.txt
To process QM9 smiles for use in train and generation scripts:
python -m data.gen_targets --data-path data/QM9/QM9_smiles.txt --save-path data/QM9/QM9_processed.p --dataset-type QM9
To download the ChEMBL dataset:
Training set: wget -O data/ChEMBL/guacamol_v1_train.smiles https://ndownloader.figshare.com/files/13612760
Validation set: wget -O data/ChEMBL/guacamol_v1_valid.smiles https://ndownloader.figshare.com/files/13612766
Full dataset (training + validation + test): wget -O data/ChEMBL/guacamol_v1_all.smiles https://ndownloader.figshare.com/files/13612745
To process the training dataset after downloading for use in train and generation scripts:
python -m data.gen_targets --data-path data/ChEMBL/guacamol_v1_train.smiles --save-path data/ChEMBL/ChEMBL_train_processed.p --dataset-type ChEMBL
To process the validation dataset after downloading for use in train and generation scripts:
python -m data.gen_targets --data-path data/ChEMBL/guacamol_v1_valid.smiles --save-path data/ChEMBL/ChEMBL_val_processed.p --dataset-type ChEMBL
Pretrained models are provided
here for both datasets.
To use these for generation, download the entire dumped
folder to the repository root.
As an alternative to using pretrained models, the following are scripts for training models from scratch.
python train.py --data_path data/QM9/QM9_processed.p --graph_type QM9 --exp_name QM9_experiment --num_node_types 5 --num_edge_types 5 --max_nodes 9 --layer_norm --spatial_msg_res_conn --batch_size 1024 --val_batch_size 2500 --val_after 105 --num_epochs 200 --shuffle --mask_independently --force_mask_predict --optimizer adam,lr=0.0001 --tensorboard
python train.py --data_path data/ChEMBL/ChEMBL_train_processed.p --graph_type ChEMBL --exp_name chembl_experiment --val_data_path data/ChEMBL/ChEMBL_val_processed.p --num_node_types 12 --num_edge_types 5 --max_nodes 88 --min_charge -1 --max_charge 3 --mpnn_steps 6 --layer_norm --spatial_msg_res_conn --batch_size 32 --val_batch_size 64 --grad_accum_iters 16 --val_after 3200 --num_epochs 10 --shuffle --force_mask_predict --mask_independently --optimizer adam,lr=0.0001 --tensorboard
In addition to scripts for training our model, we include the scripts for training the baseline autoregressive Transformer models (download preprocessed data for the Transformer models here). The code for the model is based on the following repository: https://github.com/facebookresearch/XLM
python train_ar.py --dump_path ./ --data_path /path/to/chembl/data --data_type ChEMBL --exp_name chembl_smiles_transformer_regular --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 --epoch_size 1000 --emb_dim 1024 --n_layers 6 --n_heads 8 --dropout 0.1 --attention_dropout 0.1 --batch_size 128
python train_ar.py --dump_path ./ --data_path /path/to/chembl/data --data_type ChEMBL --exp_name chembl_smiles_transformer_small --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 --epoch_size 1000 --emb_dim 512 --n_layers 4 --n_heads 8 --dropout 0.1 --attention_dropout 0.1 --batch_size 128
For QM9 take the above scripts ChEMBL and set the flag --data_type
to QM9
.
The following are scripts for generating from trained/pretrained models.
The --node_target_frac
and --edge_target_frac
options set the masking rate for node and edge features respectively.
To generate using training initialisation and masking rate 0.1:
python generate.py --data_path data/QM9/QM9_processed.p --graph_type QM9 --model_path dumped/QM9_experiment/best_model --smiles_dataset_path data/QM9/QM9_smiles.txt --output_dir dumped/QM9_experiment/generation/train_init/mask10/results --num_node_types 5 --num_edge_types 5 --max_nodes 9 --layer_norm --embed_hs --spatial_msg_res_conn --num_iters 400 --num_sampling_iters 400 --cp_save_dir dumped/QM9_experiment/generation/train_init/mask10/generation_checkpoints --batch_size 2500 --checkpointing_period 400 --evaluation_period 20 --save_period 20 --evaluate_finegrained --save_finegrained --mask_independently --retrieve_train_graphs --node_target_frac 0.1 --edge_target_frac 0.1
To generate using marginal initialisation and masking rate 0.2:
python generate.py --data_path data/QM9/QM9_processed.p --graph_type QM9 --model_path dumped/QM9_experiment/best_model --smiles_dataset_path data/QM9/QM9_smiles.txt --output_dir dumped/QM9_experiment/generation/marginal_init/mask20/results --num_node_types 5 --num_edge_types 5 --max_nodes 9 --layer_norm --embed_hs --spatial_msg_res_conn --num_iters 400 --num_sampling_iters 400 --cp_save_dir dumped/QM9_experiment/generation/train_init/mask20/generation_checkpoints --batch_size 2500 --checkpointing_period 400 --evaluation_period 20 --save_period 20 --evaluate_finegrained --save_finegrained --mask_independently --random_init --node_target_frac 0.2 --edge_target_frac 0.2
To generate using training initialisation and masking rate 0.01:
python generate.py --data_path data/ChEMBL/ChEMBL_train_processed.p --graph_type ChEMBL --model_path dumped/chembl_experiment/best_model --smiles_dataset_path data/ChEMBL/guacamol_v1_all.smiles --output_dir dumped/chembl_experiment/generation/train_init/mask1/results --num_node_types 12 --num_edge_types 5 --max_nodes 88 --min_charge -1 --max_charge 3 --layer_norm --mpnn_steps 6 --embed_hs --spatial_msg_res_conn --num_iters 300 --num_sampling_iters 300 --checkpointing_period 300 --evaluation_period 20 --save_period 20 --evaluate_finegrained --save_finegrained --cp_save_dir dumped/chembl_experiment/generation/train_init/mask1/generation_checkpoints --batch_size 32 --mask_independently --retrieve_train_graphs --node_target_frac 0.01 --edge_target_frac 0.01
python generate_ar_distributional.py --model_path /path/to/trained/qm9/model/best_model.pth \ --dist_file QM9_all.smiles
python generate_ar_distributional.py --model_path /path/to/trained/chembl/model/best_model.pth \ --dist_file guacamol_v1_all.smiles
SMILES strings and distributional results at each recorded generation step can be found in <output_dir> from the MGM generation script used above.
To print generation results at each step in a dataframe:
python get_best_distributional_results.py <output_dir>
We also provide a list of SMILES strings of 20,000 generated molecules each for QM9 with a 10% masking rate and ChEMBL with a 1% masking rate here. Training initialisation was used in both cases.
If you have found the materials in this repository useful, please consider citing: Mahmood, O., Mansimov, E., Bonneau, R. et al. Masked graph modeling for molecule generation. Nat Commun 12, 3156 (2021). https://doi.org/10.1038/s41467-021-23415-2