CrysTens is a representation for storing crystal structure information that is originally in the form of Crystallographic Information Files (CIFs). CrysTens is a tensor of size 64x64x4 that can be used in any type of machine learning application involving crystal structures. This repository houses code for creating a stack of CrysTens', using the stack to train either a Vanilla Generative Adversarial Network (GAN), a Wasserstein GAN, or a diffusion model from https://github.com/lucidrains/imagen-pytorch/tree/main/imagen_pytorch. Once a model has been trained, newly generated CIFs can be created from a stack of generated CrysTens'. The details of the CrysTens representation, model training, and model analysis in the field of material discovery can be found published in Digital Discovery (link here: https://pubs.rsc.org/en/content/articlelanding/2024/dd/d3dd00137g)
In order to train a CrysTens generative model for material discovery, a stack of CrysTens' is required. Using get_stacked_crys_tens.py
and a Crystal Dictionary, any size of Stacked CrysTens can be received.
python get_stacked_crys_tens.py --crys_dict_path=Data/CrystalDictionary.jsonl --num_examples=500 --crys_tens_path=Data/StackedCrysTensor.npy
A crystal dictionary is any .jsonl
file that lists crystal structures in the following form.
{"Crystal Structure Name": {"a": ..., "b": ..., "c": ..., "alpha": ..., "beta": ..., "gamma": ..., "sg": ..., "siteList": [[Atomic Number 1, [X_1, Y_1, Z_1]], [Atomic Number 2, [X_2, Y_2, Z_2]], ...}}
Here is an example.
{"Ca4MgPd-1934448.cif": {"a": 14.545, "b": 14.545, "c": 14.545, "alpha": 90.0, "beta": 90.0, "gamma": 90.0, "sg": 216, "siteList": [[20, [0.56266, 0.25, 0.25]], [20, [0.43734, 0.75, 0.25]], ...}
Once a Stacked CrysTens has been created, we can train a generative model to produce synthetic CrysTens'. The options are a Vanilla GAN (lowest performance), a Wasserstein GAN (intermediate performance), and a diffusion model (highest) performance.
python train_vanilla_gan.py --data_path=Data/StackedCrysTensor.npy --save_path=Data/
python train_wasserstein_gan.py --data_path=Data/StackedCrysTensor.npy --save_path=Data/
python train_diffusion_model.py --data_path=Data/StackedCrysTensor.npy --save_path=Data/diffusion.pt --unet_number=1
followed by
python train_diffusion_model.py --data_path=Data/StackedCrysTensor.npy --save_path=Data/diffusion.pt --model_path=Data/diffusion.pt --unet_number=2
Once a trained generative model is obtained, we can generate a stack of synthetic CrysTens' using generate_new_crys_tens.py
.
python generate_new_crys_tens.py --model_path=Data/diffusion.pt --num_crys_tens=200 --crys_tens_path=Data/GenStackedCrysTens.npy
With a stack of generated CrysTens', generate_crystal_statistics.py can be used to turn the CrysTens' into CIFs as well as collect statistics about the performance of the generative model that produced them.
python generate_crystal_statistics --crys_tens_path=Data/GenStackedCrysTens.npy --cif_folder=Data/ --stats_folder=Data/