Skip to content

ADicksonLab/AGDIFF

Repository files navigation

AGDIFF: Attention-Enhanced Diffusion for Molecular Geometry Prediction

License: MIT

paper

This repository contains the official implementation of the work "AGDIFF: Attention-Enhanced Diffusion for Molecular Geometry Prediction".

AGDIFF introduces a novel approach that enhances diffusion models with attention mechanisms and an improved SchNet architecture, achieving state-of-the-art performance in predicting molecular geometries.

Unique Features of AGDIFF

  • Attention Mechanisms: Enhances the global and local encoders with attention mechanisms for better feature extraction and integration.
  • Improved SchNet Architecture: Incorporates learnable activation functions, adaptive scaling modules, and dual pathway processing to increase model expressiveness.
  • Batch Normalization: Stabilizes training and improves convergence for the local encoder.
  • Feature Expansion: Extends the MLP Edge Encoder with feature expansion and processing, combining processed features and bond embeddings for more adaptable edge representations.

photo not available

generation.mp4

Content

  1. Environment Setup
  2. Dataset
  3. Training
  4. Generation
  5. Evaluation
  6. Acknowledgment
  7. Citation

Environment Setup

Install dependencies via Conda/Mamba

conda env create -f agdiff.yml
conda activate agdiff
pip install torch_geometric
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cu121.html

Once you installed all the dependencies, you should install the package locally in editable mode:

pip install -e .

Dataset

Official Dataset

The preprocessed datasets (GEOM) provided by GEODIFF can be found in this [Google Drive folder]. After downloading and unzipping the dataset, it should be placed in the folder path specified by the dataset variable in the configuration files located at ./configs/*.yml. You may also want to use the pretrained model provided in the same link.

The official raw GEOM dataset is also available [here].

Training

AGDIFF's training details and hyper-parameters are provided in the config files (./configs/*.yml). Feel free to tune these parameters as needed.

To train the model, use the following commands:

python scripts/train.py ./configs/qm9_default.yml
python scripts/train.py ./configs/drugs_default.yml

Model checkpoints, configuration YAML files, and training logs will be saved in a directory specified by --logdir in train.py.

Generation

To generate conformations for entire or part of test sets, use:

python scripts/test.py ./logs/path/to/checkpoints/${iter}.pt ./configs/qm9_default.yml \
    --start_idx 0 --end_idx 200

Here start_idx and end_idx indicate the range of the test set that we want to use. To reproduce the paper's results, you should use 0 and 200 for start_idx and end_idx, respectively. All hyper-parameters related to sampling can be set in test.py files. Specifically, for testing the qm9 model, you could add the additional arg --w_global 0.3, which empirically shows slightly better results.

We also provide an example of conformation generation for a specific molecule (alanine dipeptide) in the examples folder. To generate conformations for alanine dipeptide, use:

python examples/test_alanine_dipeptide.py ./logs/path/to/checkpoints/${iter}.pt ./configs/qm9_default.yml 

Evaluation

After generating conformations, evaluate the results of benchmark tasks using the following commands.

Task 1. Conformation Generation

Calculate COV and MAT scores on the GEOM datasets with:

python scripts/evaluation/eval_covmat.py path/to/samples/sample_all.pkl

Acknowledgement

Our implementation is based on GEODIFF, PyTorch, PyG, SchNet

Citation

If you use our code or method in your work, please consider citing the following:

@misc{wyzykowskiAGDIFFAttentionEnhancedDiffusion2024,
  title = {{{AGDIFF}}: {{Attention-Enhanced Diffusion}} for {{Molecular Geometry Prediction}}},
  shorttitle = {{{AGDIFF}}},
  author = {Wyzykowski, Andr{\'e} Brasil Vieira and Fathi Niazi, Fatemeh and Dickson, Alex},
  year = {2024},
  month = oct,
  publisher = {ChemRxiv},
  doi = {10.26434/chemrxiv-2024-wrvr4},
  urldate = {2024-10-09},
  archiveprefix = {ChemRxiv},
  langid = {english},
  keywords = {attention,conformer,diffusion models,generative,GNN,graph neural network,machine learning,structure}
}

Please direct any questions to André Wyzykowski (abvwmc@gmail.com) and Alex Dickson (alexrd@msu.edu).