diff --git a/.gitignore b/.gitignore index 0d10310c..63fe7879 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ _C_flare* **/xml **/dist **egg-info +*.DS_Store diff --git a/CMakeLists.txt b/CMakeLists.txt index 4684ae7f..d9d71dd6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,7 +9,7 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/CMakeModules/) include(ExternalProject) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) @@ -111,6 +111,22 @@ endif() # Greatly reduces the code bloat target_compile_options(pybind11 INTERFACE "-fvisibility=hidden") + +# Distmatrix +############################################################################### +FetchContent_Declare( + distmatrix + GIT_REPOSITORY https://github.com/mir-group/distmatrix.git +) +FetchContent_MakeAvailable(distmatrix) + +link_directories(${distmatrix_BINARY_DIR}) +link_directories(${distmatrix_BINARY_DIR}/scalapack_build/lib) + +include_directories(${distmatrix_SOURCE_DIR}/include) + + +#target_link_libraries(flare PUBLIC distmatrix) ############################################################################### # Specify source files. @@ -119,14 +135,19 @@ set(FLARE_SOURCES src/flare_pp/radial.cpp src/flare_pp/cutoffs.cpp src/flare_pp/structure.cpp + src/flare_pp/utils.cpp src/flare_pp/bffs/sparse_gp.cpp + src/flare_pp/bffs/parallel_sgp.cpp src/flare_pp/bffs/gp.cpp src/flare_pp/descriptors/descriptor.cpp + src/flare_pp/descriptors/b1.cpp src/flare_pp/descriptors/b2.cpp src/flare_pp/descriptors/b2_norm.cpp src/flare_pp/descriptors/b2_simple.cpp src/flare_pp/descriptors/b3.cpp - src/flare_pp/descriptors/wigner3j.cpp + src/flare_pp/descriptors/bk.cpp + src/flare_pp/descriptors/indices.cpp + src/flare_pp/descriptors/coeffs.cpp src/flare_pp/descriptors/two_body.cpp src/flare_pp/descriptors/three_body.cpp src/flare_pp/descriptors/three_body_wide.cpp @@ -152,6 +173,14 @@ include_directories( add_library(flare STATIC ${FLARE_SOURCES}) set_target_properties(flare PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(flare eigen_project) +add_dependencies(flare distmatrix) + +# Link to MPI +find_package(MPI REQUIRED) +target_link_libraries(flare PUBLIC MPI::MPI_CXX) + +# Link to distmatrix +target_link_libraries(flare PUBLIC distmatrix -lgfortran) # Link to json. target_link_libraries(flare PUBLIC nlohmann_json::nlohmann_json) @@ -175,6 +204,8 @@ if (DEFINED ENV{MKL_INCLUDE} AND DEFINED ENV{MKL_LIBS}) include_directories($ENV{MKL_INCLUDE}) target_link_libraries(flare PUBLIC $ENV{MKL_LIBS}) target_compile_definitions(flare PUBLIC EIGEN_USE_MKL_ALL=1) + target_include_directories(flare PUBLIC $ENV{MKL_MPI}) + message(MKL_MPI=$ENV{MKL_MPI}) # Give the option to do a "bare-bones" build without Lapack/Lapacke. elseif(DEFINED ENV{NO_LAPACK}) message(STATUS "Building without Lapack.") diff --git a/README.md b/README.md index 5edf95b8..a487b916 100644 --- a/README.md +++ b/README.md @@ -13,3 +13,36 @@ If you're installing on Harvard's compute cluster, load the following modules fi ``` module load cmake/3.17.3-fasrc01 python/3.6.3-fasrc01 gcc/9.3.0-fasrc01 ``` + +### MPI compilation +To install the MPI version of sparse GP, on Harvard's compute cluster, load the following modules first: +``` +module load cmake/3.17.3-fasrc01 python/3.6.3-fasrc01 gcc/9.3.0-fasrc01 +module load intel-mkl/2017.2.174-fasrc01 openmpi/4.0.5-fasrc01 +``` + +Then clone the repository +``` +git clone https://github.com/mir-group/flare_pp.git +git checkout mpi_distmat +``` + +Create a directory for building the library +``` +mkdir build +cd build +``` + +Compile with `cmake` and `make`. Here we need to specify the MPI compiler with +`CC=mpicc CXX=mpic++ FC=mpif90`. We also set the option `-DSCALAPACK_LIB=NOTFOUND`, +such that we download and compile our own static Scalapack library to support +python binding with MPI parallelized sparse GP +``` +CC=mpicc CXX=mpic++ FC=mpif90 cmake .. -DSCALAPACK_LIB=NOTFOUND +make -j +``` + +Copy the python binding library to the folder of `flare_pp` +``` +cp _C_flare.cpython-36m-x86_64-linux-gnu.so ../flare_pp +``` diff --git a/ctests/CMakeLists.txt b/ctests/CMakeLists.txt index 687b0ba6..b7d9f8d3 100644 --- a/ctests/CMakeLists.txt +++ b/ctests/CMakeLists.txt @@ -1,3 +1,8 @@ +if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/dft_data.xyz) + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/dft_data.xyz + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +endif () + # make executable add_executable( tests @@ -6,8 +11,10 @@ add_executable( test_structure.cpp test_n_body.cpp test_sparse_gp.cpp + test_parallel_sgp.cpp test_descriptor.cpp test_json.cpp + test_utils.cpp ) include_directories(../src/flare_pp) diff --git a/ctests/dft_data.xyz b/ctests/dft_data.xyz new file mode 100644 index 00000000..8ae59311 --- /dev/null +++ b/ctests/dft_data.xyz @@ -0,0 +1,17 @@ +2 +Lattice="1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0" Properties=species:S:1:pos:R:3:momenta:R:3:forces:R:3 sparse_indices= energy=0.9133510291442271 stress="0.4314439934824156 0.7292230215429331 0.38721540888657746 0.7292230215429331 0.326605318720777 0.25504993852246405 0.38721540888657746 0.25504993852246405 0.3185010782409763" pbc="T T T" +He 0.15713841 0.99710398 0.64092684 0.19861472 0.12959016 0.46411878 0.48801624 0.78256914 0.52750114 +He 0.69338569 0.90346464 0.14626089 0.92338744 0.37879257 0.79694737 0.92320467 0.71863087 0.54575622 +4 +Lattice="1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0" Properties=species:S:1:pos:R:3:momenta:R:3:forces:R:3 sparse_indices=2 energy=0.5845339958967781 stress="0.8832643728374101 0.4736589482749042 0.6125575360846077 0.4736589482749042 0.04982809557310608 0.9519383879956919 0.6125575360846077 0.9519383879956919 0.22321884060680475" pbc="T T T" +H 0.98699548 0.27109152 0.82840818 0.39723635 0.69230317 0.71742402 0.97501125 0.69790992 0.42736135 +He 0.95542741 0.62792707 0.71255717 0.20846441 0.71391509 0.56988269 0.73802987 0.55611127 0.20978028 +H 0.95551793 0.79583919 0.01626457 0.50770084 0.31544704 0.09250715 0.97886919 0.15580748 0.02892780 +He 0.34877543 0.15749443 0.15293379 0.64543458 0.63466428 0.15787333 0.67504974 0.30204244 0.41673029 +5 +Lattice="1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0" Properties=species:S:1:pos:R:3:momenta:R:3:forces:R:3 sparse_indices="0 2 4" energy=0.26796364806180983 stress="0.5706389111071576 0.3483258644776196 0.06089244456604703 0.3483258644776196 0.5296282371894802 0.21476826115341807 0.06089244456604703 0.21476826115341807 0.008875719863243536" pbc="T T T" +H 0.35071505 0.73963362 0.02696195 0.50938856 0.68674285 0.31711543 0.21822063 0.99866869 0.23949716 +H 0.78288668 0.67299943 0.47254490 0.03722967 0.36556295 0.28723150 0.36477796 0.28274931 0.78633052 +H 0.43430905 0.05355055 0.52035306 0.40526494 0.75298417 0.54189491 0.15815742 0.15396677 0.20679348 +He 0.50335965 0.87700387 0.47565834 0.59881725 0.71577512 0.87647597 0.44898467 0.56835792 0.32339863 +He 0.28826927 0.14879537 0.18758152 0.00333732 0.24138565 0.27410589 0.43239731 0.43696276 0.94641792 diff --git a/ctests/test_descriptor.cpp b/ctests/test_descriptor.cpp index 1c9ff0c7..f9cb1f94 100644 --- a/ctests/test_descriptor.cpp +++ b/ctests/test_descriptor.cpp @@ -1,52 +1,132 @@ +#include "bk.h" +#include "b1.h" +#include "b2.h" #include "b3.h" #include "descriptor.h" #include "test_structure.h" #include "gtest/gtest.h" #include #include +#include #include -TEST_F(StructureTest, RotationTest) { +// Test different types B1, B2, B3 to match with Bk +template +class DescTest : public StructureTest { +public: + using List = std::list; +}; + +using DescTypes = ::testing::Types; +TYPED_TEST_SUITE(DescTest, DescTypes); + +TYPED_TEST(DescTest, TestBk) { + // Set up B1/2/3 descriptors + std::vector descriptor_settings_1{this->n_species, this->N, this->L}; + std::vector descriptors1; + + TypeParam desc1(this->radial_string, this->cutoff_string, this->radial_hyps, + this->cutoff_hyps, descriptor_settings_1); + descriptors1.push_back(&desc1); + + // Set up Bk descriptors + int K = desc1.K; + std::vector descriptor_settings_2{this->n_species, K, this->N, this->L}; + Bk desc2(this->radial_string, this->cutoff_string, this->radial_hyps, + this->cutoff_hyps, descriptor_settings_2); + std::vector descriptors2; + descriptors2.push_back(&desc2); + + Structure struc1 = Structure(this->cell, this->species, this->positions, this->cutoff, descriptors1); + Structure struc2 = Structure(this->cell, this->species, this->positions, this->cutoff, descriptors2); + + // Check the descriptor dimensions + std::vector last_index = desc2.nu[desc2.nu.size()-1]; + int n_d = last_index[last_index.size()-1] + 1; // the size of list nu + int n_d1 = struc1.descriptors[0].n_descriptors; + int n_d2 = struc2.descriptors[0].n_descriptors; + EXPECT_EQ(n_d, n_d1); + EXPECT_EQ(n_d1, n_d2); + + // Check that Bk and B3 give the same descriptors. + double d1, d2; + int nu_ind; + for (int i = 0; i < struc1.descriptors.size(); i++) { + for (int j = 0; j < struc1.descriptors[i].descriptors.size(); j++) { + for (int k = 0; k < struc1.descriptors[i].descriptors[j].rows(); k++) { + for (int l = 0; l < struc1.descriptors[i].descriptors[j].cols(); l++) { + d1 = struc1.descriptors[i].descriptors[j](k, l); + d2 = struc2.descriptors[i].descriptors[j](k, l); + EXPECT_NEAR(d1, d2, 1e-8); + } + } + } + } + + for (int i = 0; i < struc1.descriptors.size(); i++) { + for (int j = 0; j < struc1.descriptors[i].descriptor_force_dervs.size(); j++) { + for (int k = 0; k < struc1.descriptors[i].descriptor_force_dervs[j].rows(); k++) { + for (int l = 0; l < struc1.descriptors[i].descriptor_force_dervs[j].cols(); l++) { + d1 = struc1.descriptors[i].descriptor_force_dervs[j](k, l); + d2 = struc2.descriptors[i].descriptor_force_dervs[j](k, l); + EXPECT_NEAR(d1, d2, 1e-8); + } + } + } + } + +} +// Test Bk with K=1,2,3 for rotational invariance +class DescRotTest : public StructureTest, + public testing::WithParamInterface { +public: // Choose arbitrary rotation angles. double xrot = 1.28; double yrot = -3.21; double zrot = 0.42; + Eigen::MatrixXd rotated_pos; + Eigen::MatrixXd rotated_cell; + // Define rotation matrices. Eigen::MatrixXd Rx{3, 3}, Ry{3, 3}, Rz{3, 3}, R{3, 3}; - Rx << 1, 0, 0, 0, cos(xrot), -sin(xrot), 0, sin(xrot), cos(xrot); - Ry << cos(yrot), 0, sin(yrot), 0, 1, 0, -sin(yrot), 0, cos(yrot); - Rz << cos(zrot), -sin(zrot), 0, sin(zrot), cos(zrot), 0, 0, 0, 1; - R = Rx * Ry * Rz; - - Eigen::MatrixXd rotated_pos = positions * R.transpose(); - Eigen::MatrixXd rotated_cell = cell * R.transpose(); + DescRotTest() { + Rx << 1, 0, 0, 0, cos(xrot), -sin(xrot), 0, sin(xrot), cos(xrot); + Ry << cos(yrot), 0, sin(yrot), 0, 1, 0, -sin(yrot), 0, cos(yrot); + Rz << cos(zrot), -sin(zrot), 0, sin(zrot), cos(zrot), 0, 0, 0, 1; + R = Rx * Ry * Rz; + rotated_pos = positions * R.transpose(); + rotated_cell = cell * R.transpose(); + } +}; - // Define descriptors. - descriptor_settings[2] = 2; - B3 descriptor = B3(radial_string, cutoff_string, radial_hyps, cutoff_hyps, - descriptor_settings); +TEST_P(DescRotTest, RotationTest) { + int K = GetParam(); + std::vector descriptor_settings{n_species, K, N, L}; + Bk desc(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); std::vector descriptors; - descriptors.push_back(&descriptor); + descriptors.push_back(&desc); Structure struc1 = Structure(cell, species, positions, cutoff, descriptors); Structure struc2 = - Structure(rotated_cell, species, rotated_pos, cutoff, descriptors); + Structure(this->rotated_cell, species, this->rotated_pos, cutoff, descriptors); // Check that B1 is rotationally invariant. - double d1, d2, diff; - double tol = 1e-10; + double d1, d2; + std::cout << "n_descriptors=" << struc1.descriptors[0].n_descriptors << std::endl; for (int n = 0; n < struc1.descriptors[0].n_descriptors; n++) { d1 = struc1.descriptors[0].descriptors[0](0, n); d2 = struc2.descriptors[0].descriptors[0](0, n); - diff = d1 - d2; - EXPECT_LE(abs(diff), tol); + EXPECT_NEAR(d1, d2, 1e-10); } } +INSTANTIATE_TEST_SUITE_P(DescBodies, DescRotTest, testing::Values(1, 2, 3)); + // TEST_F(DescriptorTest, SingleBond) { // // Check that B1 descriptors match the corresponding elements of the // // single bond vector. diff --git a/ctests/test_parallel_sgp.cpp b/ctests/test_parallel_sgp.cpp new file mode 100644 index 00000000..abffadc2 --- /dev/null +++ b/ctests/test_parallel_sgp.cpp @@ -0,0 +1,377 @@ +#include "parallel_sgp.h" +#include "sparse_gp.h" +#include "test_structure.h" +#include "omp.h" +#include "mpi.h" +#include +#include +#include // Iota +#include + +class ParSGPTest : public StructureTest { +public: + double sigma_e = 1; + double sigma_f = 2; + double sigma_s = 3; + int n_atoms_1 = 10; + int n_atoms_2 = 17; + int n_atoms = 10; + int n_types = n_species; + std::vector kernels; + SparseGP sparse_gp; + std::vector training_strucs; + std::vector>> sparse_indices; + + ParSGPTest() { + blacs::initialize(); + + std::vector dc; + B2 ps(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps); + B2 ps1(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps1); + + kernels.push_back(&kernel_norm); + kernels.push_back(&kernel_3_norm); + sparse_gp = SparseGP(kernels, sigma_e, sigma_f, sigma_s); + + // Generate random labels + Eigen::VectorXd energy = Eigen::VectorXd::Random(1); + Eigen::VectorXd forces = Eigen::VectorXd::Random(n_atoms * 3); + Eigen::VectorXd stresses = Eigen::VectorXd::Random(6); + + // Broadcast data such that different procs won't generate different random numbers + MPI_Bcast(energy.data(), 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(forces.data(), n_atoms * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(stresses.data(), 6, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + test_struc = Structure(cell, species, positions, cutoff, dc); + test_struc.energy = energy; + test_struc.forces = forces; + test_struc.stresses = stresses; + + // Make positions. + Eigen::MatrixXd cell_1, cell_2; + std::vector species_1, species_2; + Eigen::MatrixXd positions_1, positions_2; + Eigen::VectorXd labels_1, labels_2; + + cell_1 = Eigen::MatrixXd::Identity(3, 3) * cell_size; + cell_2 = Eigen::MatrixXd::Identity(3, 3) * cell_size; + + positions_1 = Eigen::MatrixXd::Random(n_atoms_1, 3) * cell_size / 2; + positions_2 = Eigen::MatrixXd::Random(n_atoms_2, 3) * cell_size / 2; + MPI_Bcast(positions_1.data(), n_atoms_1 * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(positions_2.data(), n_atoms_2 * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + labels_1 = Eigen::VectorXd::Random(1 + n_atoms_1 * 3 + 6); + labels_2 = Eigen::VectorXd::Random(1 + n_atoms_2 * 3 + 6); + MPI_Bcast(labels_1.data(), 1 + n_atoms_1 * 3 + 6, MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(labels_2.data(), 1 + n_atoms_2 * 3 + 6, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + // Make random species. + for (int i = 0; i < n_atoms_1; i++) { + species_1.push_back(rand() % n_species); + } + for (int i = 0; i < n_atoms_2; i++) { + species_2.push_back(rand() % n_species); + } + MPI_Bcast(species_1.data(), n_atoms_1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(species_2.data(), n_atoms_2, MPI_INT, 0, MPI_COMM_WORLD); + + // Build kernel matrices for paralle sgp + //std::vector>> sparse_indices = {{{0, 1}, {2}}}; + //std::vector> comm_sparse_ind = {{0, 1, 5, 7}, {2, 3, 4}}; + std::vector> comm_sparse_ind = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 4, 5, 6, 7, 8, 9}}; + sparse_indices = {comm_sparse_ind, comm_sparse_ind}; + + std::cout << "Start building" << std::endl; + Structure struc_1 = Structure(cell_1, species_1, positions_1); + struc_1.energy = labels_1.segment(0, 1); + struc_1.forces = labels_1.segment(1, n_atoms_1 * 3); + struc_1.stresses = labels_1.segment(1 + n_atoms_1 * 3, 6); + std::cout << "Done struc_1" << std::endl; + + Structure struc_2 = Structure(cell_2, species_2, positions_2); + struc_2.energy = labels_2.segment(0, 1); + struc_2.forces = labels_2.segment(1, n_atoms_2 * 3); + struc_2.stresses = labels_2.segment(1 + n_atoms_2 * 3, 6); + std::cout << "Done struc_2" << std::endl; + + training_strucs = {struc_1, struc_2}; + + if (blacs::mpirank == 0) { + // Build sparse_gp (non parallel) + Structure train_struc_1 = Structure(cell_1, species_1, positions_1, cutoff, dc); + train_struc_1.energy = labels_1.segment(0, 1); + train_struc_1.forces = labels_1.segment(1, n_atoms_1 * 3); + train_struc_1.stresses = labels_1.segment(1 + n_atoms_1 * 3, 6); + + Structure train_struc_2 = Structure(cell_2, species_2, positions_2, cutoff, dc); + train_struc_2.energy = labels_2.segment(0, 1); + train_struc_2.forces = labels_2.segment(1, n_atoms_2 * 3); + train_struc_2.stresses = labels_2.segment(1 + n_atoms_2 * 3, 6); + + sparse_gp.add_training_structure(train_struc_1); + sparse_gp.add_specific_environments(train_struc_1, {sparse_indices[0][0], sparse_indices[1][0]}); + sparse_gp.add_training_structure(train_struc_2); + sparse_gp.add_specific_environments(train_struc_2, {sparse_indices[0][1], sparse_indices[1][1]}); + sparse_gp.update_matrices_QR(); + std::cout << "Done QR for sparse_gp" << std::endl; + + sparse_gp.write_mapping_coefficients("beta.txt", "Me", {0}, "potential"); + sparse_gp.write_mapping_coefficients("beta_var.txt", "Me", {0}, "uncertainty"); + } + } +}; + +TEST_F(ParSGPTest, BuildParSGP){ + std::vector dc; + B2 ps(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps); + B2 ps1(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps1); + + ParallelSGP parallel_sgp(kernels, sigma_e, sigma_f, sigma_s); + parallel_sgp.build(training_strucs, cutoff, dc, sparse_indices, n_types); + + // Compute likelihood and gradient + parallel_sgp.precompute_KnK(); + double likelihood_parallel = parallel_sgp.compute_likelihood_gradient_stable(true); + Eigen::VectorXd like_grad_parallel = parallel_sgp.likelihood_gradient; + + // Build a serial SparsGP and compare with ParallelSGP + if (blacs::mpirank == 0) { + parallel_sgp.write_mapping_coefficients("par_beta.txt", "Me", {0}, "potential"); + parallel_sgp.write_mapping_coefficients("par_beta_var.txt", "Me", {0}, "uncertainty"); + + // Check the kernel matrices are consistent + int n_clusters = 0; + for (int i = 0; i < parallel_sgp.n_kernels; i++) { + n_clusters += parallel_sgp.sparse_descriptors[i].n_clusters; + } + EXPECT_EQ(n_clusters, sparse_gp.Sigma.rows()); + EXPECT_EQ(parallel_sgp.sparse_descriptors[0].n_clusters, sparse_gp.sparse_descriptors[0].n_clusters); + for (int t = 0; t < parallel_sgp.sparse_descriptors[0].n_types; t++) { + for (int r = 0; r < parallel_sgp.sparse_descriptors[0].descriptors[t].rows(); r++) { + for (int c = 0; c < parallel_sgp.sparse_descriptors[0].descriptors[t].cols(); c++) { + double par_desc = parallel_sgp.sparse_descriptors[0].descriptors[t](r, c); + double sgp_desc = sparse_gp.sparse_descriptors[0].descriptors[t](r, c); + EXPECT_NEAR(par_desc, sgp_desc, 1e-6); + } + } + } + + for (int r = 0; r < parallel_sgp.Kuu.rows(); r++) { + for (int c = 0; c < parallel_sgp.Kuu.cols(); c++) { + // Sometimes the accuracy is between 1e-6 ~ 1e-5 + EXPECT_NEAR(parallel_sgp.Kuu(r, c), sparse_gp.Kuu(r, c), 1e-6); + } + } + std::cout << "Kuu matches" << std::endl; + + for (int r = 0; r < parallel_sgp.Kuu_inverse.rows(); r++) { + for (int c = 0; c < parallel_sgp.Kuu_inverse.rows(); c++) { + // Sometimes the accuracy is between 1e-6 ~ 1e-5 + EXPECT_NEAR(parallel_sgp.Kuu_inverse(r, c), sparse_gp.Kuu_inverse(r, c), 1e-5); + } + } + std::cout << "Kuu_inverse matches" << std::endl; + + for (int r = 0; r < parallel_sgp.alpha.size(); r++) { + EXPECT_NEAR(parallel_sgp.alpha(r), sparse_gp.alpha(r), 1e-6); + } + std::cout << "alpha matches" << std::endl; + + // Compare predictions on testing structure are consistent + parallel_sgp.predict_local_uncertainties(test_struc); + Structure test_struc_copy(test_struc.cell, test_struc.species, test_struc.positions, cutoff, dc); + sparse_gp.predict_local_uncertainties(test_struc_copy); + + for (int r = 0; r < test_struc.mean_efs.size(); r++) { + EXPECT_NEAR(test_struc.mean_efs(r), test_struc_copy.mean_efs(r), 1e-5); + } + std::cout << "mean_efs matches" << std::endl; + + for (int i = 0; i < test_struc.local_uncertainties.size(); i++) { + for (int r = 0; r < test_struc.local_uncertainties[i].size(); r++) { + EXPECT_NEAR(test_struc.local_uncertainties[i](r), test_struc_copy.local_uncertainties[i](r), 1e-5); + } + } + + // Test likelihood & gradient + double likelihood_serial = sparse_gp.compute_likelihood_gradient_stable(false); + std::cout << "likelihood: " << likelihood_serial << " " << likelihood_parallel << std::endl; + EXPECT_NEAR(likelihood_serial, likelihood_parallel, 1e-6); + + Eigen::VectorXd like_grad_serial = sparse_gp.likelihood_gradient; + EXPECT_EQ(like_grad_serial.size(), like_grad_parallel.size()); + for (int i = 0; i < like_grad_serial.size(); i++) { + EXPECT_NEAR(like_grad_serial(i), like_grad_parallel(i), 1e-6); + std::cout << "like grad " << like_grad_serial(i) << " " << like_grad_parallel(i) << std::endl; + } + + } + parallel_sgp.finalize_MPI = false; +} + +TEST_F(ParSGPTest, UpdateTrainSet){ + std::vector dc; + B2 ps(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps); + B2 ps1(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps1); + + ParallelSGP parallel_sgp_1(kernels, sigma_e, sigma_f, sigma_s); + bool update = false; + parallel_sgp_1.build(training_strucs, cutoff, dc, sparse_indices, n_types, update); + + // Make positions. + int n_atoms_3 = 11; + Eigen::MatrixXd cell_3, positions_3; + std::vector species_3; + Eigen::VectorXd labels_3; + + cell_3 = Eigen::MatrixXd::Identity(3, 3) * cell_size; + + positions_3 = Eigen::MatrixXd::Random(n_atoms_3, 3) * cell_size / 2; + MPI_Bcast(positions_3.data(), n_atoms_3 * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + labels_3 = Eigen::VectorXd::Random(1 + n_atoms_3 * 3 + 6); + MPI_Bcast(labels_3.data(), 1 + n_atoms_3 * 3 + 6, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + // Make random species. + for (int i = 0; i < n_atoms_3; i++) { + species_3.push_back(rand() % n_species); + } + MPI_Bcast(species_3.data(), n_atoms_3, MPI_INT, 0, MPI_COMM_WORLD); + + // Build kernel matrices for paralle sgp + //std::vector> comm_sparse_ind = {{0, 1, 5, 7}, {2, 3, 4}}; + std::vector> comm_sparse_ind = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 4, 5, 6, 7, 8, 9}}; + //sparse_indices = {comm_sparse_ind, comm_sparse_ind}; + + std::cout << "Start building" << std::endl; + Structure struc_3 = Structure(cell_3, species_3, positions_3); + struc_3.energy = labels_3.segment(0, 1); + struc_3.forces = labels_3.segment(1, n_atoms_3 * 3); + struc_3.stresses = labels_3.segment(1 + n_atoms_3 * 3, 6); + std::cout << "Done struc_3" << std::endl; + + training_strucs.push_back(struc_3); + for (int i = 0; i < dc.size(); i++) + sparse_indices[i].push_back(comm_sparse_ind[i]); + + update = true; + parallel_sgp_1.build(training_strucs, cutoff, dc, sparse_indices, n_types, update); + parallel_sgp_1.finalize_MPI = false; + + parallel_sgp_1.precompute_KnK(); + double likelihood_1 = parallel_sgp_1.compute_likelihood_gradient_stable(true); + Eigen::VectorXd like_grad_1 = parallel_sgp_1.likelihood_gradient; + + ParallelSGP parallel_sgp_2(kernels, sigma_e, sigma_f, sigma_s); + parallel_sgp_2.build(training_strucs, cutoff, dc, sparse_indices, n_types, false); + parallel_sgp_2.finalize_MPI = false; + + parallel_sgp_2.precompute_KnK(); + double likelihood_2 = parallel_sgp_2.compute_likelihood_gradient_stable(true); + Eigen::VectorXd like_grad_2 = parallel_sgp_2.likelihood_gradient; + + if (blacs::mpirank == 0) { + // Compare predictions on testing structure are consistent + Structure test_struc_1(test_struc.cell, test_struc.species, test_struc.positions, cutoff, dc); + parallel_sgp_1.predict_local_uncertainties(test_struc_1); + Structure test_struc_2(test_struc.cell, test_struc.species, test_struc.positions, cutoff, dc); + parallel_sgp_2.predict_local_uncertainties(test_struc_2); + + for (int r = 0; r < test_struc_1.mean_efs.size(); r++) { + EXPECT_NEAR(test_struc_1.mean_efs(r), test_struc_2.mean_efs(r), 1e-6); + } + std::cout << "parallel_sgp_1 and parallel_sgp_2 mean_efs matches" << std::endl; + + for (int i = 0; i < test_struc_1.local_uncertainties.size(); i++) { + for (int r = 0; r < test_struc_1.local_uncertainties[i].size(); r++) { + EXPECT_NEAR(test_struc_1.local_uncertainties[i](r), test_struc_2.local_uncertainties[i](r), 1e-6); + } + } + + EXPECT_NEAR(likelihood_1, likelihood_2, 1e-6); + EXPECT_EQ(like_grad_1.size(), like_grad_2.size()); + for (int i = 0; i < like_grad_1.size(); i++) { + EXPECT_NEAR(like_grad_1(i), like_grad_2(i), 1e-6); + } + } + +} + +TEST_F(ParSGPTest, ParPredict){ + std::vector dc; + B2 ps(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps); + B2 ps1(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps1); + + ParallelSGP parallel_sgp(kernels, sigma_e, sigma_f, sigma_s); + parallel_sgp.build(training_strucs, cutoff, dc, sparse_indices, n_types); + + std::cout << "create testing structures" << std::endl; + //std::vector> test_strucs_par; + std::vector test_strucs_par; + std::vector test_strucs_ser; + for (int t = 0; t < 10; t++) { + Eigen::MatrixXd cell_1, positions_1; + std::vector species_1; + Eigen::VectorXd labels_1; + + cell_1 = Eigen::MatrixXd::Identity(3, 3) * cell_size; + + positions_1 = Eigen::MatrixXd::Random(n_atoms_1, 3) * cell_size / 2; + MPI_Bcast(positions_1.data(), n_atoms_1 * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + labels_1 = Eigen::VectorXd::Random(1 + n_atoms_1 * 3 + 6); + MPI_Bcast(labels_1.data(), 1 + n_atoms_1 * 3 + 6, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + // Make random species. + for (int i = 0; i < n_atoms_1; i++) { + species_1.push_back(rand() % n_species); + } + MPI_Bcast(species_1.data(), n_atoms_1, MPI_INT, 0, MPI_COMM_WORLD); + + blacs::barrier(); + + //auto test_struc_1 = std::make_shared(cell_1, species_1, positions_1); + Structure test_struc_1(cell_1, species_1, positions_1); + test_strucs_par.push_back(test_struc_1); + + // Predict with serial SGP + Structure test_struc_2(cell_1, species_1, positions_1, cutoff, dc); + test_strucs_ser.push_back(test_struc_2); + + blacs::barrier(); + } + + test_strucs_par = parallel_sgp.predict_on_structures(test_strucs_par, cutoff, dc); + if (blacs::mpirank == 0) { + for (int t = 0; t < test_strucs_par.size(); t++) { + sparse_gp.predict_local_uncertainties(test_strucs_ser[t]); + for (int r = 0; r < test_strucs_ser[t].mean_efs.size(); r++) { + EXPECT_NEAR(test_strucs_par[t].mean_efs(r), test_strucs_ser[t].mean_efs(r), 1e-6); + std::cout << test_strucs_par[t].mean_efs(r) << " " << test_strucs_ser[t].mean_efs(r) << std::endl; + } + for (int a = 0; a < test_strucs_ser[t].noa; a++) { + for (int i = 0; i < dc.size(); i++) { + EXPECT_NEAR(test_strucs_par[t].local_uncertainties[i](a), test_strucs_ser[t].local_uncertainties[i](a), 1e-6); + } + } + } + } +} diff --git a/ctests/test_sparse_gp.cpp b/ctests/test_sparse_gp.cpp index 1a5c6738..362bf1f2 100644 --- a/ctests/test_sparse_gp.cpp +++ b/ctests/test_sparse_gp.cpp @@ -4,15 +4,6 @@ #include #include // Iota -// TEST(TestPar, TestPar){ -// std::cout << omp_get_max_threads() << std::endl; -// #pragma omp parallel for -// for (int atom = 0; atom < 4; atom++) { -// std::cout << omp_get_thread_num() << std::endl; -// std::this_thread::sleep_for(std::chrono::milliseconds(2000)); -// } -// } - TEST_F(StructureTest, SortTest){ double sigma_e = 1; double sigma_f = 2; @@ -148,14 +139,27 @@ TEST_F(StructureTest, TestAdd){ TEST_F(StructureTest, LikeGrad) { // Check that the DTC likelihood gradient is correctly computed. - double sigma_e = 1; - double sigma_f = 2; - double sigma_s = 3; + double sigma_e = 0.1; + double sigma_f = 0.1; + double sigma_s = 0.1; std::vector kernels; kernels.push_back(&kernel_3); SparseGP sparse_gp = SparseGP(kernels, sigma_e, sigma_f, sigma_s); + std::vector dc; + + descriptor_settings = {n_species, 1, N, 0}; + Bk b1(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&b1); + descriptor_settings = {n_species, 2, N, L}; + Bk b2(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&b2); + + test_struc = Structure(cell, species, positions, cutoff, dc); + Eigen::VectorXd energy = Eigen::VectorXd::Random(1); Eigen::VectorXd forces = Eigen::VectorXd::Random(n_atoms * 3); Eigen::VectorXd stresses = Eigen::VectorXd::Random(6); @@ -163,8 +167,8 @@ TEST_F(StructureTest, LikeGrad) { test_struc.forces = forces; test_struc.stresses = stresses; - sparse_gp.add_training_structure(test_struc); - sparse_gp.add_all_environments(test_struc); + sparse_gp.add_training_structure(test_struc, {0, 1, 3, 5}); + sparse_gp.add_specific_environments(test_struc, {{0, 2}, {0, 1, 3}}); EXPECT_EQ(sparse_gp.Sigma.rows(), 0); EXPECT_EQ(sparse_gp.Kuu_inverse.rows(), 0); @@ -172,8 +176,8 @@ TEST_F(StructureTest, LikeGrad) { sparse_gp.update_matrices_QR(); // Test mapping coefficients. - sparse_gp.write_mapping_coefficients("beta.txt", "Jon V", 0); - sparse_gp.write_varmap_coefficients("beta_var.txt", "YX", 0); + sparse_gp.write_mapping_coefficients("beta.txt", "Jon V", {0}, "potential"); + sparse_gp.write_mapping_coefficients("beta_var.txt", "YX", {0}, "uncertainty"); // Check the likelihood function. Eigen::VectorXd hyps = sparse_gp.hyperparameters; @@ -182,7 +186,7 @@ TEST_F(StructureTest, LikeGrad) { int n_hyps = hyps.size(); Eigen::VectorXd hyps_up, hyps_down; - double pert = 1e-4, like_up, like_down, fin_diff; + double pert = 1e-6, like_up, like_down, fin_diff; for (int i = 0; i < n_hyps; i++) { hyps_up = hyps; @@ -195,10 +199,114 @@ TEST_F(StructureTest, LikeGrad) { fin_diff = (like_up - like_down) / (2 * pert); - EXPECT_NEAR(like_grad(i), fin_diff, 1e-7); + EXPECT_NEAR(like_grad(i), fin_diff, 2e-6); } } +TEST_F(StructureTest, LikeGradStable) { + // Check that the DTC likelihood gradient is correctly computed. + double sigma_e = 0.1; + double sigma_f = 0.1; + double sigma_s = 0.1; + + std::vector kernels; + kernels.push_back(&kernel_norm); + kernels.push_back(&kernel_3_norm); + SparseGP sparse_gp = SparseGP(kernels, sigma_e, sigma_f, sigma_s); + + std::vector dc; + + descriptor_settings = {n_species, 1, N, 0}; + Bk b1(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&b1); + descriptor_settings = {n_species, 2, N, L}; + Bk b2(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&b2); + + test_struc = Structure(cell, species, positions, cutoff, dc); + + Eigen::VectorXd energy = Eigen::VectorXd::Random(1); + Eigen::VectorXd forces = Eigen::VectorXd::Random(n_atoms * 3); + Eigen::VectorXd stresses = Eigen::VectorXd::Random(6); + test_struc.energy = energy; + test_struc.forces = forces; + test_struc.stresses = stresses; + + sparse_gp.add_training_structure(test_struc, {0, 1, 3, 5}); + sparse_gp.add_specific_environments(test_struc, {{0, 2}, {0, 1, 3}}); + + EXPECT_EQ(sparse_gp.Sigma.rows(), 0); + EXPECT_EQ(sparse_gp.Kuu_inverse.rows(), 0); + + sparse_gp.update_matrices_QR(); + + // Check the likelihood function. + Eigen::VectorXd hyps = sparse_gp.hyperparameters; + sparse_gp.set_hyperparameters(hyps); + sparse_gp.precompute_KnK(); + double like = sparse_gp.compute_likelihood_gradient_stable(true); + + // Debug: check KnK + Eigen::MatrixXd KnK_e = sparse_gp.Kuf * sparse_gp.e_noise_one.asDiagonal() * sparse_gp.Kuf.transpose(); + for (int i = 0; i < KnK_e.rows(); i++) { + for (int j = 0; j < KnK_e.rows(); j++) { + EXPECT_NEAR(KnK_e(i, j), sparse_gp.KnK_e(i, j), 1e-8); + } + } + + Eigen::VectorXd like_grad = sparse_gp.likelihood_gradient; + + // Check the likelihood function. + double like_original = sparse_gp.compute_likelihood_gradient(hyps); + Eigen::VectorXd like_grad_original = sparse_gp.likelihood_gradient; + EXPECT_NEAR(like, like_original, 1e-7); + + int n_hyps = hyps.size(); + Eigen::VectorXd hyps_up, hyps_down; + double pert = 1e-6, like_up, like_down, fin_diff; + + for (int i = 0; i < n_hyps; i++) { + hyps_up = hyps; + hyps_down = hyps; + hyps_up(i) += pert; + hyps_down(i) -= pert; + + sparse_gp.set_hyperparameters(hyps_up); + like_up = sparse_gp.compute_likelihood_gradient_stable(); + double datafit_up = sparse_gp.data_fit; + double complexity_up = sparse_gp.complexity_penalty; + + sparse_gp.set_hyperparameters(hyps_down); + like_down = sparse_gp.compute_likelihood_gradient_stable(); + double datafit_down = sparse_gp.data_fit; + double complexity_down = sparse_gp.complexity_penalty; + + fin_diff = (like_up - like_down) / (2 * pert); + + EXPECT_NEAR(like_grad(i), fin_diff, 1e-6); + EXPECT_NEAR(like_grad(i), like_grad_original(i), 1e-7); + } + + Eigen::VectorXd new_hyps = hyps + Eigen::VectorXd::Ones(hyps.size()); + sparse_gp.set_hyperparameters(new_hyps); + double like_plus1 = sparse_gp.compute_likelihood_gradient_stable(); + Eigen::VectorXd like_grad_plus1 = sparse_gp.likelihood_gradient; + + sparse_gp.set_hyperparameters(hyps); + double like_plusminus1 = sparse_gp.compute_likelihood_gradient_stable(); + Eigen::VectorXd like_grad_plusminus1 = sparse_gp.likelihood_gradient; + + EXPECT_NEAR(like, like_plusminus1, 1e-7); + std::cout << like << " " << like_plusminus1 << std::endl; + for (int i = 0; i < n_hyps; i++) { + EXPECT_NEAR(like_grad(i), like_grad_plusminus1(i), 1e-7); + std::cout << like_grad(i) << " " << like_grad_plusminus1(i) << std::endl; + } +} + + TEST_F(StructureTest, Set_Hyps) { // Check the reset hyperparameters method. @@ -230,16 +338,19 @@ TEST_F(StructureTest, Set_Hyps) { test_struc_2.stresses = stresses_2; // Add sparse environments and training structures. + std::cout << "adding training structure" << std::endl; sparse_gp_1.add_training_structure(test_struc); sparse_gp_1.add_training_structure(test_struc_2); sparse_gp_1.add_all_environments(test_struc); sparse_gp_1.add_all_environments(test_struc_2); + std::cout << "adding training structure" << std::endl; sparse_gp_2.add_training_structure(test_struc); sparse_gp_2.add_training_structure(test_struc_2); sparse_gp_2.add_all_environments(test_struc); sparse_gp_2.add_all_environments(test_struc_2); + std::cout << "updating matrices" << std::endl; sparse_gp_1.update_matrices_QR(); sparse_gp_2.update_matrices_QR(); @@ -267,16 +378,17 @@ TEST_F(StructureTest, AddOrder) { double sigma_s = 3; std::vector kernels; - kernels.push_back(&kernel_3); + kernels.push_back(&kernel_3_norm); SparseGP sparse_gp_1 = SparseGP(kernels, sigma_e, sigma_f, sigma_s); SparseGP sparse_gp_2 = SparseGP(kernels, sigma_e, sigma_f, sigma_s); + test_struc = Structure(cell, species, positions, cutoff, {&ps}); Eigen::VectorXd energy = Eigen::VectorXd::Random(1); Eigen::VectorXd forces = Eigen::VectorXd::Random(n_atoms * 3); Eigen::VectorXd stresses = Eigen::VectorXd::Random(6); -// test_struc.energy = energy; - test_struc.forces = forces; - // test_struc.stresses = stresses; + test_struc.energy = energy; + test_struc.forces = forces; + test_struc.stresses = stresses; // Add structure first. sparse_gp_1.add_training_structure(test_struc); @@ -289,15 +401,78 @@ TEST_F(StructureTest, AddOrder) { sparse_gp_2.update_matrices_QR(); // Check that matrices match. + EXPECT_EQ(sparse_gp_1.y.size(), sparse_gp_2.y.size()); + for (int i = 0; i < sparse_gp_1.y.size(); i++) { + EXPECT_NEAR(sparse_gp_1.y(i), sparse_gp_2.y(i), 1e-8); + } + + EXPECT_EQ(sparse_gp_1.Kuf.rows(), sparse_gp_2.Kuf.rows()); + EXPECT_EQ(sparse_gp_1.Kuf.cols(), sparse_gp_2.Kuf.cols()); + for (int i = 0; i < sparse_gp_1.Kuf.rows(); i++) { + for (int j = 0; j < sparse_gp_1.Kuf.cols(); j++) { + EXPECT_NEAR(sparse_gp_1.Kuf(i, j), sparse_gp_2.Kuf(i, j), 1e-8); + } + } + + EXPECT_EQ(sparse_gp_1.Kuu.rows(), sparse_gp_2.Kuu.rows()); + EXPECT_EQ(sparse_gp_1.Kuu.cols(), sparse_gp_2.Kuu.cols()); + for (int i = 0; i < sparse_gp_1.Kuu.rows(); i++) { + for (int j = 0; j < sparse_gp_1.Kuu.cols(); j++) { + EXPECT_NEAR(sparse_gp_1.Kuu(i, j), sparse_gp_2.Kuu(i, j), 1e-8); + } + } +} + +TEST_F(StructureTest, AtomIndices) { + double sigma_e = 1; + double sigma_f = 2; + double sigma_s = 3; + + std::vector kernels; + kernels.push_back(&kernel_3); + SparseGP sparse_gp_1 = SparseGP(kernels, sigma_e, sigma_f, sigma_s); + SparseGP sparse_gp_2 = SparseGP(kernels, sigma_e, sigma_f, sigma_s); + + Eigen::VectorXd energy = Eigen::VectorXd::Random(1); + Eigen::VectorXd forces = Eigen::VectorXd::Random(n_atoms * 3); + Eigen::VectorXd stresses = Eigen::VectorXd::Random(6); + test_struc.energy = energy; + test_struc.forces = forces; + test_struc.stresses = stresses; + + // Add structure first. + sparse_gp_1.add_training_structure(test_struc, {-1}); + sparse_gp_1.add_all_environments(test_struc); + sparse_gp_1.update_matrices_QR(); + + // Add environments first. + std::vector atom_indices; + for (int i = 0; i < n_atoms; i++) { + atom_indices.push_back(i); + } + sparse_gp_2.add_training_structure(test_struc, atom_indices); + sparse_gp_2.add_all_environments(test_struc); + sparse_gp_2.update_matrices_QR(); + + // Check that matrices match. + EXPECT_EQ(sparse_gp_1.y.size(), sparse_gp_2.y.size()); + for (int i = 0; i < sparse_gp_1.y.size(); i++) { + EXPECT_NEAR(sparse_gp_1.y(i), sparse_gp_2.y(i), 1e-8); + } + + EXPECT_EQ(sparse_gp_1.Kuf.rows(), sparse_gp_2.Kuf.rows()); + EXPECT_EQ(sparse_gp_1.Kuf.cols(), sparse_gp_2.Kuf.cols()); for (int i = 0; i < sparse_gp_1.Kuf.rows(); i++) { for (int j = 0; j < sparse_gp_1.Kuf.cols(); j++) { - EXPECT_EQ(sparse_gp_1.Kuf(i, j), sparse_gp_2.Kuf(i, j)); + EXPECT_NEAR(sparse_gp_1.Kuf(i, j), sparse_gp_2.Kuf(i, j), 1e-8); } } + EXPECT_EQ(sparse_gp_1.Kuu.rows(), sparse_gp_2.Kuu.rows()); + EXPECT_EQ(sparse_gp_1.Kuu.cols(), sparse_gp_2.Kuu.cols()); for (int i = 0; i < sparse_gp_1.Kuu.rows(); i++) { for (int j = 0; j < sparse_gp_1.Kuu.cols(); j++) { - EXPECT_EQ(sparse_gp_1.Kuu(i, j), sparse_gp_2.Kuu(i, j)); + EXPECT_NEAR(sparse_gp_1.Kuu(i, j), sparse_gp_2.Kuu(i, j), 1e-8); } } } diff --git a/ctests/test_structure.cpp b/ctests/test_structure.cpp index a4a1d453..784dfcb0 100644 --- a/ctests/test_structure.cpp +++ b/ctests/test_structure.cpp @@ -142,7 +142,7 @@ TEST_F(StructureTest, SqExpGrad) { } } -TEST_F(StructureTest, StrucStrucFull) { +TEST_F(StructureTest, EnergyForceKernel) { // TODO: Systematically test all implemented descriptors and kernels. @@ -173,11 +173,12 @@ TEST_F(StructureTest, StrucStrucFull) { double delta = 1e-4; double thresh = 1e-4; - // Check energy/force kernel. Eigen::MatrixXd positions_3, positions_4, positions_5, positions_6, cell_3, cell_4, cell_5, cell_6, kern_pert, kern_pert_2, kern_pert_3, kern_pert_4; Structure test_struc_3, test_struc_4, test_struc_5, test_struc_6; double fin_val, exact_val, abs_diff; + + // Check energy/force kernel. for (int p = 0; p < test_struc_2.noa; p++) { for (int m = 0; m < 3; m++) { positions_3 = positions_4 = positions_2; @@ -199,6 +200,24 @@ TEST_F(StructureTest, StrucStrucFull) { EXPECT_NEAR(fin_val, exact_val, thresh); } } +} + +TEST_F(StructureTest, ForceEnergyKernel) { + test_struc = Structure(cell, species, positions, cutoff, dc); + test_struc_2 = Structure(cell_2, species_2, positions_2, cutoff, dc); + struc_desc = test_struc.descriptors[0]; + + // Compute full kernel matrix. + Eigen::MatrixXd kernel_matrix = kernel.struc_struc( + struc_desc, test_struc_2.descriptors[0], kernel.kernel_hyperparameters); + + double delta = 1e-4; + double thresh = 1e-4; + + Eigen::MatrixXd positions_3, positions_4, positions_5, positions_6, cell_3, + cell_4, cell_5, cell_6, kern_pert, kern_pert_2, kern_pert_3, kern_pert_4; + Structure test_struc_3, test_struc_4, test_struc_5, test_struc_6; + double fin_val, exact_val, abs_diff; // Check force/energy kernel. for (int p = 0; p < test_struc.noa; p++) { @@ -222,6 +241,24 @@ TEST_F(StructureTest, StrucStrucFull) { EXPECT_NEAR(fin_val, exact_val, thresh); } } +} + +TEST_F(StructureTest, EnergyStressKernel) { + test_struc = Structure(cell, species, positions, cutoff, dc); + test_struc_2 = Structure(cell_2, species_2, positions_2, cutoff, dc); + struc_desc = test_struc.descriptors[0]; + + // Compute full kernel matrix. + Eigen::MatrixXd kernel_matrix = kernel.struc_struc( + struc_desc, test_struc_2.descriptors[0], kernel.kernel_hyperparameters); + + double delta = 1e-4; + double thresh = 1e-4; + + Eigen::MatrixXd positions_3, positions_4, positions_5, positions_6, cell_3, + cell_4, cell_5, cell_6, kern_pert, kern_pert_2, kern_pert_3, kern_pert_4; + Structure test_struc_3, test_struc_4, test_struc_5, test_struc_6; + double fin_val, exact_val, abs_diff; // Check energy/stress. int stress_ind_1 = 0; @@ -262,9 +299,27 @@ TEST_F(StructureTest, StrucStrucFull) { stress_ind_1++; } } +} + +TEST_F(StructureTest, StressEnergyKernel) { + test_struc = Structure(cell, species, positions, cutoff, dc); + test_struc_2 = Structure(cell_2, species_2, positions_2, cutoff, dc); + struc_desc = test_struc.descriptors[0]; + + // Compute full kernel matrix. + Eigen::MatrixXd kernel_matrix = kernel.struc_struc( + struc_desc, test_struc_2.descriptors[0], kernel.kernel_hyperparameters); + double delta = 1e-4; + double thresh = 1e-4; + + Eigen::MatrixXd positions_3, positions_4, positions_5, positions_6, cell_3, + cell_4, cell_5, cell_6, kern_pert, kern_pert_2, kern_pert_3, kern_pert_4; + Structure test_struc_3, test_struc_4, test_struc_5, test_struc_6; + double fin_val, exact_val, abs_diff; + // Check stress/energy. - stress_ind_1 = 0; + int stress_ind_1 = 0; for (int m = 0; m < 3; m++) { for (int n = m; n < 3; n++) { cell_3 = cell_4 = cell; @@ -302,7 +357,26 @@ TEST_F(StructureTest, StrucStrucFull) { stress_ind_1++; } } +} + +TEST_F(StructureTest, ForceForceKernel) { + test_struc = Structure(cell, species, positions, cutoff, dc); + test_struc_2 = Structure(cell_2, species_2, positions_2, cutoff, dc); + struc_desc = test_struc.descriptors[0]; + + // Compute full kernel matrix. + Eigen::MatrixXd kernel_matrix = kernel.struc_struc( + struc_desc, test_struc_2.descriptors[0], kernel.kernel_hyperparameters); + + double delta = 1e-4; + double thresh = 1e-4; + + Eigen::MatrixXd positions_3, positions_4, positions_5, positions_6, cell_3, + cell_4, cell_5, cell_6, kern_pert, kern_pert_2, kern_pert_3, kern_pert_4; + Structure test_struc_3, test_struc_4, test_struc_5, test_struc_6; + double fin_val, exact_val, abs_diff; + // Check force/force kernel. for (int m = 0; m < test_struc.noa; m++) { for (int n = 0; n < 3; n++) { @@ -349,7 +423,26 @@ TEST_F(StructureTest, StrucStrucFull) { } } } +} + + +TEST_F(StructureTest, ForceStressKernel) { + test_struc = Structure(cell, species, positions, cutoff, dc); + test_struc_2 = Structure(cell_2, species_2, positions_2, cutoff, dc); + struc_desc = test_struc.descriptors[0]; + // Compute full kernel matrix. + Eigen::MatrixXd kernel_matrix = kernel.struc_struc( + struc_desc, test_struc_2.descriptors[0], kernel.kernel_hyperparameters); + + double delta = 1e-4; + double thresh = 1e-4; + + Eigen::MatrixXd positions_3, positions_4, positions_5, positions_6, cell_3, + cell_4, cell_5, cell_6, kern_pert, kern_pert_2, kern_pert_3, kern_pert_4; + Structure test_struc_3, test_struc_4, test_struc_5, test_struc_6; + double fin_val, exact_val, abs_diff; + // Check force/stress kernel. for (int m = 0; m < test_struc.noa; m++) { for (int n = 0; n < 3; n++) { @@ -413,7 +506,26 @@ TEST_F(StructureTest, StrucStrucFull) { } } } +} + +TEST_F(StructureTest, StressForceKernel) { + test_struc = Structure(cell, species, positions, cutoff, dc); + test_struc_2 = Structure(cell_2, species_2, positions_2, cutoff, dc); + struc_desc = test_struc.descriptors[0]; + + // Compute full kernel matrix. + Eigen::MatrixXd kernel_matrix = kernel.struc_struc( + struc_desc, test_struc_2.descriptors[0], kernel.kernel_hyperparameters); + + double delta = 1e-4; + double thresh = 1e-4; + + Eigen::MatrixXd positions_3, positions_4, positions_5, positions_6, cell_3, + cell_4, cell_5, cell_6, kern_pert, kern_pert_2, kern_pert_3, kern_pert_4; + Structure test_struc_3, test_struc_4, test_struc_5, test_struc_6; + double fin_val, exact_val, abs_diff; + // Check stress/force kernel. for (int m = 0; m < test_struc_2.noa; m++) { for (int n = 0; n < 3; n++) { @@ -479,9 +591,28 @@ TEST_F(StructureTest, StrucStrucFull) { } } } +} + + +TEST_F(StructureTest, StressStressKernel) { + test_struc = Structure(cell, species, positions, cutoff, dc); + test_struc_2 = Structure(cell_2, species_2, positions_2, cutoff, dc); + struc_desc = test_struc.descriptors[0]; + + // Compute full kernel matrix. + Eigen::MatrixXd kernel_matrix = kernel.struc_struc( + struc_desc, test_struc_2.descriptors[0], kernel.kernel_hyperparameters); + double delta = 1e-4; + double thresh = 1e-4; + + Eigen::MatrixXd positions_3, positions_4, positions_5, positions_6, cell_3, + cell_4, cell_5, cell_6, kern_pert, kern_pert_2, kern_pert_3, kern_pert_4; + Structure test_struc_3, test_struc_4, test_struc_5, test_struc_6; + double fin_val, exact_val, abs_diff; + // Check stress/stress kernel. - stress_ind_1 = 0; + int stress_ind_1 = 0; for (int m = 0; m < 3; m++) { for (int n = m; n < 3; n++) { cell_3 = cell_4 = cell; diff --git a/ctests/test_structure.h b/ctests/test_structure.h index 255496b3..b9d33d89 100644 --- a/ctests/test_structure.h +++ b/ctests/test_structure.h @@ -1,7 +1,7 @@ #include "b2.h" #include "b2_norm.h" #include "b2_simple.h" -#include "b3.h" +#include "bk.h" #include "structure.h" #include "four_body.h" #include "normalized_dot_product.h" @@ -45,7 +45,7 @@ class StructureTest : public ::testing::Test { double sigma = 2.0; double ls = 0.9; - int power = 1; + int power = 2; SquaredExponential kernel_3; SquaredExponential kernel; NormalizedDotProduct kernel_norm, kernel_3_norm; @@ -85,7 +85,7 @@ class StructureTest : public ::testing::Test { struc_desc = test_struc.descriptors[0]; kernel = SquaredExponential(sigma, ls); - kernel_norm = NormalizedDotProduct(sigma, power); + kernel_norm = NormalizedDotProduct(sigma + 10.0, power); icm_coeffs = Eigen::MatrixXd::Zero(3, 3); // icm_coeffs << 1, 2, 3, 2, 3, 4, 3, 4, 5; diff --git a/ctests/test_utils.cpp b/ctests/test_utils.cpp new file mode 100644 index 00000000..2ff87c8e --- /dev/null +++ b/ctests/test_utils.cpp @@ -0,0 +1,69 @@ +#include "utils.h" +#include "structure.h" +#include "gtest/gtest.h" +#include +#include +#include + +class UtilsTest : public ::testing::Test { +protected: + std::vector struc_list; + std::vector>> sparse_indices; + std::string filename = std::string("dft_data.xyz"); + std::map species_map = {{"H", 0,}, {"He", 1,}}; +}; + +TEST_F(UtilsTest, XYZTest) { + std::tie(struc_list, sparse_indices) = utils::read_xyz(filename, species_map); + std::cout << "Done read" << std::endl; + + // Test if species are read correctly + EXPECT_EQ(struc_list[0].species[0], 1); + EXPECT_EQ(struc_list[1].species[2], 0); + EXPECT_EQ(struc_list[2].species[4], 1); + std::cout << "species matches" << std::endl; + + // Test lattice + Eigen::MatrixXd cell = Eigen::MatrixXd::Identity(3, 3); + for (int s = 0; s < struc_list.size(); s++) { + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + EXPECT_NEAR(struc_list[s].cell(i, j), cell(i, j), 1e-8); + } + } + } + std::cout << "cell matches" << std::endl; + + // Test if positions are read correctly + EXPECT_NEAR(struc_list[0].positions(1, 0), 0.69338569, 1e-8); + EXPECT_NEAR(struc_list[1].positions(2, 2), 0.01626457, 1e-8); + EXPECT_NEAR(struc_list[2].positions(3, 1), 0.87700387, 1e-8); + std::cout << "pos matches" << std::endl; + + // Test if energy are read correctly + EXPECT_NEAR(struc_list[0].energy(0), 0.9133510291442271, 1e-8); + EXPECT_NEAR(struc_list[1].energy(0), 0.5845339958967781, 1e-8); + EXPECT_NEAR(struc_list[2].energy(0), 0.2679636480618098, 1e-8); + std::cout << "energy matches" << std::endl; + + // Test if forces are read correctly + EXPECT_NEAR(struc_list[0].forces(1 * 3 + 0), 0.92320467, 1e-8); + EXPECT_NEAR(struc_list[1].forces(2 * 3 + 2), 0.02892780, 1e-8); + EXPECT_NEAR(struc_list[2].forces(3 * 3 + 1), 0.56835792, 1e-8); + std::cout << "forces matches" << std::endl; + + // Test if stress are read correctly + // the struc stress are the xx, xy, xz, yy, yz, zz of xyz file + EXPECT_NEAR(struc_list[0].stresses(2), -0.38721540888657746, 1e-8); + EXPECT_NEAR(struc_list[1].stresses(3), -0.04982809557310608, 1e-8); + EXPECT_NEAR(struc_list[2].stresses(5), -0.00887571986324354, 1e-8); + std::cout << "stress matches" << std::endl; + + // Test if sparse indices are read correctly + EXPECT_EQ(sparse_indices[0][0].size(), 0); + EXPECT_EQ(sparse_indices[0][1].size(), 1); + EXPECT_EQ(sparse_indices[0][2].size(), 3); + EXPECT_EQ(sparse_indices[0][1][0], 2); + EXPECT_EQ(sparse_indices[0][2][2], 4); + std::cout << "sparse ind matches" << std::endl; +} diff --git a/docs/conf.py b/docs/conf.py index 5c058ad4..54e98415 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,8 +11,9 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os -# import sys +import sys # sys.path.insert(0, os.path.abspath('.')) +sys.path.append("..") # -- Project information ----------------------------------------------------- @@ -32,8 +33,14 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ["breathe"] - +extensions = [ + 'breathe', + 'sphinx.ext.autodoc', + 'sphinx.ext.imgmath', + 'sphinx_rtd_theme', + 'sphinx.ext.napoleon', + 'nbsphinx', +] # Breathe configuration breathe_projects = {"flare_pp": "xml"} breathe_default_project = "flare_pp" diff --git a/docs/flare_pp/bffs.rst b/docs/flare_pp/bffs.rst index abb81a1b..f49caf10 100644 --- a/docs/flare_pp/bffs.rst +++ b/docs/flare_pp/bffs.rst @@ -3,3 +3,6 @@ Bayesian force fields .. doxygenclass:: SparseGP :members: + +.. doxygenclass:: ParallelSGP + :members: diff --git a/docs/index.rst b/docs/index.rst index c047af6b..36886c8c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,3 +13,4 @@ Contents :maxdepth: 4 flare_pp/flare_pp + python_api/index diff --git a/docs/python_api/index.rst b/docs/python_api/index.rst new file mode 100644 index 00000000..dc22cd67 --- /dev/null +++ b/docs/python_api/index.rst @@ -0,0 +1,9 @@ +Python API +========== + +.. toctree:: + :maxdepth: 3 + + sgp_wrapper + sgp_calculator + diff --git a/docs/python_api/sgp_calculator.rst b/docs/python_api/sgp_calculator.rst new file mode 100644 index 00000000..bceea984 --- /dev/null +++ b/docs/python_api/sgp_calculator.rst @@ -0,0 +1,5 @@ +ASE Calculator of Sparse GP +=========================== + +.. automodule:: flare_pp.sparse_gp_calculator + :members: diff --git a/docs/python_api/sgp_wrapper.rst b/docs/python_api/sgp_wrapper.rst new file mode 100644 index 00000000..fcd343b4 --- /dev/null +++ b/docs/python_api/sgp_wrapper.rst @@ -0,0 +1,5 @@ +Sparse GP Wrapper +================= + +.. automodule:: flare_pp.sparse_gp + :members: diff --git a/flare_pp/parallel_sgp.py b/flare_pp/parallel_sgp.py new file mode 100644 index 00000000..9c086b61 --- /dev/null +++ b/flare_pp/parallel_sgp.py @@ -0,0 +1,251 @@ +import json +import numpy as np +from copy import deepcopy +from scipy.optimize import minimize +from typing import List, Union, Tuple +import warnings +from ase.calculators.singlepoint import SinglePointCalculator +from flare import struc +from flare.ase.atoms import FLARE_Atoms +from flare.utils.element_coder import NumpyEncoder +from flare.utils.learner import is_std_in_bound + +from mpi4py import MPI +from memory_profiler import profile + +from flare_pp._C_flare import ParallelSGP +from flare_pp.sparse_gp import SGP_Wrapper +from flare_pp.sparse_gp_calculator import SGP_Calculator, sort_variances +from flare_pp.utils import convert_to_flarepp_structure + + +class ParSGP_Wrapper(SGP_Wrapper): + """Wrapper class used to make the C++ sparse GP object compatible with + OTF. Methods and properties are designed to mirror the GP class.""" + + def __init__( + self, + kernels: List, + descriptor_calculators: List, + cutoff: float, + sigma_e: float, + sigma_f: float, + sigma_s: float, + species_map: dict, + variance_type: str = "SOR", + single_atom_energies: dict = None, + energy_training=True, + force_training=True, + stress_training=True, + max_iterations=10, + opt_method="BFGS", + bounds=None, + ): + + super().__init__( + kernels=kernels, + descriptor_calculators=descriptor_calculators, + cutoff=cutoff, + sigma_e=sigma_e, + sigma_f=sigma_f, + sigma_s=sigma_s, + species_map=species_map, + variance_type=variance_type, + single_atom_energies=single_atom_energies, + energy_training=energy_training, + force_training=force_training, + stress_training=stress_training, + max_iterations=max_iterations, + opt_method=opt_method, + bounds=bounds, + ) + self.sparse_gp = ParallelSGP(kernels, sigma_e, sigma_f, sigma_s) + self.training_structures = [] + self.training_sparse_indices = [[] for i in range(len(descriptor_calculators))] + + @property + def training_data(self): + # global training dataset + return self.training_structures + + @property + def local_training_data(self): + # local training dataset on the current process + return self.sparse_gp.training_structures + + @profile + def build( + self, + training_strucs: Union[List[struc.Structure], List[FLARE_Atoms]], + training_sparse_indices: List[List[List[int]]], + update, + ): + # Check the shape of sparse_indices + assert ( + len(training_sparse_indices[0][0]) >= 0 + ), """Sparse indices should be a list + [[[atom1_of_kernel1_of_struc1, ...], + [atom1_of_kernel2_of_struc1, ...]], + [[atom1_of_kernel1_of_struc2, ...], + [atom1_of_kernel2_of_struc2, ...]]]""" + + # Convert flare Structure or FLARE_Atoms to flare_pp Structure + struc_list = [] + for structure in training_strucs: + try: + energy = structure.energy + except AttributeError: + energy = structure.potential_energy + forces = structure.forces + + # Convert stress order to xx, xy, xz, yy, yz, zz + s = structure.stress + stress = None + if s is not None: + if len(s) == 6: + stress = -s[[0, 5, 4, 1, 3, 2]] + elif s.shape == (3, 3): + stress = -np.array( + [s[0, 0], s[0, 1], s[0, 2], s[1, 1], s[1, 2], s[2, 2]] + ) + + structure_descriptor = convert_to_flarepp_structure( + structure, + self.species_map, + energy, + forces, + stress, + self.energy_training, + self.force_training, + self.stress_training, + self.single_atom_energies, + cutoff=None, + descriptor_calculators=None, + ) + + struc_list.append(structure_descriptor) + + n_types = len(self.species_map) + self.sparse_gp.build( + struc_list, + self.cutoff, + self.descriptor_calculators, + training_sparse_indices, + n_types, + update, + ) + + self.training_structures = training_strucs + self.training_sparse_indices = training_sparse_indices + + + @profile + def update_db( + self, + structure, + custom_range=(), + mode: str = "all", + ): + if mode == "all": + sparse_inds = [i for i in range(len(structure))] + elif mode == "uncertain": + gp_calc = SGP_Calculator(self) + uncertainties = gp_calc.get_uncertainties(structure) + if len(custom_range) == len(self.descriptor_calculators): + sparse_inds = [] + for i in range(len(self.descriptor_calculators)): + sorted_ind = np.argsort(-uncertainties[:, i]).tolist() + sparse_inds.append(sorted_ind[: custom_range[i]]) + else: + raise Exception( + "The custom_range should length equal to the number of descriptors/kernels if mode='uncertain'" + ) + elif mode == "specific": + if len(custom_range) == len(self.descriptor_calculators): + sparse_inds = custom_range + else: + raise Exception( + "The custom_range should length equal to the number of descriptors/kernels if mode='specific'" + ) + + elif mode == "random": + if len(custom_range) == len(self.descriptor_calculators): + sparse_inds = [ + np.random.choice( + len(structure), size=custom_range[i], replace=False + ).tolist() + for i in range(len(self.descriptor_calculators)) + ] + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + sparse_inds = comm.bcast(sparse_inds, root=0) + else: + raise Exception( + "The custom_range should length equal to the number of descriptors/kernels if mode='random'" + ) + else: + raise NotImplementedError + + self.training_structures.append(structure) + for k in range(len(self.descriptor_calculators)): + self.training_sparse_indices[k].append(sparse_inds[k]) + + # build a new SGP + print(self.training_sparse_indices) + if len(self.training_structures) == 1: + update = False + else: + update = True + self.build(self.training_structures, self.training_sparse_indices, update=update) + + def predict_on_structures(self, struc_list): + # convert ASE Atoms to c++ Structure with descriptors not computed + struc_desc_list = [] + for structure in struc_list: + structure_descriptor = convert_to_flarepp_structure(structure, self.species_map) + struc_desc_list.append(structure_descriptor) + + struc_desc_list = self.sparse_gp.predict_on_structures( + struc_desc_list, self.cutoff, self.descriptor_calculators + ) + + for s in range(len(struc_list)): + results = {} + mean_efs = deepcopy(struc_desc_list[s].mean_efs) + results["energy"] = mean_efs[0] + results["forces"] = mean_efs[1:-6].reshape(-1, 3) + + # Convert stress to ASE format. + flare_stress = mean_efs[-6:] + ase_stress = - flare_stress[[0, 3, 5, 4, 2, 1]] + results["stress"] = ase_stress + + ## save uncertainties + ## TODO: add "atom_indices" attribute to struc_desc_list for sort_variances() + + #n_kern = len(self.descriptor_calculators) + #stds_full = np.zeros((len(struc_list), 3)) + #assert n_kern <= 3, NotImplementedError # now only print out 3 components + + #for k in range(n_kern): + # variances = all_results[s][1][k] + # sorted_variances = sort_variances(struc_desc_list[s], variances) + # stds = np.zeros(len(sorted_variances)) + # for n in range(len(sorted_variances)): + # var = sorted_variances[n] + # if var > 0: + # stds[n] = np.sqrt(var) + # else: + # stds[n] = -np.sqrt(np.abs(var)) + + # # Divide by the signal std to get a unitless value. + # stds_full[:, k] = stds / self.gp_model.hyps[k] + + #results["stds"] = stds_full + + if struc_list[s].calc is not None: + struc_list[s].calc.results = results + else: + calc = SinglePointCalculator(struc_list[s]) + calc.results = results + struc_list[s].calc = calc diff --git a/flare_pp/sparse_gp.py b/flare_pp/sparse_gp.py index 031d0dc6..a12fa34d 100644 --- a/flare_pp/sparse_gp.py +++ b/flare_pp/sparse_gp.py @@ -1,8 +1,6 @@ import json from time import time import numpy as np -from flare_pp import _C_flare -from flare_pp._C_flare import SparseGP, Structure, NormalizedDotProduct from scipy.optimize import minimize from typing import List import warnings @@ -10,6 +8,11 @@ from flare.ase.atoms import FLARE_Atoms from flare.utils.element_coder import NumpyEncoder +from flare_pp._C_flare import SparseGP, Structure, NormalizedDotProduct, Bk +from flare_pp.utils import convert_to_flarepp_structure + +from memory_profiler import profile + class SGP_Wrapper: """Wrapper class used to make the C++ sparse GP object compatible with @@ -130,18 +133,19 @@ def as_dict(self): out_dict[key] = getattr(self, key, None) # save descriptor_settings + out_dict["descriptor_calculators"] = [] desc_calc = self.descriptor_calculators - assert (len(desc_calc) == 1) and (isinstance(desc_calc[0], _C_flare.B2)) - b2_calc = desc_calc[0] - b2_dict = { - "type": "B2", - "radial_basis": b2_calc.radial_basis, - "cutoff_function": b2_calc.cutoff_function, - "radial_hyps": b2_calc.radial_hyps, - "cutoff_hyps": b2_calc.cutoff_hyps, - "descriptor_settings": b2_calc.descriptor_settings, - } - out_dict["descriptor_calculators"] = [b2_dict] + for dc in self.descriptor_calculators: + assert isinstance(dc, Bk) + dc_dict = { + "type": "Bk", + "radial_basis": dc.radial_basis, + "cutoff_function": dc.cutoff_function, + "radial_hyps": dc.radial_hyps, + "cutoff_hyps": dc.cutoff_hyps, + "descriptor_settings": dc.descriptor_settings, + } + out_dict["descriptor_calculators"].append(dc_dict) # save hyps out_dict["hyps"], out_dict["hyp_labels"] = self.hyps_and_labels @@ -200,39 +204,42 @@ def from_dict(in_dict): """ Need an initialized GP """ - # Recover kernel from checkpoint. + # recover kernels from checkpoint kernel_list = in_dict["kernels"] - assert len(kernel_list) == 1 - kernel_hyps = kernel_list[0] - assert kernel_hyps[0] == "NormalizedDotProduct" - sigma = float(kernel_hyps[1]) - power = int(kernel_hyps[2]) - kernel = NormalizedDotProduct(sigma, power) - kernels = [kernel] - - # Recover descriptor from checkpoint. + kernels = [] + for k, kern in enumerate(kernel_list): + if kern[0] != "NormalizedDotProduct": + raise NotImplementedError + assert kern[1] == in_dict["hyps"][k] + kernels.append(NormalizedDotProduct(kern[1], kern[2])) + + # recover descriptors from checkpoint desc_calc = in_dict["descriptor_calculators"] - assert len(desc_calc) == 1 - b2_dict = desc_calc[0] - assert b2_dict["type"] == "B2" - calc = _C_flare.B2( - b2_dict["radial_basis"], - b2_dict["cutoff_function"], - b2_dict["radial_hyps"], - b2_dict["cutoff_hyps"], - b2_dict["descriptor_settings"], - ) + desc_calc_list = [] + + for dc_dict in desc_calc: + if dc_dict["type"] == "Bk": + calc = Bk( + dc_dict["radial_basis"], + dc_dict["cutoff_function"], + dc_dict["radial_hyps"], + dc_dict["cutoff_hyps"], + dc_dict["descriptor_settings"], + ) + else: + raise NotImplementedError + desc_calc_list.append(calc) # change the keys of single_atom_energies and species_map to int if in_dict["single_atom_energies"] is not None: - sae_dict = {int(k): v for k, v in in_dict["single_atom_energies"].items()} + sae_dict = {int(k): v for k, v in in_dict["single_atom_energies"].items()} else: - sae_dict = None + sae_dict = None species_map = {int(k): v for k, v in in_dict["species_map"].items()} gp = SGP_Wrapper( - kernels=[kernel], - descriptor_calculators=[calc], + kernels=kernels, + descriptor_calculators=desc_calc_list, cutoff=in_dict["cutoff"], sigma_e=in_dict["hyps"][-3], sigma_f=in_dict["hyps"][-2], @@ -289,77 +296,61 @@ def update_db( mode: str = "specific", sgp: SparseGP = None, # for creating sgp_var update_qr=True, + atom_indices=[-1], ): - - # Convert coded species to 0, 1, 2, etc. - if isinstance(structure, (struc.Structure, FLARE_Atoms)): - coded_species = [] - for spec in structure.coded_species: - coded_species.append(self.species_map[spec]) - elif isinstance(structure, Structure): - coded_species = structure.species - else: - raise Exception - - # Convert flare structure to structure descriptor. - structure_descriptor = Structure( - structure.cell, - coded_species, - structure.positions, + # TODO: simplify the interface, remove forces, energy, and stress + + # Convert flare Structure or FLARE_Atoms to flare_pp Structure + structure_descriptor = convert_to_flarepp_structure( + structure, + self.species_map, + energy, + forces, + stress, + self.energy_training, + self.force_training, + self.stress_training, + self.single_atom_energies, self.cutoff, self.descriptor_calculators, ) - # Add labels to structure descriptor. - if (energy is not None) and (self.energy_training): - # Sum up single atom energies. - single_atom_sum = 0 - if self.single_atom_energies is not None: - for spec in coded_species: - single_atom_sum += self.single_atom_energies[spec] - - # Correct the energy label and assign to structure. - corrected_energy = energy - single_atom_sum - structure_descriptor.energy = np.array([[corrected_energy]]) - - if (forces is not None) and (self.force_training): - structure_descriptor.forces = forces.reshape(-1) - - if (stress is not None) and (self.stress_training): - structure_descriptor.stresses = stress - # Update the sparse GP. if sgp is None: sgp = self.sparse_gp - sgp.add_training_structure(structure_descriptor) + sgp.add_training_structure(structure_descriptor, atom_indices) if mode == "all": if not custom_range: sgp.add_all_environments(structure_descriptor) else: raise Exception("Set mode='specific' for a user-defined custom_range") elif mode == "uncertain": - if len(custom_range) == 1: # custom_range gives n_added + if len(custom_range) == len(sgp.kernels): # custom_range gives n_added n_added = custom_range sgp.add_uncertain_environments(structure_descriptor, n_added) else: raise Exception( - "The custom_range should be set as [n_added] if mode='uncertain'" + "The custom_range should length equal to the number of descriptors/kernels if mode='uncertain'" ) elif mode == "specific": if not custom_range: warnings.warn( "The mode='specific' but no custom_range is given, will not add sparse envs" ) - else: + elif len(custom_range) == len(sgp.kernels): sgp.add_specific_environments(structure_descriptor, custom_range) + else: + raise Exception( + "The custom_range should length equal to the number of descriptors/kernels if mode='specific'" + ) elif mode == "random": - if len(custom_range) == 1: # custom_range gives n_added + if len(custom_range) == len(sgp.kernels): # custom_range gives n_added n_added = custom_range sgp.add_random_environments(structure_descriptor, n_added) else: raise Exception( - "The custom_range should be set as [n_added] if mode='random'" + "The custom_range should length equal to the number of descriptors/kernels if mode='random'" ) else: raise NotImplementedError @@ -380,17 +371,17 @@ def train(self, logger_name=None): ) def write_mapping_coefficients(self, filename, contributor, kernel_idx): - self.sparse_gp.write_mapping_coefficients(filename, contributor, kernel_idx) + self.sparse_gp.write_mapping_coefficients( + filename, contributor, kernel_idx, "potential" + ) def write_varmap_coefficients(self, filename, contributor, kernel_idx): old_kernels = self.sparse_gp.kernels - assert (len(old_kernels) == 1) and ( - kernel_idx == 0 - ), "Not support multiple kernels" - assert isinstance(old_kernels[0], NormalizedDotProduct) - power = 1 - new_kernels = [NormalizedDotProduct(old_kernels[0].sigma, power)] + new_kernels = [] + for kern in old_kernels: + assert isinstance(kern, NormalizedDotProduct) + new_kernels.append(NormalizedDotProduct(kern.sigma, power)) # Build a power=1 SGP from scratch if self.sgp_var is None: @@ -399,14 +390,14 @@ def write_varmap_coefficients(self, filename, contributor, kernel_idx): self.sgp_var_kernels = new_kernels # Check hyperparameters and training data, if not match, construct a new SGP - assert len(self.sgp_var.kernels) == 1 - assert np.allclose(self.sgp_var.kernels[0].power, 1.0) + for kern in self.sgp_var.kernels: + assert np.allclose(kern.power, 1.0) is_same_hyps = np.allclose( self.sgp_var.hyperparameters, self.sparse_gp.hyperparameters ) n_sgp = len(self.training_data) n_sgp_var = len(self.sgp_var.training_structures) - is_same_data = n_sgp == n_sgp_var + is_same_data = n_sgp == n_sgp_var # Add new data if sparse_gp has more data than sgp_var if not is_same_data: @@ -443,7 +434,9 @@ def write_varmap_coefficients(self, filename, contributor, kernel_idx): new_kernels = self.sgp_var.kernels print("Map with current sgp_var") - self.sgp_var.write_varmap_coefficients(filename, contributor, kernel_idx) + self.sgp_var.write_mapping_coefficients( + filename, contributor, kernel_idx, "uncertainty" + ) return new_kernels @@ -501,8 +494,7 @@ def duplicate(self, new_hyps=None, new_kernels=None, new_powers=None): return new_gp, kernels -def compute_negative_likelihood(hyperparameters, sparse_gp, - print_vals=False): +def compute_negative_likelihood(hyperparameters, sparse_gp): """Compute the negative log likelihood and gradient with respect to the hyperparameters.""" @@ -512,27 +504,44 @@ def compute_negative_likelihood(hyperparameters, sparse_gp, sparse_gp.compute_likelihood() negative_likelihood = -sparse_gp.log_marginal_likelihood - if print_vals: - print_hyps(hyperparameters, negative_likelihood) + print_hyps(hyperparameters, negative_likelihood) return negative_likelihood -def compute_negative_likelihood_grad(hyperparameters, sparse_gp, - print_vals=False): +def compute_negative_likelihood_grad(hyperparameters, sparse_gp): """Compute the negative log likelihood and gradient with respect to the hyperparameters.""" assert len(hyperparameters) == len(sparse_gp.hyperparameters) - negative_likelihood = \ - -sparse_gp.compute_likelihood_gradient(hyperparameters) + negative_likelihood = -sparse_gp.compute_likelihood_gradient(hyperparameters) negative_likelihood_gradient = -sparse_gp.likelihood_gradient - if print_vals: - print_hyps_and_grad( - hyperparameters, negative_likelihood_gradient, negative_likelihood - ) + print_hyps_and_grad( + hyperparameters, negative_likelihood_gradient, negative_likelihood + ) + + return negative_likelihood, negative_likelihood_gradient + + +def compute_negative_likelihood_grad_stable( + hyperparameters, sparse_gp, precomputed=False +): + """Compute the negative log likelihood and gradient with respect to the + hyperparameters.""" + + assert len(hyperparameters) == len(sparse_gp.hyperparameters) + + print("python set_hyperparameters") + sparse_gp.set_hyperparameters(hyperparameters) + + negative_likelihood = -sparse_gp.compute_likelihood_gradient_stable(precomputed) + negative_likelihood_gradient = -sparse_gp.likelihood_gradient + + print_hyps_and_grad( + hyperparameters, negative_likelihood_gradient, negative_likelihood + ) return negative_likelihood, negative_likelihood_gradient @@ -554,10 +563,10 @@ def print_hyps_and_grad(hyperparameters, neglike_grad, neglike): print(-neglike) print("\n") - +@profile def optimize_hyperparameters( sparse_gp, - display_results=False, + display_results=True, gradient_tolerance=1e-4, max_iterations=10, bounds=None, @@ -566,11 +575,26 @@ def optimize_hyperparameters( """Optimize the hyperparameters of a sparse GP model.""" initial_guess = sparse_gp.hyperparameters - arguments = sparse_gp + + # If all kernels are NormalizedDotProduct, then some matrices can be + # pre-computed and stored + precompute = True + for kern in sparse_gp.kernels: + if not isinstance(kern, NormalizedDotProduct): + precompute = False + break + if precompute: + tic = time() + print("Precomputing KnK for hyps optimization") + sparse_gp.precompute_KnK() + print("Done precomputing. Time:", time() - tic) + arguments = (sparse_gp, precompute) + else: + arguments = (sparse_gp, precompute) if method == "BFGS": optimization_result = minimize( - compute_negative_likelihood_grad, + compute_negative_likelihood_grad_stable, initial_guess, arguments, method="BFGS", @@ -587,7 +611,7 @@ def optimize_hyperparameters( elif method == "L-BFGS-B": optimization_result = minimize( - compute_negative_likelihood_grad, + compute_negative_likelihood_grad_stable, initial_guess, arguments, method="L-BFGS-B", diff --git a/flare_pp/sparse_gp_calculator.py b/flare_pp/sparse_gp_calculator.py index 2d264726..6f6b1bfd 100644 --- a/flare_pp/sparse_gp_calculator.py +++ b/flare_pp/sparse_gp_calculator.py @@ -4,7 +4,7 @@ from flare_pp.sparse_gp import SGP_Wrapper import numpy as np import time, json - +from copy import deepcopy class SGP_Calculator(Calculator): @@ -54,11 +54,11 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): self.gp_model.sparse_gp.predict_local_uncertainties(structure_descriptor) # Set results. - self.results["energy"] = structure_descriptor.mean_efs[0] - self.results["forces"] = structure_descriptor.mean_efs[1:-6].reshape(-1, 3) + self.results["energy"] = deepcopy(structure_descriptor.mean_efs[0]) + self.results["forces"] = deepcopy(structure_descriptor.mean_efs[1:-6].reshape(-1, 3)) # Convert stress to ASE format. - flare_stress = structure_descriptor.mean_efs[-6:] + flare_stress = deepcopy(structure_descriptor.mean_efs[-6:]) ase_stress = -np.array( [ flare_stress[0], @@ -89,19 +89,24 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): # single atom-centered descriptor. # TODO: Generalize this variance type to multiple descriptors. elif self.gp_model.variance_type == "local": - variances = structure_descriptor.local_uncertainties[0] - sorted_variances = sort_variances(structure_descriptor, variances) - stds = np.zeros(len(sorted_variances)) - for n in range(len(sorted_variances)): - var = sorted_variances[n] - if var > 0: - stds[n] = np.sqrt(var) - else: - stds[n] = -np.sqrt(np.abs(var)) - stds_full = np.zeros((len(sorted_variances), 3)) + n_kern = len(structure_descriptor.local_uncertainties) + stds_full = np.zeros((len(atoms), 3)) + assert n_kern <= 3, NotImplementedError # now only print out 3 components + + for k in range(n_kern): + variances = structure_descriptor.local_uncertainties[k] + sorted_variances = sort_variances(structure_descriptor, variances) + stds = np.zeros(len(sorted_variances)) + for n in range(len(sorted_variances)): + var = sorted_variances[n] + if var > 0: + stds[n] = np.sqrt(var) + else: + stds[n] = -np.sqrt(np.abs(var)) + + # Divide by the signal std to get a unitless value. + stds_full[:, k] = stds / self.gp_model.hyps[k] - # Divide by the signal std to get a unitless value. - stds_full[:, 0] = stds / self.gp_model.hyps[0] self.results["stds"] = stds_full def get_uncertainties(self, atoms): diff --git a/flare_pp/tests/test_sparse_gp.py b/flare_pp/tests/test_sparse_gp.py new file mode 100644 index 00000000..89e4c55b --- /dev/null +++ b/flare_pp/tests/test_sparse_gp.py @@ -0,0 +1,304 @@ +import os, sys, shutil +import numpy as np +import pytest + +from ase.io import read, write +from ase.calculators import lammpsrun +from flare import struc +from flare.ase.atoms import FLARE_Atoms +from flare.lammps import lammps_calculator + +from flare_pp.sparse_gp import SGP_Wrapper +from flare_pp.sparse_gp_calculator import SGP_Calculator +from flare_pp._C_flare import NormalizedDotProduct, Bk, SparseGP, Structure + +import flare_pp + +np.random.seed(10) + +# Make random structure. +n_atoms = 4 +cell = np.eye(3) +train_positions = np.random.rand(n_atoms, 3) +# test_positions = np.random.rand(n_atoms, 3) +test_positions = train_positions +atom_types = [1, 2] +atom_masses = [2, 4] +species = [1, 2, 1, 2] +train_structure = struc.Structure(cell, species, train_positions) +test_structure = struc.Structure(cell, species, test_positions) + +# Test update db +custom_range = [1, 2, 3] +energy = np.random.rand() +forces = np.random.rand(n_atoms, 3) * 10 +stress = np.random.rand(6) +np.savez( + "random_data", + train_pos=train_positions, + test_pos=test_positions, + energy=energy, + forces=forces, + stress=stress, +) + +# Create sparse GP model. +sigma = 1.0 +power = 2 +kernel1 = NormalizedDotProduct(sigma, power) + +sigma = 2.0 +power = 2 +kernel2 = NormalizedDotProduct(sigma, power) + +sigma = 3.0 +power = 2 +kernel3 = NormalizedDotProduct(sigma, power) + +kernel_list = [kernel1, kernel2, kernel3] + +cutoff_function = "quadratic" +cutoff = 1.5 +many_body_cutoffs = [cutoff] +radial_basis = "chebyshev" +radial_hyps = [0.0, cutoff] +cutoff_hyps = [] + +settings = [len(atom_types), 1, 4, 0] +calc1 = Bk(radial_basis, cutoff_function, radial_hyps, cutoff_hyps, settings) + +settings = [len(atom_types), 2, 4, 3] +calc2 = Bk(radial_basis, cutoff_function, radial_hyps, cutoff_hyps, settings) + +settings = [len(atom_types), 3, 2, 2] +calc3 = Bk(radial_basis, cutoff_function, radial_hyps, cutoff_hyps, settings) + +calc_list = [calc1, calc2, calc3] + +sigma_e = 0.1 +sigma_f = 0.1 +sigma_s = 0.1 +species_map = {1: 0, 2: 1} +max_iterations = 20 + +sgp_py = SGP_Wrapper( + kernel_list, + calc_list, + cutoff, + sigma_e, + sigma_f, + sigma_s, + species_map, + max_iterations=max_iterations, + variance_type="local", +) +sgp_calc = SGP_Calculator(sgp_py) + + +def test_update_db(): + """Check that the covariance matrices have the correct size after the + sparse GP is updated.""" + + # sgp_py.update_db( + # train_structure, forces, custom_range, energy, stress, mode="specific" + # ) + + sgp_py.update_db( + train_structure, + forces, + custom_range=[3, 3, 3], + energy=energy, + stress=stress, + mode="uncertain", + ) + + n_envs = len(custom_range) + assert sgp_py.sparse_gp.Kuu.shape[0] == sgp_py.sparse_gp.Kuf.shape[0] + assert sgp_py.sparse_gp.Kuf.shape[1] == 1 + n_atoms * 3 + 6 + + +def test_train(): + """Check that the hyperparameters and likelihood are updated when the + train method is called.""" + + hyps_init = tuple(sgp_py.hyps) + sgp_py.train() + hyps_post = tuple(sgp_py.hyps) + + assert hyps_init != hyps_post + assert sgp_py.likelihood != 0.0 + + +def test_dict(): + """ + Check the method from_dict and as_dict + """ + + out_dict = sgp_py.as_dict() + assert len(sgp_py) == len(out_dict["training_structures"]) + new_sgp, _ = SGP_Wrapper.from_dict(out_dict) + assert len(sgp_py) == len(new_sgp) + assert len(sgp_py.sparse_gp.kernels) == len(new_sgp.sparse_gp.kernels) + assert np.allclose(sgp_py.hyps, new_sgp.hyps) + + +def test_coeff(): + # Dump potential coefficient file + sgp_py.write_mapping_coefficients("lmp.flare", "A", [0, 1, 2]) + + # Dump uncertainty coefficient file + # here the new kernel needs to be returned, otherwise the kernel won't be found in the current module + new_kern = sgp_py.write_varmap_coefficients("beta_var.txt", "B", [0, 1, 2]) + + assert ( + sgp_py.sparse_gp.sparse_indices[0] == sgp_py.sgp_var.sparse_indices[0] + ), "the sparse_gp and sgp_var don't have the same training data" + + for s in range(len(atom_types)): + org_desc = sgp_py.sparse_gp.sparse_descriptors[0].descriptors[s] + new_desc = sgp_py.sgp_var.sparse_descriptors[0].descriptors[s] + if not np.allclose(org_desc, new_desc): # the atomic order might change + assert np.allclose(org_desc.shape, new_desc.shape) + for i in range(org_desc.shape[0]): + flag = False + for j in range( + new_desc.shape[0] + ): # seek in new_desc for matching of org_desc + if np.allclose(org_desc[i], new_desc[j]): + flag = True + break + assert flag, "the sparse_gp and sgp_var don't have the same descriptors" + + +@pytest.mark.skipif( + not os.environ.get("lmp", False), + reason=( + "lmp not found " + "in environment: Please install LAMMPS " + "and set the $lmp env. " + "variable to point to the executatble." + ), +) +def test_lammps(): + # create ASE calc + lmp_command = os.environ.get("lmp") + specorder = ["H", "He"] + pot_file = "lmp.flare" + params = { + "command": lmp_command, + "pair_style": "flare", + "pair_coeff": [f"* * {pot_file}"], + } + files = [pot_file] + lmp_calc = lammpsrun.LAMMPS( + label=f"tmp", + keep_tmp_files=True, + tmp_dir="./tmp/", + parameters=params, + files=files, + specorder=specorder, + ) + + test_atoms = test_structure.to_ase_atoms() + test_atoms.calc = lmp_calc + lmp_f = test_atoms.get_forces() + lmp_e = test_atoms.get_potential_energy() + lmp_s = test_atoms.get_stress() + + print("GP predicting") + test_atoms.calc = None + test_atoms = FLARE_Atoms.from_ase_atoms(test_atoms) + test_atoms.calc = sgp_calc + sgp_f = test_atoms.get_forces() + sgp_e = test_atoms.get_potential_energy() + sgp_s = test_atoms.get_stress() + + print("Energy") + print(lmp_e, sgp_e) + assert np.allclose(lmp_e, sgp_e) + + print("Forces") + print(np.concatenate([lmp_f, sgp_f], axis=1)) + assert np.allclose(lmp_f, sgp_f) + + print("Stress") + print(lmp_s) + print(sgp_s) + assert np.allclose(lmp_s, sgp_s) + + +@pytest.mark.skipif( + not os.environ.get("lmp", False), + reason=( + "lmp not found " + "in environment: Please install LAMMPS " + "and set the $lmp env. " + "variable to point to the executatble." + ), +) +def test_lammps_uncertainty(): + # create ASE calc + lmp_command = os.environ.get("lmp") + specorder = ["H", "He"] + pot_file = "lmp.flare" + params = { + "command": lmp_command, + "pair_style": "flare", + "pair_coeff": [f"* * {pot_file}"], + } + files = [pot_file] + lmp_calc = lammpsrun.LAMMPS( + label=f"tmp", + keep_tmp_files=True, + tmp_dir="./tmp/", + parameters=params, + files=files, + specorder=specorder, + ) + + test_atoms = test_structure.to_ase_atoms() + + # compute uncertainty + in_lmp = """ +atom_style atomic +units metal +boundary p p p +atom_modify sort 0 0.0 + +read_data data.lammps + +### interactions +pair_style flare +pair_coeff * * lmp.flare +mass 1 1.008000 +mass 2 4.002602 + +### run +fix fix_nve all nve +compute unc all flare/std/atom beta_var.txt +dump dump_all all custom 1 traj.lammps id type x y z vx vy vz fx fy fz c_unc[1] c_unc[2] c_unc[3] +thermo_style custom step temp press cpu pxx pyy pzz pxy pxz pyz ke pe etotal vol lx ly lz atoms +thermo_modify flush yes format float %23.16g +thermo 1 +run 0 +""" + os.chdir("tmp") + write("data.lammps", test_atoms, format="lammps-data") + with open("in.lammps", "w") as f: + f.write(in_lmp) + shutil.copyfile("../beta_var.txt", "./beta_var.txt") + os.system(f"{lmp_command} < in.lammps > log.lammps") + unc_atoms = read("traj.lammps", format="lammps-dump-text") + lmp_stds = [unc_atoms.get_array(f"c_unc[{i+1}]") / sgp_py.hyps[i] for i in range(len(calc_list))] + lmp_stds = np.squeeze(lmp_stds).T + + # Test mapped variance (need to use sgp_var) + test_atoms.calc = None + test_atoms = FLARE_Atoms.from_ase_atoms(test_atoms) + test_atoms.calc = sgp_calc + test_atoms.calc.gp_model.sparse_gp = sgp_py.sgp_var + test_atoms.calc.reset() + sgp_stds = test_atoms.calc.get_uncertainties(test_atoms) + print(sgp_stds) + print(lmp_stds) + assert np.allclose(sgp_stds, lmp_stds) diff --git a/flare_pp/utils.py b/flare_pp/utils.py new file mode 100644 index 00000000..94c7d269 --- /dev/null +++ b/flare_pp/utils.py @@ -0,0 +1,145 @@ +import numpy as np +from typing import List, Union, Tuple +from flare import struc +from flare.ase.atoms import FLARE_Atoms +from flare_pp._C_flare import Structure +from ase.io import read, write +from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator + + +def convert_to_flarepp_structure( + structure: Union[struc.Structure, FLARE_Atoms, Structure], + species_map: dict, + energy: float = None, + forces: np.ndarray = None, + stress: np.ndarray = None, + enable_energy: bool = False, + enable_forces: bool = False, + enable_stress: bool = False, + single_atom_energies: dict = None, + cutoff: float = None, + descriptor_calculators: List = None, +): + """ + Assume the stress has been converted to the order xx, xy, xz, yy, yz, zz + and has a minus sign + """ + if isinstance(structure, (struc.Structure, FLARE_Atoms)): + # Convert coded species to 0, 1, 2, etc. + coded_species = [] + for spec in structure.coded_species: + coded_species.append(species_map[spec]) + elif isinstance(structure, Structure): + coded_species = structure.species + else: + raise TypeError + + # Convert flare structure to structure descriptor. + if (cutoff is None) or (descriptor_calculators is None): + structure_descriptor = Structure( + structure.cell, + coded_species, + structure.positions, + ) + else: + structure_descriptor = Structure( + structure.cell, + coded_species, + structure.positions, + cutoff, + descriptor_calculators, + ) + + # Add labels to structure descriptor. + if (energy is not None) and (enable_energy): + # Sum up single atom energies. + single_atom_sum = 0 + if single_atom_energies is not None: + for spec in coded_species: + single_atom_sum += single_atom_energies[spec] + + # Correct the energy label and assign to structure. + corrected_energy = energy - single_atom_sum + structure_descriptor.energy = np.array([[corrected_energy]]) + + if (forces is not None) and (enable_forces): + structure_descriptor.forces = forces.reshape(-1) + + if (stress is not None) and (enable_stress): + structure_descriptor.stresses = stress + + return structure_descriptor + + +def add_sparse_indices_to_xyz(xyz_file_in, ind_file, xyz_file_out): + """ + Suppose we have an .xyz file saving all the DFT frames from OTF training, + and a file saving the sparse indices of atoms from each frame in the .xyz + file. For example, `sparse_indices.txt` has lines with format: + ``` + frame_number ind1 ind2 ind3 + ``` + Then this function combines the two files, such that the sparse indices + are written to the .xyz frames. + + Args: + xyz_file_in (str): the file name of the .xyz file, which saves the + DFT frames from OTF. + ind_file (str): the file name of the sparse indices, whose number of + rows should be equal to the number of frames in `xyz_file_in`. + xyz_file_out (str): the output file name. The output will be an .xyz file. + """ + frames = read(xyz_file_in, format="extxyz", index=":") + indices = open(ind_file).readlines() + assert len(frames) == len(indices) + for i in range(len(frames)): + sparse_ind = indices[i].split()[1:] + sparse_ind = np.array([int(s) for s in sparse_ind]) + frames[i].info["sparse_indices"] = sparse_ind + + write(xyz_file_out, frames, format="extxyz") + + +def struc_list_to_xyz(struc_list, xyz_file_out, species_map, sparse_indices=None): + """ + Write a list of `Structure` objects (and their sparse indices) into an .xyz file + + Args: + struc_list (list): a list of `Structure` objects. + xyz_file_out (str): the output file name. The output will be an .xyz file. + species_map (dict): the map from chemical number to coded species. For + example, `species_map = {14: 0, 6: 1}`, means that in the structure list, + the coded species 0 corresponds to silicon, and coded species 1 corresponds + to carbon. + sparse_indices (list): a list of arrays. Each array is the indices of sparse + atoms added to GP in the corresponding `Structure` object. + """ + if sparse_indices: + assert len(sparse_indices) == len(struc_list) + + frames = [] + code2species = { + v: k for k, v in species_map.items() + } # an inverse map of species_map + for i, struc in enumerate(struc_list): + species_number = [code2species[s] for s in struc.species] + atoms = Atoms( + number=species_number, positions=struc.positions, cell=struc.cell, pbc=True + ) + + properties = ["forces", "energy", "stress"] + results = { + "forces": struc.forces, + "energy": struc.energy, + "stress": struc.stresses, + } + calculator = SinglePointCalculator(atoms, **results) + atoms.set_calculator(calculator) + + if sparse_indices: + atoms.info["sparse_indices"] = sparse_indices[i] + + frames.append(atoms) + + write(xyz_file_out, frames, format="extxyz") diff --git a/lammps_plugins/compute_flare_std_atom.cpp b/lammps_plugins/compute_flare_std_atom.cpp index 7b08b5f1..1e03fce0 100644 --- a/lammps_plugins/compute_flare_std_atom.cpp +++ b/lammps_plugins/compute_flare_std_atom.cpp @@ -20,6 +20,8 @@ #include "lammps_descriptor.h" #include "radial.h" #include "y_grad.h" +#include "indices.h" +#include "coeffs.h" using namespace LAMMPS_NS; @@ -35,6 +37,9 @@ ComputeFlareStdAtom::ComputeFlareStdAtom(LAMMPS *lmp, int narg, char **arg) : peratom_flag = 1; size_peratom_cols = 0; + + nmax = 0; + timeflag = 1; comm_reverse = 1; @@ -59,13 +64,19 @@ ComputeFlareStdAtom::~ComputeFlareStdAtom() { if (copymode) return; - memory->destroy(beta); - // if (allocated) { // memory->destroy(setflag); // memory->destroy(cutsq); // } + memory->destroy(radial_code); + memory->destroy(cutoff_code); + memory->destroy(K); + memory->destroy(n_max); + memory->destroy(l_max); + memory->destroy(beta_size); + memory->destroy(cutoffs); + memory->destroy(stds); memory->destroy(desc_derv); } @@ -100,8 +111,8 @@ void ComputeFlareStdAtom::compute_peratom() { if (atom->nmax > nmax) { memory->destroy(stds); nmax = atom->nmax; - memory->create(stds,nmax,"flare/std/atom:stds"); - vector_atom = stds; + memory->create(stds,nmax,num_kern,"flare/std/atom:stds"); + array_atom = stds; // each atom has an uncertainty vector from all kernels } int i, j, ii, jj, inum, jnum, itype, jtype, n_inner, n_count; @@ -125,16 +136,14 @@ void ComputeFlareStdAtom::compute_peratom() { numneigh = list->numneigh; firstneigh = list->firstneigh; - int beta_init, beta_counter; - double B2_norm_squared, B2_val_1, B2_val_2; - - Eigen::VectorXd single_bond_vals, B2_vals, B2_env_dot, beta_p, partial_forces, u; - Eigen::MatrixXd single_bond_env_dervs, B2_env_dervs; - for (ii = 0; ii < ntotal; ii++) { - stds[ii] = 0.0; + for (jj = 0; jj < num_kern; jj++) { + stds[ii][jj] = 0.0; + } } + double empty_thresh = 1e-8; + for (ii = 0; ii < inum; ii++) { i = ilist[ii]; itype = type[i]; @@ -144,36 +153,54 @@ void ComputeFlareStdAtom::compute_peratom() { ztmp = x[i][2]; jlist = firstneigh[i]; - // Count the atoms inside the cutoff. - n_inner = 0; - for (int jj = 0; jj < jnum; jj++) { - j = jlist[jj]; - - delx = x[j][0] - xtmp; - dely = x[j][1] - ytmp; - delz = x[j][2] - ztmp; - rsq = delx * delx + dely * dely + delz * delz; - if (rsq < (cutoff * cutoff)) { - n_inner++; + for (int kern = 0; kern < num_kern; kern++) { + cutoff_matrix = cutoff_matrices[kern]; + + // Count the atoms inside the cutoff. + // TODO: this might be duplicated when multiple kernels share the same cutoff + n_inner = 0; + for (int jj = 0; jj < jnum; jj++) { + j = jlist[jj]; + + delx = x[j][0] - xtmp; + dely = x[j][1] - ytmp; + delz = x[j][2] - ztmp; + rsq = delx * delx + dely * dely + delz * delz; + if (rsq < (cutoff * cutoff)) { + n_inner++; + } + } + double norm_squared, variance; + Eigen::VectorXcd single_bond_vals, u; + Eigen::VectorXd vals, env_dot, partial_forces, beta_p; + Eigen::MatrixXcd single_bond_env_dervs; + Eigen::MatrixXd env_dervs; + + // Compute covariant descriptors. + // TODO: this function call is duplicated for multiple kernels + complex_single_bond(x, type, jnum, n_inner, i, xtmp, ytmp, ztmp, jlist, + basis_function[kern], cutoff_function[kern], + n_species, n_max[kern], l_max[kern], + radial_hyps[kern], cutoff_hyps[kern], + single_bond_vals, single_bond_env_dervs, cutoff_matrix); + + // Compute invariant descriptors. + compute_Bk(vals, env_dervs, norm_squared, env_dot, + single_bond_vals, single_bond_env_dervs, nu[kern], + n_species, K[kern], n_max[kern], l_max[kern], coeffs[kern], + beta_matrices[kern][itype - 1], u, &variance); + + // Continue if the environment is empty. + if (norm_squared < empty_thresh) + continue; + + // Compute local energy and partial forces. + // TODO: not needed if using "u" + beta_p = beta_matrices[kern][itype - 1] * vals; + stds[i][kern] = pow(abs(vals.dot(beta_p)) / norm_squared, 0.5); // the numerator could be negative } - - // Compute covariant descriptors. - single_bond(x, type, jnum, n_inner, i, xtmp, ytmp, ztmp, jlist, - basis_function, cutoff_function, cutoff, n_species, n_max, - l_max, radial_hyps, cutoff_hyps, single_bond_vals, - single_bond_env_dervs); - - // Compute invariant descriptors. - double variance; - B2_descriptor(B2_vals, B2_env_dervs, B2_norm_squared, B2_env_dot, - single_bond_vals, single_bond_env_dervs, n_species, n_max, - l_max, beta_matrices[itype - 1], u, &variance); - - // Compute local energy and partial forces. - stds[i] = pow(abs(variance), 0.5); // the numerator could be negative - } } @@ -187,8 +214,8 @@ int ComputeFlareStdAtom::pack_reverse_comm(int n, int first, double *buf) m = 0; last = first + n; for (i = first; i < last; i++) { - for (int comp = 0; comp < 3; comp++) { - buf[m++] = stds[i]; + for (int k = 0; k < num_kern; k++) { + buf[m++] = stds[i][k]; } } @@ -204,8 +231,8 @@ void ComputeFlareStdAtom::unpack_reverse_comm(int n, int *list, double *buf) m = 0; for (i = 0; i < n; i++) { j = list[i]; - for (int comp = 0; comp < 3; comp++) { - stds[j] += buf[m++]; + for (int k = 0; k < num_kern; k++) { + stds[j][k] += buf[m++]; } } @@ -255,6 +282,7 @@ void ComputeFlareStdAtom::coeff(int narg, char **arg) { error->all(FLERR, "Incorrect args for compute coefficients"); read_file(arg[3]); + size_peratom_cols = num_kern; } @@ -275,83 +303,134 @@ void ComputeFlareStdAtom::coeff(int narg, char **arg) { void ComputeFlareStdAtom::read_file(char *filename) { int me = comm->me; - char line[MAXLINE], radial_string[MAXLINE], cutoff_string[MAXLINE]; - int radial_string_length, cutoff_string_length; + char line[MAXLINE]; FILE *fptr; // Check that the potential file can be opened. - if (me == 0) { - fptr = utils::open_potential(filename,lmp,nullptr); - if (fptr == NULL) { - char str[128]; - snprintf(str, 128, "Cannot open variance file %s", filename); - error->one(FLERR, str); - } + fptr = utils::open_potential(filename,lmp,nullptr); + if (fptr == NULL) { + char str[128]; + snprintf(str, 128, "Cannot open potential file %s", filename); + error->one(FLERR, str); } int tmp, nwords; - if (me == 0) { - fgets(line, MAXLINE, fptr); + fgets(line, MAXLINE, fptr); + fgets(line, MAXLINE, fptr); + sscanf(line, "%i", &num_kern); // number of descriptors/kernels + + memory->create(radial_code, num_kern, "compute:radial_code"); + memory->create(cutoff_code, num_kern, "compute:cutoff_code"); + memory->create(K, num_kern, "compute:K"); + memory->create(n_max, num_kern, "compute:n_max"); + memory->create(l_max, num_kern, "compute:l_max"); + memory->create(beta_size, num_kern, "compute:beta_size"); + + for (int k = 0; k < num_kern; k++) { + char desc_str[MAXLINE]; fgets(line, MAXLINE, fptr); - sscanf(line, "%s", radial_string); // Radial basis set - radial_string_length = strlen(radial_string); + sscanf(line, "%s", desc_str); // Descriptor name + + char radial_str[MAXLINE], cutoff_str[MAXLINE]; fgets(line, MAXLINE, fptr); - sscanf(line, "%i %i %i %i", &n_species, &n_max, &l_max, &beta_size); + sscanf(line, "%s", radial_str); // Radial basis set + if (!strcmp(radial_str, "chebyshev")) { + radial_code[k] = 1; + } else { + char str[128]; + snprintf(str, 128, "Radial function %s is not supported\n.", radial_str); + error->all(FLERR, str); + } + fgets(line, MAXLINE, fptr); - sscanf(line, "%s", cutoff_string); // Cutoff function - cutoff_string_length = strlen(cutoff_string); + sscanf(line, "%i %i %i %i %i", &n_species, &K[k], &n_max[k], &l_max[k], &beta_size[k]); + fgets(line, MAXLINE, fptr); - sscanf(line, "%lg", &cutoff); // Cutoff - } - - MPI_Bcast(&n_species, 1, MPI_INT, 0, world); - MPI_Bcast(&n_max, 1, MPI_INT, 0, world); - MPI_Bcast(&l_max, 1, MPI_INT, 0, world); - MPI_Bcast(&beta_size, 1, MPI_INT, 0, world); - MPI_Bcast(&cutoff, 1, MPI_DOUBLE, 0, world); - MPI_Bcast(&radial_string_length, 1, MPI_INT, 0, world); - MPI_Bcast(&cutoff_string_length, 1, MPI_INT, 0, world); - MPI_Bcast(radial_string, radial_string_length + 1, MPI_CHAR, 0, world); - MPI_Bcast(cutoff_string, cutoff_string_length + 1, MPI_CHAR, 0, world); - - // Set number of descriptors. - int n_radial = n_max * n_species; - n_descriptors = (n_radial * (n_radial + 1) / 2) * (l_max + 1); - - // Check the relationship between the power spectrum and beta. - int beta_check = n_descriptors * n_descriptors; - if (beta_check != beta_size) - error->all(FLERR, "Beta size doesn't match the number of descriptors."); - - // Set the radial basis. - if (!strcmp(radial_string, "chebyshev")) { - basis_function = chebyshev; - radial_hyps = std::vector{0, cutoff}; - } + sscanf(line, "%s", cutoff_str); // Cutoff function + if (!strcmp(cutoff_str, "cosine")) { + cutoff_code[k] = 1; + } else if (!strcmp(cutoff_str, "quadratic")) { + cutoff_code[k] = 2; + } else { + char str[128]; + snprintf(str, 128, "Cutoff function %s is not supported\n.", cutoff_str); + error->all(FLERR, str); + } - // Set the cutoff function. - if (!strcmp(cutoff_string, "quadratic")) - cutoff_function = quadratic_cutoff; - else if (!strcmp(cutoff_string, "cosine")) - cutoff_function = cos_cutoff; - - // Parse the beta vectors. - //memory->create(beta, beta_size * n_species * n_species, "compute:beta"); - memory->create(beta, beta_size * n_species, "compute:beta"); - if (me == 0) - // grab(fptr, beta_size * n_species * n_species, beta); - grab(fptr, beta_size * n_species, beta); - //MPI_Bcast(beta, beta_size * n_species * n_species, MPI_DOUBLE, 0, world); - MPI_Bcast(beta, beta_size * n_species, MPI_DOUBLE, 0, world); - - // Fill in the beta matrix. - // TODO: Remove factor of 2 from beta. - int n_size = n_species * n_descriptors; - int beta_count = 0; - double beta_val; - for (int k = 0; k < n_species; k++) { -// for (int l = 0; l < n_species; l++) { + // Parse the cutoffs. + int n_cutoffs = n_species * n_species; + memory->create(cutoffs, n_cutoffs, "compute:cutoffs"); + if (me == 0) + grab(fptr, n_cutoffs, cutoffs); + MPI_Bcast(cutoffs, n_cutoffs, MPI_DOUBLE, 0, world); + + // Fill in the cutoff matrix. + cutoff = -1; + cutoff_matrix = Eigen::MatrixXd::Zero(n_species, n_species); + int cutoff_count = 0; + for (int i = 0; i < n_species; i++){ + for (int j = 0; j < n_species; j++){ + double cutoff_val = cutoffs[cutoff_count]; + cutoff_matrix(i, j) = cutoff_val; + if (cutoff_val > cutoff) cutoff = cutoff_val; + cutoff_count ++; + } + } + cutoff_matrices.push_back(cutoff_matrix); + + // Compute indices and coefficients + std::vector descriptor_settings = {n_species, K[k], n_max[k], l_max[k]}; + std::vector> nu_kern = compute_indices(descriptor_settings); + Eigen::VectorXd coeffs_kern = compute_coeffs(K[k], l_max[k]); + nu.push_back(nu_kern); + coeffs.push_back(coeffs_kern); + + // Set number of descriptors. + std::vector last_index = nu_kern[nu_kern.size()-1]; + int n_descriptors = last_index[last_index.size()-1] + 1; + + // Check the relationship between the power spectrum and beta. + int beta_check = n_descriptors * n_descriptors; + if (beta_check != beta_size[k]) { + char str[128]; + snprintf(str, 128, "The beta size of kernel %d doesn't match the number of descriptors.", k); + error->all(FLERR, str); + } + // Set the radial basis. + if (radial_code[k] == 1){ + basis_function.push_back(chebyshev); + std::vector rh = {0, cutoffs[0]}; // It does not matter what + // cutoff is used, will be + // modified to cutoff_matrix + // when computing descriptors + radial_hyps.push_back(rh); + std::vector ch; + cutoff_hyps.push_back(ch); + } + + // Set the cutoff function. + if (cutoff_code[k] == 1) + cutoff_function.push_back(cos_cutoff); + else if (cutoff_code[k] == 2) + cutoff_function.push_back(quadratic_cutoff); + + // Parse the beta vectors. + // TODO: check this memory creation + memory->create(beta, beta_size[k] * n_species, "pair:beta"); + grab(fptr, beta_size[k] * n_species, beta); + //MPI_Bcast(beta1, beta_size[k] * n_species, MPI_DOUBLE, 0, world); + + // Fill in the beta matrix. + // TODO: Remove factor of 2 from beta. + Eigen::MatrixXd beta_matrix; + std::vector beta_matrix_kern; + int n_size = n_species * n_descriptors; + int beta_count = 0; + double beta_val; + for (int k = 0; k < n_species; k++) { + // for (int l = 0; l < n_species; l++) { + beta_matrix = Eigen::MatrixXd::Zero(n_descriptors, n_descriptors); for (int i = 0; i < n_descriptors; i++) { for (int j = 0; j < n_descriptors; j++) { @@ -360,11 +439,15 @@ void ComputeFlareStdAtom::read_file(char *filename) { beta_count++; } } - beta_matrices.push_back(beta_matrix); -// } - } + beta_matrix_kern.push_back(beta_matrix); + // } + } + beta_matrices.push_back(beta_matrix_kern); + // TODO: check this memory destroy + memory->destroy(beta); + } } /* ---------------------------------------------------------------------- diff --git a/lammps_plugins/compute_flare_std_atom.h b/lammps_plugins/compute_flare_std_atom.h index 3d19de23..2be11458 100644 --- a/lammps_plugins/compute_flare_std_atom.h +++ b/lammps_plugins/compute_flare_std_atom.h @@ -32,26 +32,33 @@ class ComputeFlareStdAtom : public Compute { void init_list(int, class NeighList *); protected: - double *stds; + double **stds; double **desc_derv; class NeighList *list; - int n_species, n_max, l_max, n_descriptors, beta_size; + //int n_species, n_max, l_max, n_descriptors, beta_size; + int num_kern, n_species, n_descriptors; + int *K, *n_max, *l_max, *beta_size; + std::vector coeffs; // coefficient of A product in B + std::vector>> nu; // indices of A product - std::function &, std::vector &, double, int, - std::vector)> + int *radial_code, *cutoff_code; + std::vector &, std::vector &, double, int, + std::vector)>> basis_function; - std::function &, double, double, - std::vector)> + std::vector &, double, double, + std::vector)>> cutoff_function; - std::vector radial_hyps, cutoff_hyps; + std::vector> radial_hyps, cutoff_hyps; int nmax; // number of atoms + double *cutoffs; double cutoff; double *beta; - Eigen::MatrixXd beta_matrix; - std::vector beta_matrices; + Eigen::MatrixXd beta_matrix, cutoff_matrix; + std::vector cutoff_matrices; + std::vector> beta_matrices; virtual void allocate(); virtual void read_file(char *); diff --git a/lammps_plugins/lammps_descriptor.cpp b/lammps_plugins/lammps_descriptor.cpp index 3546e95b..54e92079 100644 --- a/lammps_plugins/lammps_descriptor.cpp +++ b/lammps_plugins/lammps_descriptor.cpp @@ -4,7 +4,7 @@ #include #include -void single_bond_multiple_cutoffs( +void complex_single_bond( double **x, int *type, int jnum, int n_inner, int i, double xtmp, double ytmp, double ztmp, int *jlist, std::function &, std::vector &, double, @@ -15,8 +15,8 @@ void single_bond_multiple_cutoffs( cutoff_function, int n_species, int N, int lmax, const std::vector &radial_hyps, - const std::vector &cutoff_hyps, Eigen::VectorXd &single_bond_vals, - Eigen::MatrixXd &single_bond_env_dervs, + const std::vector &cutoff_hyps, Eigen::VectorXcd &single_bond_vals, + Eigen::MatrixXcd &single_bond_env_dervs, const Eigen::MatrixXd &cutoff_matrix) { // Initialize basis vectors and spherical harmonics. @@ -26,22 +26,24 @@ void single_bond_multiple_cutoffs( std::vector gz = std::vector(N, 0); int n_harmonics = (lmax + 1) * (lmax + 1); - std::vector h = std::vector(n_harmonics, 0); - std::vector hx = std::vector(n_harmonics, 0); - std::vector hy = std::vector(n_harmonics, 0); - std::vector hz = std::vector(n_harmonics, 0); + Eigen::VectorXcd h, hx, hy, hz; +// std::vector h = std::vector(n_harmonics, 0); +// std::vector hx = std::vector(n_harmonics, 0); +// std::vector hy = std::vector(n_harmonics, 0); +// std::vector hz = std::vector(n_harmonics, 0); // Prepare LAMMPS variables. int central_species = type[i] - 1; - double delx, dely, delz, rsq, r, bond, bond_x, bond_y, bond_z, g_val, gx_val, - gy_val, gz_val, h_val; + double delx, dely, delz, rsq, r; + double g_val, gx_val, gy_val, gz_val; + std::complex bond, bond_x, bond_y, bond_z, h_val; int j, s, descriptor_counter; // Initialize vectors. int n_radial = n_species * N; int n_bond = n_radial * n_harmonics; - single_bond_vals = Eigen::VectorXd::Zero(n_bond); - single_bond_env_dervs = Eigen::MatrixXd::Zero(n_inner * 3, n_bond); + single_bond_vals = Eigen::VectorXcd::Zero(n_bond); + single_bond_env_dervs = Eigen::MatrixXcd::Zero(n_inner * 3, n_bond); // Initialize radial hyperparameters. std::vector new_radial_hyps = radial_hyps; @@ -68,7 +70,7 @@ void single_bond_multiple_cutoffs( calculate_radial(g, gx, gy, gz, basis_function, cutoff_function, delx, dely, delz, r, cutoff, N, new_radial_hyps, cutoff_hyps); - get_Y(h, hx, hy, hz, delx, dely, delz, lmax); + get_complex_Y(h, hx, hy, hz, delx, dely, delz, lmax); // Store the products and their derivatives. descriptor_counter = s * N * n_harmonics; @@ -106,175 +108,80 @@ void single_bond_multiple_cutoffs( } } -void single_bond( - double **x, int *type, int jnum, int n_inner, int i, double xtmp, - double ytmp, double ztmp, int *jlist, - std::function &, std::vector &, double, - int, std::vector)> - basis_function, - std::function &, double, double, - std::vector)> - cutoff_function, - double cutoff, int n_species, int N, int lmax, - const std::vector &radial_hyps, - const std::vector &cutoff_hyps, Eigen::VectorXd &single_bond_vals, - Eigen::MatrixXd &single_bond_env_dervs) { - - // Initialize basis vectors and spherical harmonics. - std::vector g = std::vector(N, 0); - std::vector gx = std::vector(N, 0); - std::vector gy = std::vector(N, 0); - std::vector gz = std::vector(N, 0); - - int n_harmonics = (lmax + 1) * (lmax + 1); - std::vector h = std::vector(n_harmonics, 0); - std::vector hx = std::vector(n_harmonics, 0); - std::vector hy = std::vector(n_harmonics, 0); - std::vector hz = std::vector(n_harmonics, 0); - - // Prepare LAMMPS variables. - int itype = type[i]; - double delx, dely, delz, rsq, r, bond, bond_x, bond_y, bond_z, g_val, gx_val, - gy_val, gz_val, h_val; - int j, s, descriptor_counter; - double cutforcesq = cutoff * cutoff; - - // Initialize vectors. - int n_radial = n_species * N; - int n_bond = n_radial * n_harmonics; - single_bond_vals = Eigen::VectorXd::Zero(n_bond); - single_bond_env_dervs = Eigen::MatrixXd::Zero(n_inner * 3, n_bond); - - // Loop over neighbors. - int n_count = 0; - for (int jj = 0; jj < jnum; jj++) { - j = jlist[jj]; - - delx = x[j][0] - xtmp; - dely = x[j][1] - ytmp; - delz = x[j][2] - ztmp; - rsq = delx * delx + dely * dely + delz * delz; - r = sqrt(rsq); - - if (rsq < cutforcesq) { // minus a small value to prevent numerial error - s = type[j] - 1; - calculate_radial(g, gx, gy, gz, basis_function, cutoff_function, delx, - dely, delz, r, cutoff, N, radial_hyps, cutoff_hyps); - get_Y(h, hx, hy, hz, delx, dely, delz, lmax); - - // Store the products and their derivatives. - descriptor_counter = s * N * n_harmonics; - - for (int radial_counter = 0; radial_counter < N; radial_counter++) { - // Retrieve radial values. - g_val = g[radial_counter]; - gx_val = gx[radial_counter]; - gy_val = gy[radial_counter]; - gz_val = gz[radial_counter]; - - for (int angular_counter = 0; angular_counter < n_harmonics; - angular_counter++) { - - h_val = h[angular_counter]; - bond = g_val * h_val; - - // Calculate derivatives with the product rule. - bond_x = gx_val * h_val + g_val * hx[angular_counter]; - bond_y = gy_val * h_val + g_val * hy[angular_counter]; - bond_z = gz_val * h_val + g_val * hz[angular_counter]; - - // Update single bond basis arrays. - single_bond_vals(descriptor_counter) += bond; - - single_bond_env_dervs(n_count * 3, descriptor_counter) += bond_x; - single_bond_env_dervs(n_count * 3 + 1, descriptor_counter) += bond_y; - single_bond_env_dervs(n_count * 3 + 2, descriptor_counter) += bond_z; - - descriptor_counter++; - } - } - n_count++; +void compute_Bk(Eigen::VectorXd &Bk_vals, Eigen::MatrixXd &Bk_force_dervs, + double &norm_squared, Eigen::VectorXd &Bk_force_dots, + const Eigen::VectorXcd &single_bond_vals, + const Eigen::MatrixXcd &single_bond_force_dervs, + std::vector> nu, int nos, int K, int N, + int lmax, const Eigen::VectorXd &coeffs, + const Eigen::MatrixXd &beta_matrix, Eigen::VectorXcd &u, + double *evdwl) { + + int env_derv_cols = single_bond_force_dervs.cols(); + int env_derv_size = single_bond_force_dervs.rows(); + int n_neighbors = env_derv_size / 3; + + // The value of last counter is the number of descriptors + std::vector last_index = nu[nu.size()-1]; + int n_d = last_index[last_index.size()-1] + 1; + + // Initialize arrays. + Bk_vals = Eigen::VectorXd::Zero(n_d); + Bk_force_dervs = Eigen::MatrixXd::Zero(env_derv_size, n_d); + Bk_force_dots = Eigen::VectorXd::Zero(env_derv_size); + norm_squared = 0.0; + + Eigen::MatrixXcd dA_matr = Eigen::MatrixXd::Zero(n_d, env_derv_cols); + for (int i = 0; i < nu.size(); i++) { + std::vector nu_list = nu[i]; + std::vector single_bond_index = std::vector(nu_list.end() - 2 - K, nu_list.end() - 2); // Get n1_l, n2_l, n3_l, etc. + // Forward + std::complex A_fwd = 1; + Eigen::VectorXcd dA = Eigen::VectorXcd::Ones(K); + for (int t = 0; t < K - 1; t++) { + A_fwd *= single_bond_vals(single_bond_index[t]); + dA(t + 1) *= A_fwd; } - } -} - -void B2_descriptor(Eigen::VectorXd &B2_vals, Eigen::MatrixXd &B2_env_dervs, - double &norm_squared, Eigen::VectorXd &B2_env_dot, - const Eigen::VectorXd &single_bond_vals, - const Eigen::MatrixXd &single_bond_env_dervs, int n_species, - int N, int lmax, const Eigen::MatrixXd &beta_matrix, - Eigen::VectorXd &u, double *evdwl) { - - int env_derv_size = single_bond_env_dervs.rows(); - int neigh_size = env_derv_size / 3; - int n_radial = n_species * N; - int n_harmonics = (lmax + 1) * (lmax + 1); - int n_descriptors = (n_radial * (n_radial + 1) / 2) * (lmax + 1); - - int n1_l, n2_l, counter, n1_count, n2_count; - - // Zero the B2 vectors and matrices. - B2_vals = Eigen::VectorXd::Zero(n_descriptors); - B2_env_dervs = Eigen::MatrixXd::Zero(env_derv_size, n_descriptors); - B2_env_dot = Eigen::VectorXd::Zero(env_derv_size); - - // Compute the descriptor. - for (int n1 = n_radial - 1; n1 >= 0; n1--) { - n1_count = (n1 * (2 * n_radial - n1 + 1)) / 2; - - for (int n2 = n1; n2 < n_radial; n2++) { - n2_count = n2 - n1; - - for (int l = 0; l < (lmax + 1); l++) { - counter = l + (n1_count + n2_count) * (lmax + 1); - - for (int m = 0; m < (2 * l + 1); m++) { - n1_l = n1 * n_harmonics + (l * l + m); - n2_l = n2 * n_harmonics + (l * l + m); - - // Store B2 value. - B2_vals(counter) += single_bond_vals(n1_l) * single_bond_vals(n2_l); + // Backward + std::complex A_bwd = 1; + for (int t = K - 1; t > 0; t--) { + A_bwd *= single_bond_vals(single_bond_index[t]); + dA(t - 1) *= A_bwd; + } + std::complex A = A_fwd * single_bond_vals(single_bond_index[K - 1]); + + int counter = nu_list[nu_list.size() - 1]; + int m_index = nu_list[nu_list.size() - 2]; + Bk_vals(counter) += real(coeffs(m_index) * A); + +// // Prepare for partial force calculation +// for (int t = 0; t < K; t++) { +// dA_matr(counter, single_bond_index[t]) = coeffs(m_index) * dA(t); +// } + + // Store force derivatives. + for (int n = 0; n < n_neighbors; n++) { + for (int comp = 0; comp < 3; comp++) { + int ind = n * 3 + comp; + std::complex dA_dr = 0; + for (int t = 0; t < K; t++) { + dA_dr += dA(t) * single_bond_force_dervs(ind, single_bond_index[t]); } + Bk_force_dervs(ind, counter) += + real(coeffs(m_index) * dA_dr); } } } - // Compute w(n1, n2, l), where f_ik = w * dB/dr_ik - norm_squared = B2_vals.dot(B2_vals); - Eigen::VectorXd beta_p = beta_matrix * B2_vals; - *evdwl = B2_vals.dot(beta_p) / norm_squared; - Eigen::VectorXd w = 2 * (beta_p - *evdwl * B2_vals) / norm_squared; + // Compute descriptor norm and energy. + Bk_force_dots = Bk_force_dervs * Bk_vals; + norm_squared = Bk_vals.dot(Bk_vals); + //Eigen::VectorXd beta_p = beta_matrix * Bk_vals; + //*evdwl = Bk_vals.dot(beta_p) / norm_squared; + //Eigen::VectorXd w = 2 * (beta_p - *evdwl * Bk_vals) / norm_squared; // same size as Bk_vals // Compute u(n1, l, m), where f_ik = u * dA/dr_ik - u = Eigen::VectorXd::Zero(single_bond_vals.size()); - double factor; - for (int n1 = n_radial - 1; n1 >= 0; n1--) { - for (int n2 = 0; n2 < n_radial; n2++) { - if (n1 == n2){ - n1_count = (n1 * (2 * n_radial - n1 + 1)) / 2; - n2_count = n2 - n1; - factor = 1.0; - } else if (n1 < n2) { - n1_count = (n1 * (2 * n_radial - n1 + 1)) / 2; - n2_count = n2 - n1; - factor = 0.5; - } else { - n1_count = (n2 * (2 * n_radial - n2 + 1)) / 2; - n2_count = n1 - n2; - factor = 0.5; - } - - for (int l = 0; l < (lmax + 1); l++) { - counter = l + (n1_count + n2_count) * (lmax + 1); - - for (int m = 0; m < (2 * l + 1); m++) { - n1_l = n1 * n_harmonics + (l * l + m); - n2_l = n2 * n_harmonics + (l * l + m); - - u(n1_l) += w(counter) * single_bond_vals(n2_l) * factor; - } - } - } - } - u *= 2; + //double factor; + //u = w.transpose() * dA_matr; + } diff --git a/lammps_plugins/lammps_descriptor.h b/lammps_plugins/lammps_descriptor.h index c59c47f6..6ed7e60a 100644 --- a/lammps_plugins/lammps_descriptor.h +++ b/lammps_plugins/lammps_descriptor.h @@ -6,21 +6,7 @@ #include -void single_bond( - double **x, int *type, int jnum, int n_inner, int i, double xtmp, - double ytmp, double ztmp, int *jlist, - std::function &, std::vector &, double, - int, std::vector)> - basis_function, - std::function &, double, double, - std::vector)> - cutoff_function, - double cutoff, int n_species, int N, int lmax, - const std::vector &radial_hyps, - const std::vector &cutoff_hyps, Eigen::VectorXd &single_bond_vals, - Eigen::MatrixXd &single_bond_env_dervs); - -void single_bond_multiple_cutoffs( +void complex_single_bond( double **x, int *type, int jnum, int n_inner, int i, double xtmp, double ytmp, double ztmp, int *jlist, std::function &, std::vector &, double, @@ -31,15 +17,18 @@ void single_bond_multiple_cutoffs( cutoff_function, int n_species, int N, int lmax, const std::vector &radial_hyps, - const std::vector &cutoff_hyps, Eigen::VectorXd &single_bond_vals, - Eigen::MatrixXd &single_bond_env_dervs, - const Eigen::MatrixXd &cutoff_matrix); + const std::vector &cutoff_hyps, Eigen::VectorXcd &single_bond_vals, + Eigen::MatrixXcd &single_bond_env_dervs, + const Eigen::MatrixXd &cutoff_matrix); -void B2_descriptor(Eigen::VectorXd &B2_vals, Eigen::MatrixXd &B2_env_dervs, - double &norm_squared, Eigen::VectorXd &B2_env_dot, - const Eigen::VectorXd &single_bond_vals, - const Eigen::MatrixXd &single_bond_env_dervs, int n_species, - int N, int lmax, const Eigen::MatrixXd &beta_matrix, - Eigen::VectorXd &u, double *evdwl); +void compute_Bk(Eigen::VectorXd &Bk_vals, Eigen::MatrixXd &Bk_env_dervs, + double &norm_squared, Eigen::VectorXd &Bk_env_dot, + const Eigen::VectorXcd &single_bond_vals, + const Eigen::MatrixXcd &single_bond_env_dervs, + std::vector> nu, + int n_species, int K, int N, int lmax, + const Eigen::VectorXd &coeffs, + const Eigen::MatrixXd &beta_matrix, + Eigen::VectorXcd &u, double *evdwl); #endif diff --git a/lammps_plugins/pair_flare.cpp b/lammps_plugins/pair_flare.cpp index 77757229..09090753 100644 --- a/lammps_plugins/pair_flare.cpp +++ b/lammps_plugins/pair_flare.cpp @@ -14,30 +14,19 @@ #include #include #include -#include // flare++ modules #include "cutoffs.h" #include "lammps_descriptor.h" #include "radial.h" #include "y_grad.h" +#include "indices.h" +#include "coeffs.h" using namespace LAMMPS_NS; #define MAXLINE 1024 -typedef unsigned long long timestamp_t; - -static timestamp_t -get_timestamp () -{ - struct timeval now; - gettimeofday (&now, NULL); - return now.tv_usec + (timestamp_t)now.tv_sec * 1000000; -} - - - /* ---------------------------------------------------------------------- */ PairFLARE::PairFLARE(LAMMPS *lmp) : Pair(lmp) { @@ -55,12 +44,19 @@ PairFLARE::~PairFLARE() { if (copymode) return; - memory->destroy(beta); - if (allocated) { memory->destroy(setflag); memory->destroy(cutsq); } + + memory->destroy(radial_code); + memory->destroy(cutoff_code); + memory->destroy(K); + memory->destroy(n_max); + memory->destroy(l_max); + memory->destroy(beta_size); + memory->destroy(cutoffs); + } /* ---------------------------------------------------------------------- */ @@ -86,10 +82,6 @@ void PairFLARE::compute(int eflag, int vflag) { numneigh = list->numneigh; firstneigh = list->firstneigh; - int beta_init, beta_counter; - double B2_norm_squared, B2_val_1, B2_val_2; - Eigen::VectorXd single_bond_vals, B2_vals, B2_env_dot, u; - Eigen::MatrixXd single_bond_env_dervs, B2_env_dervs; double empty_thresh = 1e-8; for (ii = 0; ii < inum; ii++) { @@ -101,74 +93,97 @@ void PairFLARE::compute(int eflag, int vflag) { ztmp = x[i][2]; jlist = firstneigh[i]; - // Count the atoms inside the cutoff. - n_inner = 0; - for (int jj = 0; jj < jnum; jj++) { - j = jlist[jj]; - int s = type[j] - 1; - double cutoff_val = cutoff_matrix(itype-1, s); - - delx = x[j][0] - xtmp; - dely = x[j][1] - ytmp; - delz = x[j][2] - ztmp; - rsq = delx * delx + dely * dely + delz * delz; - if (rsq < (cutoff_val * cutoff_val)) - n_inner++; - } + for (int kern = 0; kern < num_kern; kern++) { + cutoff_matrix = cutoff_matrices[kern]; + + // Count the atoms inside the cutoff. + // TODO: this might be duplicated when multiple kernels share the same cutoff + n_inner = 0; + for (int jj = 0; jj < jnum; jj++) { + j = jlist[jj]; + int s = type[j] - 1; + double cutoff_val = cutoff_matrix(itype-1, s); + delx = x[j][0] - xtmp; + dely = x[j][1] - ytmp; + delz = x[j][2] - ztmp; + rsq = delx * delx + dely * dely + delz * delz; + if (rsq < (cutoff_val * cutoff_val)) + n_inner++; + } - // Compute covariant descriptors. - double secs; - single_bond_multiple_cutoffs(x, type, jnum, n_inner, i, xtmp, ytmp, ztmp, - jlist, basis_function, cutoff_function, - n_species, n_max, l_max, radial_hyps, - cutoff_hyps, single_bond_vals, - single_bond_env_dervs, cutoff_matrix); - - // Compute invariant descriptors. - B2_descriptor(B2_vals, B2_env_dervs, B2_norm_squared, B2_env_dot, - single_bond_vals, single_bond_env_dervs, n_species, n_max, - l_max, beta_matrices[itype - 1], u, &evdwl); - - // Continue if the environment is empty. - if (B2_norm_squared < empty_thresh) - continue; - - // Update energy, force and stress arrays. - n_count = 0; - for (int jj = 0; jj < jnum; jj++) { - j = jlist[jj]; - int s = type[j] - 1; - double cutoff_val = cutoff_matrix(itype-1, s); - delx = xtmp - x[j][0]; - dely = ytmp - x[j][1]; - delz = ztmp - x[j][2]; - rsq = delx * delx + dely * dely + delz * delz; - - if (rsq < (cutoff_val * cutoff_val)) { - // Compute partial force f_ij = u * dA/dr_ij - double fx = single_bond_env_dervs.row(n_count * 3).dot(u); - double fy = single_bond_env_dervs.row(n_count * 3 + 1).dot(u); - double fz = single_bond_env_dervs.row(n_count * 3 + 2).dot(u); - // Compute local energy and partial forces. - - f[i][0] += fx; - f[i][1] += fy; - f[i][2] += fz; - f[j][0] -= fx; - f[j][1] -= fy; - f[j][2] -= fz; - - if (vflag) { - ev_tally_xyz(i, j, nlocal, newton_pair, 0.0, 0.0, fx, fy, fz, delx, - dely, delz); + double norm_squared; + Eigen::VectorXcd single_bond_vals, u; + Eigen::VectorXd vals, env_dot, partial_forces, beta_p; + Eigen::MatrixXcd single_bond_env_dervs; + Eigen::MatrixXd env_dervs; + + // Compute covariant descriptors. + // TODO: this function call is duplicated for multiple kernels + complex_single_bond(x, type, jnum, n_inner, i, xtmp, ytmp, ztmp, jlist, + basis_function[kern], cutoff_function[kern], + n_species, n_max[kern], l_max[kern], + radial_hyps[kern], cutoff_hyps[kern], + single_bond_vals, single_bond_env_dervs, cutoff_matrix); + + // Compute invariant descriptors. + compute_Bk(vals, env_dervs, norm_squared, env_dot, + single_bond_vals, single_bond_env_dervs, nu[kern], + n_species, K[kern], n_max[kern], l_max[kern], coeffs[kern], + beta_matrices[kern][itype - 1], u, &evdwl); + + // Continue if the environment is empty. + if (norm_squared < empty_thresh) + continue; + + // Compute local energy and partial forces. + // TODO: not needed if using "u" + beta_p = beta_matrices[kern][itype - 1] * vals; + evdwl = vals.dot(beta_p) / norm_squared; + + partial_forces = + 2 * (- env_dervs * beta_p + evdwl * env_dot) / norm_squared; + + // Update energy, force and stress arrays. + n_count = 0; + for (int jj = 0; jj < jnum; jj++) { + j = jlist[jj]; + int s = type[j] - 1; + double cutoff_val = cutoff_matrix(itype-1, s); + delx = xtmp - x[j][0]; + dely = ytmp - x[j][1]; + delz = ztmp - x[j][2]; + rsq = delx * delx + dely * dely + delz * delz; + + if (rsq < (cutoff_val * cutoff_val)) { + double fx = -partial_forces(n_count * 3); + double fy = -partial_forces(n_count * 3 + 1); + double fz = -partial_forces(n_count * 3 + 2); + +// // Compute partial force f_ij = u * dA/dr_ij +// double fx = real(single_bond_env_dervs.row(n_count * 3 + 0).dot(u)); +// double fy = real(single_bond_env_dervs.row(n_count * 3 + 1).dot(u)); +// double fz = real(single_bond_env_dervs.row(n_count * 3 + 2).dot(u)); + + f[i][0] += fx; + f[i][1] += fy; + f[i][2] += fz; + f[j][0] -= fx; + f[j][1] -= fy; + f[j][2] -= fz; + + if (vflag) { + ev_tally_xyz(i, j, nlocal, newton_pair, 0.0, 0.0, fx, fy, fz, delx, + dely, delz); + } + n_count++; } - n_count++; } + + // Compute local energy. + if (eflag) + ev_tally_full(i, 2.0 * evdwl, 0.0, 0.0, 0.0, 0.0, 0.0); } - // Compute local energy. - if (eflag) - ev_tally_full(i, 2.0 * evdwl, 0.0, 0.0, 0.0, 0.0, 0.0); } if (vflag_fdotr) @@ -254,110 +269,151 @@ double PairFLARE::init_one(int i, int j) { void PairFLARE::read_file(char *filename) { int me = comm->me; - char line[MAXLINE], radial_string[MAXLINE], cutoff_string[MAXLINE]; - int radial_string_length, cutoff_string_length; + char line[MAXLINE]; FILE *fptr; // Check that the potential file can be opened. - if (me == 0) { - fptr = utils::open_potential(filename,lmp,nullptr); - if (fptr == NULL) { - char str[128]; - snprintf(str, 128, "Cannot open potential file %s", filename); - error->one(FLERR, str); - } + fptr = utils::open_potential(filename,lmp,nullptr); + if (fptr == NULL) { + char str[128]; + snprintf(str, 128, "Cannot open potential file %s", filename); + error->one(FLERR, str); } int tmp, nwords; - if (me == 0) { + fgets(line, MAXLINE, fptr); + fgets(line, MAXLINE, fptr); + sscanf(line, "%i", &num_kern); // number of descriptors/kernels + + memory->create(radial_code, num_kern, "pair:radial_code"); + memory->create(cutoff_code, num_kern, "pair:cutoff_code"); + memory->create(K, num_kern, "pair:K"); + memory->create(n_max, num_kern, "pair:n_max"); + memory->create(l_max, num_kern, "pair:l_max"); + memory->create(beta_size, num_kern, "pair:beta_size"); + + for (int k = 0; k < num_kern; k++) { + char desc_str[MAXLINE]; fgets(line, MAXLINE, fptr); + sscanf(line, "%s", desc_str); // Descriptor name + + char radial_str[MAXLINE], cutoff_str[MAXLINE]; fgets(line, MAXLINE, fptr); - sscanf(line, "%s", radial_string); // Radial basis set - radial_string_length = strlen(radial_string); + sscanf(line, "%s", radial_str); // Radial basis set + if (!strcmp(radial_str, "chebyshev")) { + radial_code[k] = 1; + } else { + char str[128]; + snprintf(str, 128, "Radial function %s is not supported\n.", radial_str); + error->all(FLERR, str); + } + fgets(line, MAXLINE, fptr); - sscanf(line, "%i %i %i %i", &n_species, &n_max, &l_max, &beta_size); + sscanf(line, "%i %i %i %i %i", &n_species, &K[k], &n_max[k], &l_max[k], &beta_size[k]); + fgets(line, MAXLINE, fptr); - sscanf(line, "%s", cutoff_string); // Cutoff function - cutoff_string_length = strlen(cutoff_string); - } - - MPI_Bcast(&n_species, 1, MPI_INT, 0, world); - MPI_Bcast(&n_max, 1, MPI_INT, 0, world); - MPI_Bcast(&l_max, 1, MPI_INT, 0, world); - MPI_Bcast(&beta_size, 1, MPI_INT, 0, world); - MPI_Bcast(&cutoff, 1, MPI_DOUBLE, 0, world); - MPI_Bcast(&radial_string_length, 1, MPI_INT, 0, world); - MPI_Bcast(&cutoff_string_length, 1, MPI_INT, 0, world); - MPI_Bcast(radial_string, radial_string_length + 1, MPI_CHAR, 0, world); - MPI_Bcast(cutoff_string, cutoff_string_length + 1, MPI_CHAR, 0, world); - - // Parse the cutoffs. - int n_cutoffs = n_species * n_species; - memory->create(cutoffs, n_cutoffs, "pair:cutoffs"); - if (me == 0) - grab(fptr, n_cutoffs, cutoffs); - MPI_Bcast(cutoffs, n_cutoffs, MPI_DOUBLE, 0, world); - - // Fill in the cutoff matrix. - cutoff = -1; - cutoff_matrix = Eigen::MatrixXd::Zero(n_species, n_species); - int cutoff_count = 0; - for (int i = 0; i < n_species; i++){ - for (int j = 0; j < n_species; j++){ - double cutoff_val = cutoffs[cutoff_count]; - cutoff_matrix(i, j) = cutoff_val; - if (cutoff_val > cutoff) cutoff = cutoff_val; - cutoff_count ++; + sscanf(line, "%s", cutoff_str); // Cutoff function + if (!strcmp(cutoff_str, "cosine")) { + cutoff_code[k] = 1; + } else if (!strcmp(cutoff_str, "quadratic")) { + cutoff_code[k] = 2; + } else { + char str[128]; + snprintf(str, 128, "Cutoff function %s is not supported\n.", cutoff_str); + error->all(FLERR, str); } - } - // Set number of descriptors. - int n_radial = n_max * n_species; - n_descriptors = (n_radial * (n_radial + 1) / 2) * (l_max + 1); - - // Check the relationship between the power spectrum and beta. - int beta_check = n_descriptors * (n_descriptors + 1) / 2; - if (beta_check != beta_size) - error->all(FLERR, "Beta size doesn't match the number of descriptors."); - - // Set the radial basis. - if (!strcmp(radial_string, "chebyshev")) { - basis_function = chebyshev; - radial_hyps = std::vector{0, cutoff}; - } + // Parse the cutoffs. + int n_cutoffs = n_species * n_species; + memory->create(cutoffs, n_cutoffs, "pair:cutoffs"); + if (me == 0) + grab(fptr, n_cutoffs, cutoffs); + MPI_Bcast(cutoffs, n_cutoffs, MPI_DOUBLE, 0, world); + + // Fill in the cutoff matrix. + cutoff = -1; + cutoff_matrix = Eigen::MatrixXd::Zero(n_species, n_species); + int cutoff_count = 0; + for (int i = 0; i < n_species; i++){ + for (int j = 0; j < n_species; j++){ + double cutoff_val = cutoffs[cutoff_count]; + cutoff_matrix(i, j) = cutoff_val; + if (cutoff_val > cutoff) cutoff = cutoff_val; + cutoff_count ++; + } + } + cutoff_matrices.push_back(cutoff_matrix); + + // Compute indices and coefficients + std::vector descriptor_settings = {n_species, K[k], n_max[k], l_max[k]}; + std::vector> nu_kern = compute_indices(descriptor_settings); + Eigen::VectorXd coeffs_kern = compute_coeffs(K[k], l_max[k]); + nu.push_back(nu_kern); + coeffs.push_back(coeffs_kern); + + // Set number of descriptors. + std::vector last_index = nu_kern[nu_kern.size()-1]; + int n_descriptors = last_index[last_index.size()-1] + 1; + + // Check the relationship between the power spectrum and beta. + int beta_check = n_descriptors * (n_descriptors + 1) / 2; + if (beta_check != beta_size[k]) { + char str[128]; + snprintf(str, 128, "The beta size of kernel %d doesn't match the number of descriptors.", k); + error->all(FLERR, str); + } - // Set the cutoff function. - if (!strcmp(cutoff_string, "quadratic")) - cutoff_function = quadratic_cutoff; - else if (!strcmp(cutoff_string, "cosine")) - cutoff_function = cos_cutoff; - - // Parse the beta vectors. - memory->create(beta, beta_size * n_species, "pair:beta"); - if (me == 0) - grab(fptr, beta_size * n_species, beta); - MPI_Bcast(beta, beta_size * n_species, MPI_DOUBLE, 0, world); - - // Fill in the beta matrix. - // TODO: Remove factor of 2 from beta. - Eigen::MatrixXd beta_matrix; - int beta_count = 0; - double beta_val; - for (int k = 0; k < n_species; k++) { - beta_matrix = Eigen::MatrixXd::Zero(n_descriptors, n_descriptors); - for (int i = 0; i < n_descriptors; i++) { - for (int j = i; j < n_descriptors; j++) { - if (i == j) - beta_matrix(i, j) = beta[beta_count]; - else if (i != j) { - beta_val = beta[beta_count] / 2; - beta_matrix(i, j) = beta_val; - beta_matrix(j, i) = beta_val; + // Set the radial basis. + if (radial_code[k] == 1){ + basis_function.push_back(chebyshev); + std::vector rh = {0, cutoffs[0]}; // It does not matter what + // cutoff is used, will be + // modified to cutoff_matrix + // when computing descriptors + radial_hyps.push_back(rh); + std::vector ch; + cutoff_hyps.push_back(ch); + } + + // Set the cutoff function. + if (cutoff_code[k] == 1) + cutoff_function.push_back(cos_cutoff); + else if (cutoff_code[k] == 2) + cutoff_function.push_back(quadratic_cutoff); + + // Parse the beta vectors. + // TODO: check this memory creation + memory->create(beta, beta_size[k] * n_species, "pair:beta"); + grab(fptr, beta_size[k] * n_species, beta); + //MPI_Bcast(beta1, beta_size[k] * n_species, MPI_DOUBLE, 0, world); + + // Fill in the beta matrix. + // TODO: Remove factor of 2 from beta. + Eigen::MatrixXd beta_matrix; + std::vector beta_matrix_kern; + int beta_count = 0; + double beta_val; + for (int s = 0; s < n_species; s++) { + beta_matrix = Eigen::MatrixXd::Zero(n_descriptors, n_descriptors); + for (int i = 0; i < n_descriptors; i++) { + for (int j = i; j < n_descriptors; j++) { + if (i == j) + beta_matrix(i, j) = beta[beta_count]; + else if (i != j) { + beta_val = beta[beta_count] / 2; + beta_matrix(i, j) = beta_val; + beta_matrix(j, i) = beta_val; + } + beta_count++; } - beta_count++; } + beta_matrix_kern.push_back(beta_matrix); } - beta_matrices.push_back(beta_matrix); + beta_matrices.push_back(beta_matrix_kern); + + // TODO: check this memory destroy + memory->destroy(beta); + } } diff --git a/lammps_plugins/pair_flare.h b/lammps_plugins/pair_flare.h index 1c984791..9cc677e8 100644 --- a/lammps_plugins/pair_flare.h +++ b/lammps_plugins/pair_flare.h @@ -28,21 +28,27 @@ class PairFLARE : public Pair { double init_one(int, int); protected: - int n_species, n_max, l_max, n_descriptors, beta_size; - - std::function &, std::vector &, double, int, - std::vector)> + int num_kern, n_species, n_descriptors; + int *K, *n_max, *l_max, *beta_size; + std::vector coeffs; // coefficient of A product in B + std::vector>> nu; // indices of A product + + int *radial_code, *cutoff_code; + std::vector &, std::vector &, double, int, + std::vector)>> basis_function; - std::function &, double, double, - std::vector)> + std::vector &, double, double, + std::vector)>> cutoff_function; - std::vector radial_hyps, cutoff_hyps; + std::vector> radial_hyps, cutoff_hyps; + double *cutoffs; double cutoff; - double *beta, *cutoffs; + double *beta; Eigen::MatrixXd beta_matrix, cutoff_matrix; - std::vector beta_matrices; + std::vector cutoff_matrices; + std::vector> beta_matrices; virtual void allocate(); virtual void read_file(char *); @@ -52,4 +58,4 @@ class PairFLARE : public Pair { } // namespace LAMMPS_NS #endif -#endif \ No newline at end of file +#endif diff --git a/scripts/CMakeLists.txt b/scripts/CMakeLists.txt new file mode 100644 index 00000000..7522ac1a --- /dev/null +++ b/scripts/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(mpi_build mpi_build.cpp) +include_directories(../src/flare_pp) +target_link_libraries(mpi_build PUBLIC flare) + +add_executable(omp_build omp_build.cpp) +include_directories(../src/flare_pp) +target_link_libraries(omp_build PUBLIC flare) diff --git a/scripts/mpi_build.cpp b/scripts/mpi_build.cpp new file mode 100644 index 00000000..cadc2449 --- /dev/null +++ b/scripts/mpi_build.cpp @@ -0,0 +1,187 @@ +#include +#include +#include +#include +#include +#include +#include +#include "omp.h" +#include "mpi.h" + +#include "utils.h" +#include "b2.h" +#include "parallel_sgp.h" +#include "sparse_gp.h" +#include "structure.h" +#include "normalized_dot_product.h" + +int main(int argc, char* argv[]) { + // Default setting of descriptors + int N = 8; + int L = 4; + double cutoff = 4.0; + std::string radial_string = "chebyshev"; + std::string cutoff_string = "quadratic"; + + // Species setting + int n_species = 1; + bool input_species = false; + std::map species_map; + std::vector single_atom_energy; + + // Default setting of kernels + double sigma = 7.0; + int power = 2; + + // Default setting of SGP + double sigma_e = 0.1; + double sigma_f = 0.1; + double sigma_s = 0.001; + double Kuu_jitter = 1e-8; + + // Default input file + std::string filename = "dft_data.xyz"; + std::string coefname = "par_beta.txt"; + std::string contributor = "Me"; + + // Read input file + std::ifstream file("input.flare"); + std::vector v; + if (file.is_open()) { + std::string line; + while (std::getline(file, line)) { + v = utils::split(line, ' '); + if (v[0] == std::string("cutoff")) { + cutoff = std::stod(v[1]); + } else if (v[0] == std::string("nmax")) { + N = std::stoi(v[1]); + } else if (v[0] == std::string("lmax")) { + L = std::stoi(v[1]); + } else if (v[0] == std::string("radial_func")) { + radial_string = v[1]; + } else if (v[0] == std::string("cutoff_func")) { + cutoff_string = v[1]; + } else if (v[0] == std::string("species")) { + for (int s = 1; s < v.size(); s++) { + species_map[v[s]] = s - 1; + } + n_species = v.size() - 1; + input_species = true; + } else if (v[0] == std::string("single_atom_energy")) { + for (int s = 1; s < v.size(); s++) { + single_atom_energy.push_back(std::stod(v[s])); + } + } else if (v[0] == std::string("sigma")) { + sigma = std::stod(v[1]); + } else if (v[0] == std::string("power")) { + power = std::stoi(v[1]); + } else if (v[0] == std::string("sigma_e")) { + sigma_e = std::stod(v[1]); + } else if (v[0] == std::string("sigma_f")) { + sigma_f = std::stod(v[1]); + } else if (v[0] == std::string("sigma_s")) { + sigma_s = std::stod(v[1]); + } else if (v[0] == std::string("Kuu_jitter")) { + Kuu_jitter = std::stod(v[1]); + } else if (v[0] == std::string("data_file")) { + filename = v[1]; + } else if (v[0] == std::string("coef_file")) { + coefname = v[1]; + } else if (v[0] == std::string("contributor")) { + contributor = v[1]; + } + } + } + + if (!input_species) throw; // Check the species are input + if (single_atom_energy.size() == 0) { // Check single atom energies are input + for (int s = 0; s < n_species; s++) { // Otherwise set to 0 + single_atom_energy.push_back(0.0); + } + } else { // Single atoms energies should correspond + assert(single_atom_energy.size() == n_species); // to species list + } + + // Set descriptors for SGP + std::vector radial_hyps{0, cutoff}; + std::vector cutoff_hyps; + + std::vector descriptor_settings{n_species, N, L}; + + std::vector dc; + B2 ps(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps); + + // Set kernels for SGP + NormalizedDotProduct kernel_norm = NormalizedDotProduct(sigma, power); + std::vector kernels; + kernels.push_back(&kernel_norm); + + std::cout << "Kuu_jitter=" << Kuu_jitter << std::endl; + ParallelSGP parallel_sgp = ParallelSGP(kernels, sigma_e, sigma_f, sigma_s); + parallel_sgp.Kuu_jitter = Kuu_jitter; + + // Read data from .xyz file + std::vector struc_list; + std::vector>> sparse_indices; + std::tie(struc_list, sparse_indices) = utils::read_xyz(filename, species_map); + + // Build parallel sgp + int n_types = n_species; + parallel_sgp.build(struc_list, cutoff, dc, sparse_indices, n_types); + std::cout << "Parallel_sgp is built!" << std::endl; + + if (blacs::mpirank == 0) { + parallel_sgp.write_mapping_coefficients(coefname, contributor, {0}); + std::cout << "Mapping coefficients are written" << std::endl; + + // validate the Kuu is symmetric + for (int r = 0; r < parallel_sgp.Kuu.rows(); r++) { + for (int c = r; c < parallel_sgp.Kuu.cols(); c++) { + double dKuu = parallel_sgp.Kuu(r, c) - parallel_sgp.Kuu(c, r); + if (std::abs(dKuu) > 1e-6) { + std::cout << r << " " << c << " " << dKuu << " " << parallel_sgp.Kuu(r, c) << std::endl; + } + } + } + + std::cout << "start L_inv" << std::endl; + for (int r = 0; r < parallel_sgp.L_inv.rows(); r++) { + for (int c = 0; c <= r; c++) { + std::cout << r << " " << c << " " << parallel_sgp.L_inv(r, c) << std::endl; + } + } + std::cout << "end L_inv" << std::endl; + + std::cout << "start R_inv" << std::endl; + for (int r = 0; r < parallel_sgp.R_inv.rows(); r++) { + for (int c = r; c < parallel_sgp.R_inv.cols(); c++) { + std::cout << r << " " << c << " " << parallel_sgp.R_inv(r, c) << std::endl; + } + } + std::cout << "end R_inv" << std::endl; + +// std::cout << "start R" << std::endl; +// for (int r = 0; r < parallel_sgp.R.rows(); r++) { +// for (int c = r; c < parallel_sgp.R.cols(); c++) { +// std::cout << r << " " << c << " " << std::setprecision (17) << parallel_sgp.R(r, c) << std::endl; +// } +// } +// std::cout << "end R" << std::endl; +// +// +// std::cout << "Start Q_b" << std::endl; +// for (int r = 0; r < parallel_sgp.Q_b.size(); r++) { +// std::cout << std::setprecision (17) << parallel_sgp.Q_b(r) << std::endl; +// } +// std::cout << "End Q_b" << std::endl; + + std::cout << "Start alpha" << std::endl; + for (int r = 0; r < parallel_sgp.alpha.size(); r++) { + std::cout << std::setprecision (17) << parallel_sgp.alpha(r) << std::endl; + } + std::cout << "End alpha" << std::endl; + } + +} diff --git a/scripts/omp_build.cpp b/scripts/omp_build.cpp new file mode 100644 index 00000000..61bfa860 --- /dev/null +++ b/scripts/omp_build.cpp @@ -0,0 +1,171 @@ +#include +#include +#include +#include +#include +#include +#include "omp.h" + +#include "utils.h" +#include "b2.h" +#include "sparse_gp.h" +#include "structure.h" +#include "normalized_dot_product.h" + +int main(int argc, char* argv[]) { + // Default setting of descriptors + int N = 8; + int L = 4; + double cutoff = 4.0; + std::string radial_string = "chebyshev"; + std::string cutoff_string = "quadratic"; + + // Species setting + int n_species = 1; + bool input_species = false; + std::map species_map; + + // Default setting of kernels + double sigma = 7.0; + int power = 2; + + // Default setting of SGP + double sigma_e = 0.1; + double sigma_f = 0.1; + double sigma_s = 0.001; + double Kuu_jitter = 1e-8; + + // Default input file + std::string filename = "dft_data.xyz"; + std::string coefname = "par_beta.txt"; + std::string contributor = "Me"; + + // Read input file + std::ifstream file("input.flare"); + std::vector v; + if (file.is_open()) { + std::string line; + while (std::getline(file, line)) { + v = utils::split(line, ' '); + if (v[0] == std::string("cutoff")) { + cutoff = std::stod(v[1]); + } else if (v[0] == std::string("nmax")) { + N = std::stoi(v[1]); + } else if (v[0] == std::string("lmax")) { + L = std::stoi(v[1]); + } else if (v[0] == std::string("radial_func")) { + radial_string = v[1]; + } else if (v[0] == std::string("cutoff_func")) { + cutoff_string = v[1]; + } else if (v[0] == std::string("species")) { + for (int s = 1; s < v.size(); s++) { + species_map[v[s]] = s - 1; + } + n_species = v.size() - 1; + input_species = true; + } else if (v[0] == std::string("sigma")) { + sigma = std::stod(v[1]); + } else if (v[0] == std::string("power")) { + power = std::stoi(v[1]); + } else if (v[0] == std::string("sigma_e")) { + sigma_e = std::stod(v[1]); + } else if (v[0] == std::string("sigma_f")) { + sigma_f = std::stod(v[1]); + } else if (v[0] == std::string("sigma_s")) { + sigma_s = std::stod(v[1]); + } else if (v[0] == std::string("Kuu_jitter")) { + Kuu_jitter = std::stod(v[1]); + } else if (v[0] == std::string("data_file")) { + filename = v[1]; + } else if (v[0] == std::string("coef_file")) { + coefname = v[1]; + } else if (v[0] == std::string("contributor")) { + contributor = v[1]; + } + } + } + + if (!input_species) throw; + + std::vector radial_hyps{0, cutoff}; + std::vector cutoff_hyps; + + std::vector descriptor_settings{n_species, N, L}; + + std::vector dc; + B2 ps(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps); + + NormalizedDotProduct kernel_norm = NormalizedDotProduct(sigma, power); + std::vector kernels; + kernels.push_back(&kernel_norm); + + std::cout << "Kuu_jitter=" << Kuu_jitter << std::endl; + SparseGP sparse_gp(kernels, sigma_e, sigma_f, sigma_s); + sparse_gp.Kuu_jitter = Kuu_jitter; + + // Read data from .xyz file + std::vector struc_list; + std::vector>> sparse_indices; + std::tie(struc_list, sparse_indices) = utils::read_xyz(filename, species_map); + + // Build parallel sgp + int n_types = n_species; + for (int t = 0; t < struc_list.size(); t++) { + std::cout << "Adding structure " << t << std::endl; + Structure struc(struc_list[t].cell, struc_list[t].species, + struc_list[t].positions, cutoff, dc); + + struc.energy = struc_list[t].energy; + struc.forces = struc_list[t].forces; + struc.stresses = struc_list[t].stresses; + + sparse_gp.add_training_structure(struc); + sparse_gp.add_specific_environments(struc, sparse_indices[0][t]); + } + std::cout << "Finish adding. Start building matrices" << std::endl; + sparse_gp.update_matrices_QR(); + std::cout << "Sparse GP is built!" << std::endl; + + sparse_gp.write_mapping_coefficients(coefname, contributor, {0}); + std::cout << "Mapping coefficients are written" << std::endl; + + std::cout << "start L_inv" << std::endl; + for (int r = 0; r < sparse_gp.L_inv.rows(); r++) { + for (int c = 0; c <= r; c++) { + std::cout << r << " " << c << " " << sparse_gp.L_inv(r, c) << std::endl; + } + } + std::cout << "end L_inv" << std::endl; + + std::cout << "start R_inv" << std::endl; + for (int r = 0; r < sparse_gp.R_inv.rows(); r++) { + for (int c = r; c < sparse_gp.R_inv.cols(); c++) { + std::cout << r << " " << c << " " << sparse_gp.R_inv(r, c) << std::endl; + } + } + std::cout << "end R_inv" << std::endl; + +// std::cout << "start R" << std::endl; +// for (int r = 0; r < sparse_gp.R.rows(); r++) { +// for (int c = r; c < sparse_gp.R.cols(); c++) { +// std::cout << r << " " << c << " " << std::setprecision (17) << sparse_gp.R(r, c) << std::endl; +// } +// } +// std::cout << "end R" << std::endl; +// +// +// std::cout << "Start Q_b" << std::endl; +// for (int r = 0; r < sparse_gp.Q_b.size(); r++) { +// std::cout << std::setprecision (17) << sparse_gp.Q_b(r) << std::endl; +// } +// std::cout << "End Q_b" << std::endl; + + std::cout << "Start alpha" << std::endl; + for (int r = 0; r < sparse_gp.alpha.size(); r++) { + std::cout << std::setprecision (17) << sparse_gp.alpha(r) << std::endl; + } + std::cout << "End alpha" << std::endl; + +} diff --git a/setup.py b/setup.py index 99e80f0a..8423af25 100644 --- a/setup.py +++ b/setup.py @@ -161,7 +161,7 @@ def _decode(x): setuptools.setup( name="flare_pp", packages=setuptools.find_packages(exclude=[]), - version="0.0.23", + version="0.1.0", author="Materials Intelligence Research", author_email="mir@g.harvard.edu", description="Many-body extension of the flare code", diff --git a/src/ace_binding.cpp b/src/ace_binding.cpp index 63afcd37..339178ff 100644 --- a/src/ace_binding.cpp +++ b/src/ace_binding.cpp @@ -2,10 +2,11 @@ #include "structure.h" #include "y_grad.h" #include "sparse_gp.h" +#include "parallel_sgp.h" #include "b2.h" #include "b2_simple.h" #include "b2_norm.h" -#include "b3.h" +#include "bk.h" #include "two_body.h" #include "three_body.h" #include "three_body_wide.h" @@ -102,6 +103,16 @@ PYBIND11_MODULE(_C_flare, m) { .def(py::init &>()); + py::class_(m, "B2_Simple") + .def(py::init &, const std::vector &, + const std::vector &>()); + + py::class_(m, "B2_Norm") + .def(py::init &, const std::vector &, + const std::vector &>()); + py::class_(m, "B2") .def(py::init &, const std::vector &, @@ -117,20 +128,20 @@ PYBIND11_MODULE(_C_flare, m) { .def_readonly("cutoffs", &B2::cutoffs) .def_readonly("descriptor_settings", &B2::descriptor_settings); - py::class_(m, "B2_Simple") + py::class_(m, "Bk") .def(py::init &, const std::vector &, - const std::vector &>()); - - py::class_(m, "B2_Norm") - .def(py::init &, const std::vector &, - const std::vector &>()); - - py::class_(m, "B3") + const std::vector &>()) .def(py::init &, const std::vector &, - const std::vector &>()); + const std::vector &, + const Eigen::MatrixXd &>()) + .def_readonly("radial_basis", &Bk::radial_basis) + .def_readonly("cutoff_function", &Bk::cutoff_function) + .def_readonly("radial_hyps", &Bk::radial_hyps) + .def_readonly("cutoff_hyps", &Bk::cutoff_hyps) + .def_readonly("cutoffs", &Bk::cutoffs) + .def_readonly("descriptor_settings", &Bk::descriptor_settings); // Kernel functions py::class_(m, "Kernel"); @@ -172,10 +183,10 @@ PYBIND11_MODULE(_C_flare, m) { .def("compute_likelihood_stable", &SparseGP::compute_likelihood_stable) .def("compute_likelihood_gradient", &SparseGP::compute_likelihood_gradient) + .def("compute_likelihood_gradient_stable", + &SparseGP::compute_likelihood_gradient_stable) + .def("precompute_KnK", &SparseGP::precompute_KnK) .def("write_mapping_coefficients", &SparseGP::write_mapping_coefficients) - .def_readonly("varmap_coeffs", &SparseGP::varmap_coeffs) // for debugging and unit test - .def("compute_cluster_uncertainties", &SparseGP::compute_cluster_uncertainties) // for debugging and unit test - .def("write_varmap_coefficients", &SparseGP::write_varmap_coefficients) .def_readwrite("Kuu_jitter", &SparseGP::Kuu_jitter) .def_readonly("complexity_penalty", &SparseGP::complexity_penalty) .def_readonly("data_fit", &SparseGP::data_fit) @@ -199,6 +210,9 @@ PYBIND11_MODULE(_C_flare, m) { .def_readonly("Kuu_kernels", &SparseGP::Kuu_kernels) .def_readonly("Kuf", &SparseGP::Kuf) .def_readonly("Kuf_kernels", &SparseGP::Kuf_kernels) + .def_readwrite("Kuf_e_noise_Kfu", &SparseGP::Kuf_e_noise_Kfu) + .def_readwrite("Kuf_f_noise_Kfu", &SparseGP::Kuf_f_noise_Kfu) + .def_readwrite("Kuf_s_noise_Kfu", &SparseGP::Kuf_s_noise_Kfu) .def_readonly("alpha", &SparseGP::alpha) .def_readonly("Kuu_inverse", &SparseGP::Kuu_inverse) .def_readonly("Sigma", &SparseGP::Sigma) @@ -207,4 +221,15 @@ PYBIND11_MODULE(_C_flare, m) { .def_readonly("y", &SparseGP::y) .def_static("to_json", &SparseGP::to_json) .def_static("from_json", &SparseGP::from_json); + + py::class_(m, "ParallelSGP") + .def(py::init<>()) + .def(py::init, double, double, double>()) + .def_readwrite("finalize_MPI", &ParallelSGP::finalize_MPI) + .def("build", &ParallelSGP::build) + .def("set_hyperparameters", &ParallelSGP::set_hyperparameters) + .def("compute_likelihood_stable", &ParallelSGP::compute_likelihood_stable) + .def("compute_likelihood_gradient_stable", &ParallelSGP::compute_likelihood_gradient_stable) + .def("predict_local_uncertainties", &ParallelSGP::predict_local_uncertainties) + .def("predict_on_structures", &ParallelSGP::predict_on_structures); } diff --git a/src/flare_pp/bffs/parallel_sgp.cpp b/src/flare_pp/bffs/parallel_sgp.cpp new file mode 100644 index 00000000..d8c7ea2a --- /dev/null +++ b/src/flare_pp/bffs/parallel_sgp.cpp @@ -0,0 +1,1375 @@ +#include "parallel_sgp.h" +#include +#include // Random shuffle +#include +#include // File operations +#include // setprecision +#include +#include // Iota + +#include +#include +#include + + +#define MAXLINE 1024 + + +ParallelSGP ::ParallelSGP() {} + +ParallelSGP ::ParallelSGP(std::vector kernels, double energy_noise, + double force_noise, double stress_noise) { + + this->kernels = kernels; + n_kernels = kernels.size(); + Kuu_jitter = 1e-8; // default value + label_count = Eigen::VectorXd::Zero(1); + + // Count hyperparameters. + int n_hyps = 0; + for (int i = 0; i < kernels.size(); i++) { + n_hyps += kernels[i]->kernel_hyperparameters.size(); + } + + // Set the kernel hyperparameters. + hyperparameters = Eigen::VectorXd::Zero(n_hyps + 3); + Eigen::VectorXd hyps_curr; + int hyp_counter = 0; + for (int i = 0; i < kernels.size(); i++) { + hyps_curr = kernels[i]->kernel_hyperparameters; + + for (int j = 0; j < hyps_curr.size(); j++) { + hyperparameters(hyp_counter) = hyps_curr(j); + hyp_counter++; + } + } + + // Set the noise hyperparameters. + hyperparameters(n_hyps) = energy_noise; + hyperparameters(n_hyps + 1) = force_noise; + hyperparameters(n_hyps + 2) = stress_noise; + + this->energy_noise = energy_noise; + this->force_noise = force_noise; + this->stress_noise = stress_noise; + + // Initialize kernel lists. + Eigen::MatrixXd empty_matrix; + for (int i = 0; i < kernels.size(); i++) { + Kuu_kernels.push_back(empty_matrix); + Kuf_kernels.push_back(empty_matrix); + } +} + +// Destructor +ParallelSGP ::~ParallelSGP() { + if (finalize_MPI) blacs::finalize(); +} + +void ParallelSGP ::add_training_structure(const Structure &structure) { + int n_energy = structure.energy.size(); + int n_force = structure.forces.size(); + int n_stress = structure.stresses.size(); + int n_struc_labels = n_energy + n_force + n_stress; + int n_atoms = structure.noa; + + // No updating Kuf + + // Update labels. + label_count.conservativeResize(training_structures.size() + 2); + label_count(training_structures.size() + 1) = n_labels + n_struc_labels; + y.conservativeResize(n_labels + n_struc_labels); + y.segment(n_labels, n_energy) = structure.energy; + y.segment(n_labels + n_energy, n_force) = structure.forces; + y.segment(n_labels + n_energy + n_force, n_stress) = structure.stresses; + + // Update noise. + noise_vector.conservativeResize(n_labels + n_struc_labels); + noise_vector.segment(n_labels, n_energy) = + Eigen::VectorXd::Constant(n_energy, 1 / (energy_noise * energy_noise)); + noise_vector.segment(n_labels + n_energy, n_force) = + Eigen::VectorXd::Constant(n_force, 1 / (force_noise * force_noise)); + noise_vector.segment(n_labels + n_energy + n_force, n_stress) = + Eigen::VectorXd::Constant(n_stress, 1 / (stress_noise * stress_noise)); + + // Save "1" vector for energy, force and stress noise, for likelihood gradient calculation + e_noise_one.conservativeResize(n_labels + n_struc_labels); + f_noise_one.conservativeResize(n_labels + n_struc_labels); + s_noise_one.conservativeResize(n_labels + n_struc_labels); + + e_noise_one.segment(n_labels, n_struc_labels) = Eigen::VectorXd::Zero(n_struc_labels); + f_noise_one.segment(n_labels, n_struc_labels) = Eigen::VectorXd::Zero(n_struc_labels); + s_noise_one.segment(n_labels, n_struc_labels) = Eigen::VectorXd::Zero(n_struc_labels); + + e_noise_one.segment(n_labels, n_energy) = Eigen::VectorXd::Ones(n_energy); + f_noise_one.segment(n_labels + n_energy, n_force) = Eigen::VectorXd::Ones(n_force); + s_noise_one.segment(n_labels + n_energy + n_force, n_stress) = Eigen::VectorXd::Ones(n_stress); + + // Update label count. + n_energy_labels += n_energy; + n_force_labels += n_force; + n_stress_labels += n_stress; + n_labels += n_struc_labels; + + // Store training structure. + n_strucs += 1; +} + +Eigen::VectorXi ParallelSGP ::sparse_indices_by_type(int n_types, + std::vector species, const std::vector atoms) { + // Compute the number of sparse envs of each type + // TODO: support two-body/three-body descriptors + Eigen::VectorXi n_envs_by_type = Eigen::VectorXi::Zero(n_types); + for (int a = 0; a < atoms.size(); a++) { + int s = species[atoms[a]]; + n_envs_by_type(s)++; + } + return n_envs_by_type; +} + +void ParallelSGP ::add_specific_environments(const Structure &structure, + const std::vector> atoms, bool update) { + + // Gather clusters with central atom in the given list. + std::vector>> indices_1; + for (int i = 0; i < n_kernels; i++){ + int n_types = structure.descriptors[i].n_types; + std::vector> indices_2; + for (int j = 0; j < n_types; j++){ + int n_clusters = structure.descriptors[i].n_clusters_by_type[j]; + std::vector indices_3; + for (int k = 0; k < n_clusters; k++){ + int atom_index_1 = structure.descriptors[i].atom_indices[j](k); + for (int l = 0; l < atoms[i].size(); l++){ + int atom_index_2 = atoms[i][l]; + if (atom_index_1 == atom_index_2){ + indices_3.push_back(k); + } + } + } + indices_2.push_back(indices_3); + } + indices_1.push_back(indices_2); + } + + // Create cluster descriptors. + std::vector cluster_descriptors; + for (int i = 0; i < n_kernels; i++) { + ClusterDescriptor cluster_descriptor = + ClusterDescriptor(structure.descriptors[i], indices_1[i]); + cluster_descriptors.push_back(cluster_descriptor); + } +// local_sparse_descriptors.push_back(cluster_descriptors); + if (update) { + // Store sparse environments. + for (int i = 0; i < n_kernels; i++) { + sparse_descriptors[i].add_clusters_by_type(structure.descriptors[i], + indices_1[i]); + } + } else { + local_sparse_descriptors.push_back(cluster_descriptors); + } +} + +void ParallelSGP ::add_global_noise(int n_energy, int n_force, int n_stress) { + + int n_struc_labels = n_energy + n_force + n_stress; + + // Update noise. + global_noise_vector.conservativeResize(global_n_labels + n_struc_labels); + global_noise_vector.segment(global_n_labels, n_energy) = + Eigen::VectorXd::Constant(n_energy, 1 / (energy_noise * energy_noise)); + global_noise_vector.segment(global_n_labels + n_energy, n_force) = + Eigen::VectorXd::Constant(n_force, 1 / (force_noise * force_noise)); + global_noise_vector.segment(global_n_labels + n_energy + n_force, n_stress) = + Eigen::VectorXd::Constant(n_stress, 1 / (stress_noise * stress_noise)); + + global_n_energy_labels += n_energy; + global_n_force_labels += n_force; + global_n_stress_labels += n_stress; + global_n_labels += n_struc_labels; + +} + +void ParallelSGP::build(const std::vector &training_strucs, + double cutoff, std::vector descriptor_calculators, + const std::vector>> &training_sparse_indices, + int n_types, bool update) { + + // initialization + u_size_single_kernel = {}; + + // Initialize kernel lists. + Eigen::MatrixXd empty_matrix; + Kuu_kernels = {}; + Kuf_kernels = {}; + for (int i = 0; i < kernels.size(); i++) { + Kuu_kernels.push_back(empty_matrix); + Kuf_kernels.push_back(empty_matrix); + } + + // initialize BLACS + blacs::initialize(); + + timer.tic(); + + // Get the number of processes + int world_size; + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + + // Get the rank of the process + int world_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + + // Compute the dimensions of the matrices Kuf and Kuu + f_size = 0; + for (int i = 0; i < training_strucs.size(); i++) { + f_size += training_strucs[i].energy.size() + training_strucs[i].forces.size() + training_strucs[i].stresses.size(); + } + if (f_size % world_size == 0) { + f_size_per_proc = f_size / world_size; + } else { + f_size_per_proc = f_size / world_size + 1; + } + + u_size = 0; + for (int k = 0; k < n_kernels; k++) { + int u_kern = 0; + for (int i = 0; i < training_sparse_indices[k].size(); i++) { + u_kern += training_sparse_indices[k][i].size(); + } + u_size += u_kern; + u_size_single_kernel.push_back(u_kern); + } + u_size_per_proc = u_size / world_size; + + // Compute the range of structures covered by the current rank + nmin_struc = world_rank * f_size_per_proc; + nmin_envs = world_rank * u_size_per_proc; + if (world_rank == world_size - 1) { + nmax_struc = f_size; + nmax_envs = u_size; + } else { + nmax_struc = (world_rank + 1) * f_size_per_proc; + nmax_envs = (world_rank + 1) * u_size_per_proc; + } + + timer.toc("initialize build"); + + // load and distribute training structures, compute descriptors + load_local_training_data(training_strucs, cutoff, descriptor_calculators, training_sparse_indices, n_types, update); + + // compute kernel matrices from training data + compute_kernel_matrices(training_strucs); + + update_matrices_QR(); + + // TODO: finalize BLACS + //blacs::finalize(); + +} + +/* ------------------------------------------------------------------------- + * Load training data and compute descriptors + * ------------------------------------------------------------------------- */ + +void ParallelSGP::load_local_training_data(const std::vector &training_strucs, + double cutoff, std::vector descriptor_calculators, + const std::vector>> &training_sparse_indices, + int n_types, bool update) { + + timer.tic(); + + // Distribute the training structures and sparse envs + Structure struc; + int cum_f = 0; + + global_n_labels = 0; + global_n_energy_labels = 0; + global_n_force_labels = 0; + global_n_stress_labels = 0; + global_noise_vector = Eigen::VectorXd::Zero(0); + + n_energy_labels = 0; + n_force_labels = 0; + n_stress_labels = 0; + n_labels = 0; + n_strucs = 0; + + label_count = Eigen::VectorXd::Zero(1); + y = Eigen::VectorXd::Zero(0); + noise_vector = Eigen::VectorXd::Zero(0); + e_noise_one = Eigen::VectorXd::Zero(0); + f_noise_one = Eigen::VectorXd::Zero(0); + s_noise_one = Eigen::VectorXd::Zero(0); + + local_label_indices = {}; + local_label_size = 0; + local_sparse_descriptors = {}; + if (!update) { + sparse_descriptors = {}; + } + + // Not clean up training_structures + std::vector new_local_training_structure_indices; + timer.toc("initialize loading"); + + // Compute the total number of clusters of each type of each kernel + n_struc_clusters_by_type = {}; + std::vector> n_clusters_by_type; + for (int i = 0; i < n_kernels; i++) { + std::vector n_clusters_kern; + for (int s = 0; s < n_types; s++) { + n_clusters_kern.push_back(0); + } + n_clusters_by_type.push_back(n_clusters_kern); + n_struc_clusters_by_type.push_back({}); + } + + int num_build_strucs = 0; + for (int t = 0; t < training_strucs.size(); t++) { + timer.tic(); + int label_size = training_strucs[t].n_labels(); + int noa = training_strucs[t].noa; + + int n_energy = 1; + int n_forces = 3 * noa; + int n_stress = 6; + add_global_noise(n_energy, n_forces, n_stress); // for b + + // Compute the total number of clusters of each type + for (int i = 0; i < n_kernels; i++) { + Eigen::VectorXi n_envs_by_type = sparse_indices_by_type(n_types, + training_strucs[t].species, training_sparse_indices[i][t]); + n_struc_clusters_by_type[i].push_back(n_envs_by_type); + for (int s = 0; s < n_types; s++) n_clusters_by_type[i][s] += n_envs_by_type(s); + } + timer.toc("add_global_noise & sparse_indices_by_type", blacs::mpirank); + + // Collect local training structures for A, Kuf + // Check if the current struc belongs to the current process + if (nmin_struc < cum_f + label_size && cum_f < nmax_struc) { + timer.tic(); + bool build_struc = true; + if (update) { // if just adding a few new strucs (appened in training_strucs) + // check if the current struc is already in the list + std::vector::iterator local_struc_index = + std::find(local_training_structure_indices.begin(), + local_training_structure_indices.end(), t); + + // if it is, then use the existing one, otherwise build it up + + if (local_struc_index != local_training_structure_indices.end()) { + int local_struc_ind = std::distance(local_training_structure_indices.begin(), local_struc_index); + struc = training_structures[local_struc_ind]; + build_struc = false; + } + } + + if (build_struc) { + // compute descriptors for the current struc + struc = Structure(training_strucs[t].cell, training_strucs[t].species, + training_strucs[t].positions, cutoff, descriptor_calculators); + + struc.energy = training_strucs[t].energy; + struc.forces = training_strucs[t].forces; + struc.stresses = training_strucs[t].stresses; + + training_structures.push_back(struc); + num_build_strucs += 1; + } + + // add the current struc to the local list + add_training_structure(struc); + new_local_training_structure_indices.push_back(t); + + // save all the label indices that belongs to the current process + std::vector label_inds; + for (int l = 0; l < label_size; l++) { + if (cum_f + l >= nmin_struc && cum_f + l < nmax_struc) { + label_inds.push_back(l); + } + } + local_label_indices.push_back(label_inds); + local_label_size += label_size; + + // compute sparse descriptors if the first label of this struc belongs to + // the current process, to avoid multiple procs add the same sparse envs + if (!update) { + if (nmin_struc <= cum_f && cum_f < nmax_struc) { + std::vector> sparse_atoms = {}; + for (int i = 0; i < n_kernels; i++) { + sparse_atoms.push_back(training_sparse_indices[i][t]); + } + add_specific_environments(struc, sparse_atoms, update); + } + } + } + + if (update) { // add the sparse envs to the global "sparse_descriptors" + if (t == (training_strucs.size() - 1)) { + // check if the current struc is already in the list + std::vector::iterator local_struc_index = + std::find(new_local_training_structure_indices.begin(), + new_local_training_structure_indices.end(), t); + + // if it is, then use the existing one, otherwise build it up + if (local_struc_index != new_local_training_structure_indices.end()) { + struc = training_structures[training_structures.size() - 1]; + } else { + // compute descriptors for the current struc + struc = Structure(training_strucs[t].cell, training_strucs[t].species, + training_strucs[t].positions, cutoff, descriptor_calculators); + + struc.energy = training_strucs[t].energy; + struc.forces = training_strucs[t].forces; + struc.stresses = training_strucs[t].stresses; + } + + std::vector> sparse_atoms = {}; + for (int i = 0; i < n_kernels; i++) { + sparse_atoms.push_back(training_sparse_indices[i][t]); + } + + add_specific_environments(struc, sparse_atoms, update); + num_build_strucs += 1; + } + } + + cum_f += label_size; + } + + // remove structures in the old list that are not in the new list + for (int t = local_training_structure_indices.size() - 1; t >= 0; t--) { + int t0 = local_training_structure_indices[t]; + if (std::find(new_local_training_structure_indices.begin(), + new_local_training_structure_indices.end(), t0) + == new_local_training_structure_indices.end()) { + + training_structures.erase(training_structures.begin() + t); + } + } + + local_training_structure_indices = new_local_training_structure_indices; + std::cout << "Rank " << blacs::mpirank << " build " << num_build_strucs << " new strucs" << std::endl; + + blacs::barrier(); + timer.toc("compute structure descriptors", blacs::mpirank); + + // Gather all sparse descriptors to each process + if (!update) + gather_sparse_descriptors(n_clusters_by_type, training_strucs); +} + +/* ------------------------------------------------------------------------- + * Gather distributed sparse descriptors + * ------------------------------------------------------------------------- */ + +void ParallelSGP::gather_sparse_descriptors(std::vector> n_clusters_by_type, + const std::vector &training_strucs) { + + timer.tic(); + for (int i = 0; i < n_kernels; i++) { + // Assign global sparse descritors + int n_descriptors = training_structures[0].descriptors[i].n_descriptors; + int n_types = training_structures[0].descriptors[i].n_types; + + int cum_f, local_u, cum_u; + std::vector descriptors; + std::vector descriptor_norms, cutoff_values; + for (int s = 0; s < n_types; s++) { + if (n_clusters_by_type[i][s] == 0) { // throw error if there is no cluster + std::stringstream errmsg; + errmsg << "Number of clusters of kernel " << i << " and type " << s << " is 0"; + throw std::logic_error(errmsg.str()); + } + + DistMatrix dist_descriptors(n_clusters_by_type[i][s], n_descriptors); + DistMatrix dist_descriptor_norms(n_clusters_by_type[i][s], 1); + DistMatrix dist_cutoff_values(n_clusters_by_type[i][s], 1); + dist_descriptors = [](int i, int j){return 0.0;}; + dist_descriptor_norms = [](int i, int j){return 0.0;}; + dist_cutoff_values = [](int i, int j){return 0.0;}; + blacs::barrier(); + + cum_f = 0; + cum_u = 0; + local_u = 0; + bool lock = true; + for (int t = 0; t < training_strucs.size(); t++) { + if (nmin_struc <= cum_f && cum_f < nmax_struc) { + ClusterDescriptor cluster_descriptor = local_sparse_descriptors[local_u][i]; + for (int j = 0; j < n_struc_clusters_by_type[i][t](s); j++) { + for (int d = 0; d < n_descriptors; d++) { + dist_descriptors.set(cum_u + j, d, + cluster_descriptor.descriptors[s](j, d), lock); + dist_descriptor_norms.set(cum_u + j, 0, + cluster_descriptor.descriptor_norms[s](j), lock); + dist_cutoff_values.set(cum_u + j, 0, + cluster_descriptor.cutoff_values[s](j), lock); + } + } + local_u += 1; + } + cum_u += n_struc_clusters_by_type[i][t](s); + int label_size = training_strucs[t].n_labels(); + cum_f += label_size; + } + dist_descriptors.fence(); + dist_descriptor_norms.fence(); + dist_cutoff_values.fence(); + + int nrows = n_clusters_by_type[i][s]; + int ncols = n_descriptors; + Eigen::MatrixXd type_descriptors = Eigen::MatrixXd::Zero(nrows, ncols); + Eigen::VectorXd type_descriptor_norms = Eigen::VectorXd::Zero(nrows); + Eigen::VectorXd type_cutoff_values = Eigen::VectorXd::Zero(nrows); + + dist_descriptors.allgather(&type_descriptors(0, 0), 0, 0, nrows, ncols); + dist_descriptor_norms.allgather(&type_descriptor_norms(0, 0), 0, 0, nrows, 1); + dist_cutoff_values.allgather(&type_cutoff_values(0, 0), 0, 0, nrows, 1); + + descriptors.push_back(type_descriptors); + descriptor_norms.push_back(type_descriptor_norms); + cutoff_values.push_back(type_cutoff_values); + } + + // Store sparse environments. + std::vector cumulative_type_count = {0}; + int n_clusters = 0; + for (int s = 0; s < n_types; s++) { + cumulative_type_count.push_back(cumulative_type_count[s] + n_clusters_by_type[i][s]); + n_clusters += n_clusters_by_type[i][s]; + } + + ClusterDescriptor cluster_desc; + cluster_desc.initialize_cluster(n_types, n_descriptors); + cluster_desc.descriptors = descriptors; + cluster_desc.descriptor_norms = descriptor_norms; + cluster_desc.cutoff_values = cutoff_values; + cluster_desc.n_clusters_by_type = n_clusters_by_type[i]; + cluster_desc.cumulative_type_count = cumulative_type_count; + cluster_desc.n_clusters = n_clusters; + sparse_descriptors.push_back(cluster_desc); + } + + timer.toc("gather_sparse_descriptors", blacs::mpirank); +} + +/* ------------------------------------------------------------------------- + * Compute kernel matrices and alpha + * ------------------------------------------------------------------------- */ +void ParallelSGP ::stack_Kuf() { + // Update Kuf kernels. + int local_f_size = nmax_struc - nmin_struc; + Kuf_local = Eigen::MatrixXd::Zero(u_size, local_f_size); + + int count = 0; + for (int i = 0; i < Kuf_kernels.size(); i++) { + int size = Kuf_kernels[i].rows(); + Kuf_local.block(count, 0, size, local_f_size) = Kuf_kernels[i]; + count += size; + } + blacs::barrier(); +} + +void ParallelSGP::compute_kernel_matrices(const std::vector &training_strucs) { + timer.tic(); + + // Build block of A, y, Kuu using distributed training structures + std::vector kuf, kuu; + int cum_f = 0; + int cum_u = 0; + int local_f_size = nmax_struc - nmin_struc; + Kuf_local = Eigen::MatrixXd::Zero(u_size, local_f_size); + + for (int i = 0; i < n_kernels; i++) { + int u_kern = u_size_single_kernel[i]; // sparse set size of kernel i + assert(u_kern == sparse_descriptors[i].n_clusters); + Eigen::MatrixXd kuf_i = Eigen::MatrixXd::Zero(u_kern, local_f_size); + cum_f = 0; + for (int t = 0; t < training_structures.size(); t++) { + int f_size_i = local_label_indices[t].size(); + Eigen::MatrixXd kern_t = kernels[i]->envs_struc( + sparse_descriptors[i], + training_structures[t].descriptors[i], + kernels[i]->kernel_hyperparameters); + + // Remove columns of kern_t that is not assigned to the current processor + Eigen::MatrixXd kern_local = Eigen::MatrixXd::Zero(u_kern, f_size_i); + for (int l = 0; l < f_size_i; l++) { + kern_local.block(0, l, u_kern, 1) = kern_t.block(0, local_label_indices[t][l], u_kern, 1); + } + kuf_i.block(0, cum_f, u_kern, f_size_i) = kern_local; + cum_f += f_size_i; + } + //kuf.push_back(kuf_i); + + // TODO: change here into stack_Kuf + Kuf_local.block(cum_u, 0, u_kern, local_f_size) = kuf_i; + + Kuf_kernels[i] = kuf_i; + + Kuu_kernels[i] = kernels[i]->envs_envs( + sparse_descriptors[i], + sparse_descriptors[i], + kernels[i]->kernel_hyperparameters); + cum_u += u_kern; + } + + // Only keep the chunk of noise_vector assigned to the current proc + cum_f = 0; + cum_u = 0; + int cum_f_struc = 0; + local_noise_vector = Eigen::VectorXd::Zero(local_f_size); + local_e_noise_one = Eigen::VectorXd::Zero(local_f_size); + local_f_noise_one = Eigen::VectorXd::Zero(local_f_size); + local_s_noise_one = Eigen::VectorXd::Zero(local_f_size); + local_labels = Eigen::VectorXd::Zero(local_f_size); + for (int t = 0; t < training_structures.size(); t++) { + int f_size_i = local_label_indices[t].size(); + for (int l = 0; l < f_size_i; l++) { + local_noise_vector(cum_f + l) = noise_vector(cum_f_struc + local_label_indices[t][l]); + local_e_noise_one(cum_f + l) = e_noise_one(cum_f_struc + local_label_indices[t][l]); + local_f_noise_one(cum_f + l) = f_noise_one(cum_f_struc + local_label_indices[t][l]); + local_s_noise_one(cum_f + l) = s_noise_one(cum_f_struc + local_label_indices[t][l]); + local_labels(cum_f + l) = y(cum_f_struc + local_label_indices[t][l]); + } + cum_f += f_size_i; + cum_f_struc += training_strucs[t].n_labels(); + } + timer.toc("compute_kernel_matrice", blacs::mpirank); + + blacs::barrier(); +} + +void ParallelSGP::update_matrices_QR() { + timer.tic(); + // Store square root of noise vector. + Eigen::VectorXd noise_vector_sqrt = sqrt(local_noise_vector.array()); + Eigen::VectorXd global_noise_vector_sqrt = sqrt(global_noise_vector.array()); + + // Synchronize, wait until all training structures are ready on all processors + blacs::barrier(); + timer.toc("build local kuf kuu", blacs::mpirank); + + // Create distributed matrices + // specify the scope of the DistMatrix + timer.tic(); + + std::cout << "f_size=" << f_size << " , u_size=" << u_size << std::endl; + DistMatrix A(f_size + u_size, u_size); // use the default blocking + DistMatrix b(f_size + u_size, 1); + DistMatrix Kuu_dist(u_size, u_size); + A = [](int i, int j){return 0.0;}; + b = [](int i, int j){return 0.0;}; + Kuu_dist = [](int i, int j){return 0.0;}; + blacs::barrier(); + + bool lock = true; + int cum_u = 0; + // Assign sparse set kernel matrix Kuu + for (int i = 0; i < n_kernels; i++) { + for (int r = 0; r < Kuu_kernels[i].rows(); r++) { + for (int c = 0; c < Kuu_kernels[i].cols(); c++) { + if (Kuu_dist.islocal(r + cum_u, c + cum_u)) { // only set the local part + Kuu_dist.set(r + cum_u, c + cum_u, Kuu_kernels[i](r, c), lock); + } + } + } + cum_u += Kuu_kernels[i].rows(); + } + Kuu_dist.fence(); + + Kuu = Eigen::MatrixXd::Zero(u_size, u_size); + Kuu_dist.allgather(&Kuu(0, 0), 0, 0, u_size, u_size); + + timer.toc("build Kuu_dist", blacs::mpirank); + + // Cholesky decomposition of Kuu and its inverse. + timer.tic(); + Eigen::LLT chol( + Kuu + Kuu_jitter * Eigen::MatrixXd::Identity(Kuu.rows(), Kuu.cols())); + + // Get the inverse of Kuu from Cholesky decomposition. + Eigen::MatrixXd Kuu_eye = Eigen::MatrixXd::Identity(Kuu.rows(), Kuu.cols()); + Eigen::MatrixXd L = chol.matrixL(); + L_inv = chol.matrixL().solve(Kuu_eye); + Kuu_inverse = L_inv.transpose() * L_inv; + L_diag = L_inv.diagonal(); + + timer.toc("cholesky, tri_inv, matmul", blacs::mpirank); + + // Assign Lambda * Kfu to A + timer.tic(); + + int cum_f = 0; + int local_f_full = 0; + int local_f = 0; + Eigen::MatrixXd noise_kfu = noise_vector_sqrt.asDiagonal() * Kuf_local.transpose(); + Eigen::VectorXd noise_labels = noise_vector_sqrt.asDiagonal() * local_labels; + blacs::barrier(); + A.collect(&noise_kfu(0, 0), 0, 0, f_size, u_size, f_size_per_proc, u_size, nmax_struc - nmin_struc); + A.fence(); + b.collect(&noise_labels(0), 0, 0, f_size, 1, f_size_per_proc, 1, nmax_struc - nmin_struc); + b.fence(); + + timer.toc("set A & b", blacs::mpirank); + + // Copy L.T to A using scatter function + timer.tic(); + Eigen::MatrixXd L_T; + int mb, nb, lld; + if (blacs::mpirank == 0) { + L_T = L.transpose(); + mb = nb = lld = u_size; + } else { + mb = nb = lld = 0; + } + blacs::barrier(); + A.scatter(&L_T(0,0), f_size, 0, u_size, u_size); + A.fence(); + + timer.toc("set L.T to A", blacs::mpirank); + + // QR factorize A to compute alpha + timer.tic(); + DistMatrix QR(u_size + f_size, u_size); + std::vector tau; + std::tie(QR, tau) = A.qr(); + QR.fence(); + + timer.toc("QR", blacs::mpirank); + + // Directly use triangular_solve to get alpha: R * alpha = Q_b + timer.tic(); + + DistMatrix Qb_dist = QR.Q_matmul(b, tau, 'L', 'T'); // Q_b = Q^T * b + Qb_dist.fence(); + Eigen::MatrixXd Q_b_mat = Eigen::MatrixXd::Zero(u_size, 1); + Qb_dist.allgather(&Q_b_mat(0, 0), 0, 0, u_size, 1); + Eigen::VectorXd Q_b = Q_b_mat.col(0); + + timer.toc("Qb", blacs::mpirank); + + timer.tic(); + Eigen::MatrixXd R = Eigen::MatrixXd::Zero(u_size, u_size); // Upper triangular R from QR + QR.allgather(&R(0, 0), 0, 0, u_size, u_size); // Here the lower triangular part of R is not zero + + // Using Lapack triangular solver to temporarily avoid the numerical issue + // with Scalapack block algorithm with ill-conditioned matrix + R_inv = R.triangularView().solve(Kuu_eye); + R_inv_diag = R_inv.diagonal(); + alpha = R_inv * Q_b; + + timer.toc("get alpha", blacs::mpirank); + blacs::barrier(); +} + +void ParallelSGP ::set_hyperparameters(Eigen::VectorXd hyps) { + timer.tic(); + + // Reset Kuu and Kuf matrices. + int n_hyps, hyp_index = 0; + Eigen::VectorXd new_hyps; + + std::vector Kuu_grad, Kuf_grad; + for (int i = 0; i < n_kernels; i++) { + n_hyps = kernels[i]->kernel_hyperparameters.size(); + new_hyps = hyps.segment(hyp_index, n_hyps); + + Kuu_grad = kernels[i]->Kuu_grad(sparse_descriptors[i], Kuu_kernels[i], new_hyps); + Kuf_grad = kernels[i]->Kuf_grad(sparse_descriptors[i], training_structures, + i, Kuf_kernels[i], new_hyps); + + Kuu_kernels[i] = Kuu_grad[0]; + Kuf_kernels[i] = Kuf_grad[0]; + + kernels[i]->set_hyperparameters(new_hyps); + hyp_index += n_hyps; + } + timer.toc("set_hyp: update Kuf Kuu", blacs::mpirank); + + // Stack Kuf_local + timer.tic(); + stack_Kuf(); + timer.toc("set_hyp: stack Kuf Kuu"); + + // Update noise vector + timer.tic(); + hyperparameters = hyps; + energy_noise = hyps(hyp_index); + force_noise = hyps(hyp_index + 1); + stress_noise = hyps(hyp_index + 2); + + local_noise_vector = 1 / (energy_noise * energy_noise) * local_e_noise_one + + 1 / (force_noise * force_noise) * local_f_noise_one + + 1 / (stress_noise * stress_noise) * local_s_noise_one; + blacs::barrier(); + + // TODO: global_n_labels == f_size + global_noise_vector = Eigen::VectorXd::Zero(global_n_labels); + DistMatrix noise_vector_dist(f_size, 1); + noise_vector_dist.collect(&local_noise_vector(0), 0, 0, f_size, 1, f_size_per_proc, 1, nmax_struc - nmin_struc); + noise_vector_dist.fence(); + noise_vector_dist.gather(&global_noise_vector(0), 0, 0, f_size, 1); + blacs::barrier(); + + timer.toc("set_hyp: update noise"); + + // Update remaining matrices. + update_matrices_QR(); +} + +/* ------------------------------------------------------------------------- + * Predict mean and uncertainties + * ------------------------------------------------------------------------- */ + +void ParallelSGP ::predict_local_uncertainties(Structure &test_structure) { + int n_atoms = test_structure.noa; + int n_out = 1 + 3 * n_atoms + 6; + + int n_sparse = Kuu_inverse.rows(); + Eigen::MatrixXd kernel_mat = Eigen::MatrixXd::Zero(n_sparse, n_out); + int count = 0; + for (int i = 0; i < Kuu_kernels.size(); i++) { + int size = sparse_descriptors[i].n_clusters; + kernel_mat.block(count, 0, size, n_out) = kernels[i]->envs_struc( + sparse_descriptors[i], test_structure.descriptors[i], + kernels[i]->kernel_hyperparameters); + count += size; + } + + test_structure.mean_efs = kernel_mat.transpose() * alpha; + std::vector local_uncertainties = + this->compute_cluster_uncertainties(test_structure); + test_structure.local_uncertainties = local_uncertainties; + +} + +std::vector +ParallelSGP ::predict_on_structures(std::vector struc_list, + double cutoff, std::vector descriptor_calculators) { + + int n_test = struc_list.size(); + int n_test_per_proc, nmin_test, nmax_test; + + // Get the number of processes + int world_size; + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + + // Get the rank of the process + int world_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + + // Compute the dimensions of the matrices Kuf and Kuu + if (n_test % world_size == 0) { + n_test_per_proc = n_test / world_size; + } else { + n_test_per_proc = n_test / world_size + 1; + } + + // Compute the range of structures covered by the current rank + nmin_test = world_rank * n_test_per_proc; + if (world_rank == world_size - 1) { + nmax_test = n_test; + } else { + nmax_test = (world_rank + 1) * n_test_per_proc; + } + + // Compute the total number of labels + int n_test_labels = 0; + int n_test_atoms = 0; + for (int t = 0; t < n_test; t++) { + n_test_labels += 1 + struc_list[t].noa * 3 + 6; + n_test_atoms += struc_list[t].noa; + } + + // Create a long array to store predictions of all structures + Eigen::VectorXd mean_efs = Eigen::VectorXd::Zero(n_test_labels); + std::vector local_uncertainties; + for (int i = 0; i < n_kernels; i++) { + local_uncertainties.push_back(Eigen::VectorXd::Zero(n_test_atoms)); + } + blacs::barrier(); + + // Compute e/f/s and uncertainties for the local structures + int count_efs = 0; + int count_unc = 0; + for (int t = 0; t < n_test; t++) { + Structure struc0 = struc_list[t]; + int n_curr_labels = 1 + struc0.noa * 3 + 6; + if (nmin_test <= t && t < nmax_test) { + // compute descriptors + Structure struc(struc0.cell, struc0.species, struc0.positions, cutoff, descriptor_calculators); + // predict mean_efs and uncertainties + predict_local_uncertainties(struc); + mean_efs.segment(count_efs, n_curr_labels) = struc.mean_efs; + for (int i = 0; i < n_kernels; i++) { + local_uncertainties[i].segment(count_unc, struc.noa) = struc.local_uncertainties[i]; + } + } + count_efs += n_curr_labels; + count_unc += struc0.noa; + } + blacs::barrier(); + + // Collect all results to rank 0 + Eigen::VectorXd all_mean_efs = Eigen::VectorXd::Zero(n_test_labels); + MPI_Reduce(mean_efs.data(), all_mean_efs.data(), n_test_labels, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD); + + std::vector all_local_uncertainties; + for (int i = 0; i < n_kernels; i++) { + Eigen::VectorXd curr_unc = Eigen::VectorXd::Zero(n_test_atoms); + MPI_Reduce(local_uncertainties[i].data(), curr_unc.data(), n_test_atoms, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD); + all_local_uncertainties.push_back(curr_unc); + } + + // Assign the results to each struc + count_efs = 0; + count_unc = 0; + for (int t = 0; t < n_test; t++) { + int n_curr_labels = 1 + struc_list[t].noa * 3 + 6; + struc_list[t].mean_efs = all_mean_efs.segment(count_efs, n_curr_labels); + struc_list[t].local_uncertainties = {}; + for (int i = 0; i < n_kernels; i++) { + struc_list[t].local_uncertainties.push_back(all_local_uncertainties[i].segment(count_unc, struc_list[t].noa)); + } + count_efs += n_curr_labels; + count_unc += struc_list[t].noa; + } + return struc_list; +} + +/* ------------------------------------------------------------------------- + * Compute likelihood + * ------------------------------------------------------------------------- */ + +void ParallelSGP ::compute_likelihood_stable() { + timer.tic(); + // initialize BLACS + blacs::initialize(); + + DistMatrix Kfu_dist(f_size, u_size); + Eigen::MatrixXd Kfu_local = Kuf_local.transpose(); + blacs::barrier(); + Kfu_dist.collect(&Kfu_local(0, 0), 0, 0, f_size, u_size, f_size_per_proc, u_size, nmax_struc - nmin_struc); + Kfu_dist.fence(); + + DistMatrix alpha_dist(u_size, 1); + alpha_dist.scatter(&alpha(0), 0, 0, u_size, 1); + alpha_dist.fence(); + + DistMatrix K_alpha_dist(f_size, 1); + K_alpha_dist = Kfu_dist.matmul(alpha_dist, 1.0, 'N', 'N'); + K_alpha_dist.fence(); + Eigen::VectorXd K_alpha = Eigen::VectorXd::Zero(f_size); + blacs::barrier(); + K_alpha_dist.gather(&K_alpha(0), 0, 0, f_size, 1); + K_alpha_dist.fence(); + + DistMatrix y_dist(f_size, 1); + y_dist.collect(&local_labels(0), 0, 0, f_size, 1, f_size_per_proc, 1, nmax_struc - nmin_struc); + y_dist.fence(); + Eigen::VectorXd y_global = Eigen::VectorXd::Zero(f_size); + y_dist.gather(&y_global(0), 0, 0, f_size, 1); + y_dist.fence(); + + y_K_alpha = Eigen::VectorXd::Zero(f_size); + + // Compute log marginal likelihood + log_marginal_likelihood = 0; + if (blacs::mpirank == 0) { + y_K_alpha = y_global - K_alpha; + data_fit = + -(1. / 2.) * y_global.transpose() * global_noise_vector.cwiseProduct(y_K_alpha); + constant_term = -(1. / 2.) * global_n_labels * log(2 * M_PI); + + // Compute complexity penalty. + double noise_det = - 2 * (global_n_energy_labels * log(abs(energy_noise)) + + global_n_force_labels * log(abs(force_noise)) + + global_n_stress_labels * log(abs(stress_noise))); + + assert(L_diag.size() == R_inv_diag.size()); + double Kuu_inv_det = 0; + double sigma_inv_det = 0; + for (int i = 0; i < L_diag.size(); i++) { + Kuu_inv_det -= 2 * log(abs(L_diag(i))); + sigma_inv_det += 2 * log(abs(R_inv_diag(i))); + } + + complexity_penalty = (1. / 2.) * (noise_det + Kuu_inv_det + sigma_inv_det); + log_marginal_likelihood = complexity_penalty + data_fit + constant_term; + } + blacs::barrier(); + MPI_Bcast(&log_marginal_likelihood, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); + timer.toc("Compute likelihood", blacs::mpirank); +} + +/* ----------------------------------------------------------------------- + * Compute likelihood gradient of hyperparameters + * ------------------------------------------------------------------------- */ + +double ParallelSGP ::compute_likelihood_gradient_stable(bool precomputed_KnK) { + // initialize BLACS + blacs::initialize(); + + // Compute likelihood + compute_likelihood_stable(); + + Sigma = R_inv * R_inv.transpose(); + + // Compute likelihood gradient of kernel hyps such as signal variance + int n_hyps_total = hyperparameters.size(); + likelihood_gradient = Eigen::VectorXd::Zero(n_hyps_total); + likelihood_gradient += compute_like_grad_of_kernel_hyps(); + + // Compute likelihood gradient of energy, force, stress noises + likelihood_gradient.segment(n_hyps_total - 3, 3) += compute_like_grad_of_noise(precomputed_KnK); + + return log_marginal_likelihood; +} + +Eigen::VectorXd ParallelSGP ::compute_like_grad_of_kernel_hyps() { + timer.tic(); + + DistMatrix Kfu_dist(f_size, u_size); + Eigen::MatrixXd Kfu_local = Kuf_local.transpose(); + blacs::barrier(); + Kfu_dist.collect(&Kfu_local(0, 0), 0, 0, f_size, u_size, f_size_per_proc, u_size, nmax_struc - nmin_struc); + Kfu_dist.fence(); + + DistMatrix alpha_dist(u_size, 1); + alpha_dist.scatter(&alpha(0), 0, 0, u_size, 1); + alpha_dist.fence(); + + timer.toc("collect Kfu and alpha", blacs::mpirank); + + // Compute Kuu and Kuf matrices and gradients. + int n_hyps_total = hyperparameters.size(); + std::vector Kuu_grad, Kuf_grad, Kuu_grads, Kuf_grads; + int n_hyps, hyp_index = 0, grad_index = 0; + Eigen::VectorXd hyps_curr; + + int count = 0; + Eigen::VectorXd complexity_grad = Eigen::VectorXd::Zero(n_hyps_total); + Eigen::VectorXd datafit_grad = Eigen::VectorXd::Zero(n_hyps_total); + Eigen::VectorXd likelihood_grad = Eigen::VectorXd::Zero(n_hyps_total); + for (int i = 0; i < n_kernels; i++) { + timer.tic(); + n_hyps = kernels[i]->kernel_hyperparameters.size(); + hyps_curr = hyperparameters.segment(hyp_index, n_hyps); + int size = Kuu_kernels[i].rows(); + + // Compute the kernel matrix and grad matrix of a single kernel. The size is not (u_size, u_size) + Kuu_grad = kernels[i]->Kuu_grad(sparse_descriptors[i], Kuu_kernels[i], hyps_curr); + Kuf_grad = kernels[i]->Kuf_grad(sparse_descriptors[i], + training_structures, i, Kuf_kernels[i], hyps_curr); + + Eigen::MatrixXd Kuu_i = Kuu_grad[0]; + timer.toc("Kuu_grad & Kuf_grad", blacs::mpirank); + + for (int j = 0; j < n_hyps; j++) { + timer.tic(); + Eigen::MatrixXd dKuu = Eigen::MatrixXd::Zero(u_size, u_size); + dKuu.block(count, count, size, size) = Kuu_grad[j + 1]; + + // Compute locally noise_one * Kuf_grad.transpose() + // TODO: only apply for inner product kernel? + int local_f_size = nmax_struc - nmin_struc; + Eigen::MatrixXd dKfu_local = Eigen::MatrixXd::Zero(local_f_size, u_size); + dKfu_local.block(0, count, local_f_size, size) = Kuf_grad[j + 1].transpose(); + n_sparse = u_size; + Eigen::MatrixXd dK_noise_K = SparseGP::compute_dKnK(i); + timer.toc("compute dKnK", blacs::mpirank); + + // Derivative of complexity over sigma + timer.tic(); + if (blacs::mpirank == 0) { + Eigen::MatrixXd Pi_mat = dK_noise_K + dK_noise_K.transpose() + dKuu; + complexity_grad(hyp_index + j) += 1./2. * (Kuu_i.inverse() * Kuu_grad[j + 1]).trace() - 1./2. * (Pi_mat * Sigma).trace(); + std::cout << "complex_grad 1 = " << 1./2. * (Kuu_i.inverse() * Kuu_grad[j + 1]).trace() << std::endl; + std::cout << "max Kuu_inv = " << Kuu_i.inverse().maxCoeff() << ", min = " << Kuu_i.inverse().minCoeff() << std::endl; + std::cout << "complex_grad 2 = " << 1./2. * (Pi_mat * Sigma).trace() << std::endl; + } + timer.toc("Pi_mat and complexity_grad", blacs::mpirank); + + // Derivative of data_fit over sigma + timer.tic(); + DistMatrix dKfu_dist(f_size, u_size); + dKfu_dist.collect(&dKfu_local(0, 0), 0, 0, f_size, u_size, f_size_per_proc, u_size, nmax_struc - nmin_struc); + dKfu_dist.fence(); + + DistMatrix dK_alpha_dist(f_size, 1); + dK_alpha_dist = dKfu_dist.matmul(alpha_dist, 1.0, 'N', 'N'); + dK_alpha_dist.fence(); + Eigen::VectorXd dK_alpha = Eigen::VectorXd::Zero(f_size); + dK_alpha_dist.gather(&dK_alpha(0), 0, 0, f_size, 1); + dK_alpha_dist.fence(); + + if (blacs::mpirank == 0) { + datafit_grad(hyp_index + j) += + dK_alpha.transpose() * global_noise_vector.cwiseProduct(y_K_alpha); + datafit_grad(hyp_index + j) += + - 1./2. * alpha.transpose() * dKuu * alpha; + likelihood_grad(hyp_index + j) += complexity_grad(hyp_index + j) + datafit_grad(hyp_index + j); + std::cout << "datafit_grad 1 = " << dK_alpha.transpose() * global_noise_vector.cwiseProduct(y_K_alpha) << std::endl; + std::cout << "datafit_grad 2 = " << - 1./2. * alpha.transpose() * dKuu * alpha << std::endl; + } + timer.toc("datafit_grad of kernel hyps", blacs::mpirank); + } + count += size; + hyp_index += n_hyps; + } + assert(hyp_index == n_hyps_total - 3); + + blacs::barrier(); + MPI_Bcast(likelihood_grad.data(), likelihood_grad.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + + return likelihood_grad; +} + +Eigen::VectorXd ParallelSGP ::compute_like_grad_of_noise(bool precomputed_KnK) { + // Derivative of complexity over noise + double en3 = energy_noise * energy_noise * energy_noise; + double fn3 = force_noise * force_noise * force_noise; + double sn3 = stress_noise * stress_noise * stress_noise; + + Eigen::VectorXd complexity_grad = Eigen::VectorXd::Zero(3); + Eigen::VectorXd datafit_grad = Eigen::VectorXd::Zero(3); + Eigen::VectorXd likelihood_grad = Eigen::VectorXd::Zero(3); + + compute_KnK(precomputed_KnK); + blacs::barrier(); + + // Derivative of data_fit over noise + timer.tic(); + + DistMatrix e_noise_one_dist(f_size, 1); + e_noise_one_dist.collect(&local_e_noise_one(0), 0, 0, f_size, 1, f_size_per_proc, 1, nmax_struc - nmin_struc); + e_noise_one_dist.fence(); + Eigen::VectorXd global_e_noise_one = Eigen::VectorXd::Zero(f_size); + e_noise_one_dist.gather(&global_e_noise_one(0), 0, 0, f_size, 1); + e_noise_one_dist.fence(); + + DistMatrix f_noise_one_dist(f_size, 1); + f_noise_one_dist.collect(&local_f_noise_one(0), 0, 0, f_size, 1, f_size_per_proc, 1, nmax_struc - nmin_struc); + f_noise_one_dist.fence(); + Eigen::VectorXd global_f_noise_one = Eigen::VectorXd::Zero(f_size); + f_noise_one_dist.gather(&global_f_noise_one(0), 0, 0, f_size, 1); + f_noise_one_dist.fence(); + + DistMatrix s_noise_one_dist(f_size, 1); + s_noise_one_dist.collect(&local_s_noise_one(0), 0, 0, f_size, 1, f_size_per_proc, 1, nmax_struc - nmin_struc); + s_noise_one_dist.fence(); + Eigen::VectorXd global_s_noise_one = Eigen::VectorXd::Zero(f_size); + s_noise_one_dist.gather(&global_s_noise_one(0), 0, 0, f_size, 1); + s_noise_one_dist.fence(); + + timer.toc("collect and gather e/f/s_noise_one", blacs::mpirank); + + timer.tic(); + if (blacs::mpirank == 0) { + complexity_grad(0) = - global_n_energy_labels / energy_noise + + (KnK_e * Sigma).trace() / en3; + complexity_grad(1) = - global_n_force_labels / force_noise + + (KnK_f * Sigma).trace() / fn3; + complexity_grad(2) = - global_n_stress_labels / stress_noise + + (KnK_s * Sigma).trace() / sn3; + + datafit_grad(0) = y_K_alpha.transpose() * global_e_noise_one.cwiseProduct(y_K_alpha); + datafit_grad(0) /= en3; + datafit_grad(1) = y_K_alpha.transpose() * global_f_noise_one.cwiseProduct(y_K_alpha); + datafit_grad(1) /= fn3; + datafit_grad(2) = y_K_alpha.transpose() * global_s_noise_one.cwiseProduct(y_K_alpha); + datafit_grad(2) /= sn3; + + likelihood_grad(0) += complexity_grad(0) + datafit_grad(0); + likelihood_grad(1) += complexity_grad(1) + datafit_grad(1); + likelihood_grad(2) += complexity_grad(2) + datafit_grad(2); + } + blacs::barrier(); + MPI_Bcast(likelihood_grad.data(), likelihood_grad.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + timer.toc("datafit/complexity grad over noise", blacs::mpirank); + + return likelihood_grad; +} + +void ParallelSGP ::precompute_KnK() { + timer.tic(); + + // clear memory + Kuf_e_noise_Kfu.clear(); + Kuf_f_noise_Kfu.clear(); + Kuf_s_noise_Kfu.clear(); + + Kuf_e_noise_Kfu.shrink_to_fit(); + Kuf_f_noise_Kfu.shrink_to_fit(); + Kuf_s_noise_Kfu.shrink_to_fit(); + + for (int i = 0; i < n_kernels; i++) { + for (int j = 0; j < n_kernels; j++) { + Kuf_e_noise_Kfu.push_back(Eigen::MatrixXd::Zero(0,0)); + Kuf_f_noise_Kfu.push_back(Eigen::MatrixXd::Zero(0,0)); + Kuf_s_noise_Kfu.push_back(Eigen::MatrixXd::Zero(0,0)); + } + } + + for (int i = 0; i < n_kernels; i++) { + Eigen::VectorXd hyps_i = kernels[i]->kernel_hyperparameters; + assert(hyps_i.size() == 1); + + int u_size_kern = Kuf_kernels[i].rows(); + + // Compute Kuf * e/f/s_noise_one_local and collect to distributed matrix + Eigen::MatrixXd enK_local = local_e_noise_one.asDiagonal() * Kuf_kernels[i].transpose(); + Eigen::MatrixXd fnK_local = local_f_noise_one.asDiagonal() * Kuf_kernels[i].transpose(); + Eigen::MatrixXd snK_local = local_s_noise_one.asDiagonal() * Kuf_kernels[i].transpose(); + blacs::barrier(); + + DistMatrix enK_dist(f_size, u_size_kern); + enK_dist.collect(&enK_local(0, 0), 0, 0, f_size, u_size_kern, f_size_per_proc, u_size_kern, nmax_struc - nmin_struc); + enK_dist.fence(); + + DistMatrix fnK_dist(f_size, u_size_kern); + fnK_dist.collect(&fnK_local(0, 0), 0, 0, f_size, u_size_kern, f_size_per_proc, u_size_kern, nmax_struc - nmin_struc); + fnK_dist.fence(); + + DistMatrix snK_dist(f_size, u_size_kern); + snK_dist.collect(&snK_local(0, 0), 0, 0, f_size, u_size_kern, f_size_per_proc, u_size_kern, nmax_struc - nmin_struc); + snK_dist.fence(); + + for (int j = 0; j < n_kernels; j++) { + Eigen::VectorXd hyps_j = kernels[j]->kernel_hyperparameters; + assert(hyps_j.size() == 1); + + double sig4 = hyps_i(0) * hyps_i(0) * hyps_j(0) * hyps_j(0); + + // Compute and return Kuf * e/f/s_noise_one * Kuf.transpose() + int u_size_kern_j = Kuf_kernels[j].rows(); + DistMatrix Kfu_dist(f_size, u_size_kern_j); + Eigen::MatrixXd Kfu_local = Kuf_kernels[j].transpose(); + Kfu_dist.collect(&Kfu_local(0, 0), 0, 0, f_size, u_size_kern_j, f_size_per_proc, u_size_kern_j, nmax_struc - nmin_struc); + Kfu_dist.fence(); + + // Compute Kn * Kuf.tranpose() + DistMatrix KnK_e_dist(u_size_kern_j, u_size_kern); + KnK_e_dist = Kfu_dist.matmul(enK_dist, 1.0, 'T', 'N'); + KnK_e_dist.fence(); + + DistMatrix KnK_f_dist(u_size_kern_j, u_size_kern); + KnK_f_dist = Kfu_dist.matmul(fnK_dist, 1.0, 'T', 'N'); + KnK_f_dist.fence(); + + DistMatrix KnK_s_dist(u_size_kern_j, u_size_kern); + KnK_s_dist = Kfu_dist.matmul(snK_dist, 1.0, 'T', 'N'); + KnK_s_dist.fence(); + + // Gather to get the serial matrix + Eigen::MatrixXd KnK_e_kern = Eigen::MatrixXd::Zero(u_size_kern_j, u_size_kern); + Eigen::MatrixXd KnK_f_kern = Eigen::MatrixXd::Zero(u_size_kern_j, u_size_kern); + Eigen::MatrixXd KnK_s_kern = Eigen::MatrixXd::Zero(u_size_kern_j, u_size_kern); + blacs::barrier(); + + KnK_e_dist.gather(&KnK_e_kern(0,0), 0, 0, u_size_kern_j, u_size_kern); + KnK_e_dist.fence(); + KnK_e_kern /= sig4; + + KnK_f_dist.gather(&KnK_f_kern(0,0), 0, 0, u_size_kern_j, u_size_kern); + KnK_f_dist.fence(); + KnK_f_kern /= sig4; + + KnK_s_dist.gather(&KnK_s_kern(0,0), 0, 0, u_size_kern_j, u_size_kern); + KnK_s_dist.fence(); + KnK_s_kern /= sig4; + + Kuf_e_noise_Kfu[j * n_kernels + i] = KnK_e_kern; + Kuf_f_noise_Kfu[j * n_kernels + i] = KnK_f_kern; + Kuf_s_noise_Kfu[j * n_kernels + i] = KnK_s_kern; + } + } + timer.toc("precompute_KnK", blacs::mpirank); +} + + +void ParallelSGP ::compute_KnK(bool precomputed) { + timer.tic(); + + KnK_e = Eigen::MatrixXd::Zero(u_size, u_size); + KnK_f = Eigen::MatrixXd::Zero(u_size, u_size); + KnK_s = Eigen::MatrixXd::Zero(u_size, u_size); + + if (precomputed) { + int count_i = 0, count_ij = 0; + for (int i = 0; i < n_kernels; i++) { + Eigen::VectorXd hyps_i = kernels[i]->kernel_hyperparameters; + assert(hyps_i.size() == 1); + int size_i = Kuf_kernels[i].rows(); + int count_j = 0; + for (int j = 0; j < n_kernels; j++) { + Eigen::VectorXd hyps_j = kernels[j]->kernel_hyperparameters; + assert(hyps_j.size() == 1); + int size_j = Kuf_kernels[j].rows(); + + double sig4 = hyps_i(0) * hyps_i(0) * hyps_j(0) * hyps_j(0); + + KnK_e.block(count_i, count_j, size_i, size_j) += Kuf_e_noise_Kfu[count_ij] * sig4; + KnK_f.block(count_i, count_j, size_i, size_j) += Kuf_f_noise_Kfu[count_ij] * sig4; + KnK_s.block(count_i, count_j, size_i, size_j) += Kuf_s_noise_Kfu[count_ij] * sig4; + + count_ij += 1; + count_j += size_j; + } + count_i += size_i; + } + blacs::barrier(); + } else { + // Compute and return Kuf * e/f/s_noise_one * Kuf.transpose() + DistMatrix Kfu_dist(f_size, u_size); + Eigen::MatrixXd Kfu_local = Kuf_local.transpose(); + Kfu_dist.collect(&Kfu_local(0, 0), 0, 0, f_size, u_size, f_size_per_proc, u_size, nmax_struc - nmin_struc); + Kfu_dist.fence(); + + // Compute Kuf * e/f/s_noise_one_local and collect to distributed matrix + Eigen::MatrixXd enK_local = local_e_noise_one.asDiagonal() * Kuf_local.transpose(); + Eigen::MatrixXd fnK_local = local_f_noise_one.asDiagonal() * Kuf_local.transpose(); + Eigen::MatrixXd snK_local = local_s_noise_one.asDiagonal() * Kuf_local.transpose(); + blacs::barrier(); + + DistMatrix enK_dist(f_size, u_size); + enK_dist.collect(&enK_local(0, 0), 0, 0, f_size, u_size, f_size_per_proc, u_size, nmax_struc - nmin_struc); + enK_dist.fence(); + + DistMatrix fnK_dist(f_size, u_size); + fnK_dist.collect(&fnK_local(0, 0), 0, 0, f_size, u_size, f_size_per_proc, u_size, nmax_struc - nmin_struc); + fnK_dist.fence(); + + DistMatrix snK_dist(f_size, u_size); + snK_dist.collect(&snK_local(0, 0), 0, 0, f_size, u_size, f_size_per_proc, u_size, nmax_struc - nmin_struc); + snK_dist.fence(); + + // Compute Kn * Kuf.tranpose() + DistMatrix KnK_e_dist(u_size, u_size); + KnK_e_dist = Kfu_dist.matmul(enK_dist, 1.0, 'T', 'N'); + KnK_e_dist.fence(); + + DistMatrix KnK_f_dist(u_size, u_size); + KnK_f_dist = Kfu_dist.matmul(fnK_dist, 1.0, 'T', 'N'); + KnK_f_dist.fence(); + + DistMatrix KnK_s_dist(u_size, u_size); + KnK_s_dist = Kfu_dist.matmul(snK_dist, 1.0, 'T', 'N'); + KnK_s_dist.fence(); + + // Gather to get the serial matrix + KnK_e_dist.gather(&KnK_e(0,0), 0, 0, u_size, u_size); + KnK_e_dist.fence(); + + KnK_f_dist.gather(&KnK_f(0,0), 0, 0, u_size, u_size); + KnK_f_dist.fence(); + + KnK_s_dist.gather(&KnK_s(0,0), 0, 0, u_size, u_size); + KnK_s_dist.fence(); + } + timer.toc("compute_KnK", blacs::mpirank); +} diff --git a/src/flare_pp/bffs/parallel_sgp.h b/src/flare_pp/bffs/parallel_sgp.h new file mode 100644 index 00000000..ad2600aa --- /dev/null +++ b/src/flare_pp/bffs/parallel_sgp.h @@ -0,0 +1,126 @@ +#ifndef PARALLEL_SGP_H +#define PARALLEL_SGP_H + +#include "descriptor.h" +#include "kernel.h" +#include "structure.h" +#include "sparse_gp.h" +#include +#include +#include "json.h" +#include "utils.h" + +class ParallelSGP : public SparseGP { +public: + // Training and sparse points. + std::vector> local_sparse_descriptors; + std::vector> local_label_indices; + int local_label_size; + + // Parallel parameters + int u_size, u_size_per_proc; + std::vector u_size_single_kernel; + int f_size, f_size_single_kernel, f_size_per_proc; + int nmin_struc, nmax_struc, nmin_envs, nmax_envs; + std::vector> n_struc_clusters_by_type; + int global_n_labels = 0; + int global_n_energy_labels = 0; + int global_n_force_labels = 0; + int global_n_stress_labels = 0; + std::vector local_training_structure_indices; + bool finalize_MPI = true; + + utils::Timer timer; + + Eigen::MatrixXd Kuf_local; + + // Constructors. + ParallelSGP(); + + /** + Basic Parallel Sparse GP constructor. This class inherits from SparseGP class and accept + the same input parameters. + + @param kernels A list of Kernel objects, e.g. NormalizedInnerProduct, SquaredExponential. + Note the number of kernels should be equal to the number of descriptor calculators. + @param energy_noise Noise hyperparameter for total energy. + @param force_noise Noise hyperparameter for atomic forces. + @param stress_noise Noise hyperparameter for total stress. + */ + ParallelSGP(std::vector kernels, double energy_noise, + double force_noise, double stress_noise); + + // Destructor + virtual ~ParallelSGP(); + + void add_global_noise(int n_energy, int n_force, int n_stress); + Eigen::VectorXd global_noise_vector, local_noise_vector, local_e_noise_one, local_f_noise_one, local_s_noise_one; + Eigen::MatrixXd local_labels; + void add_training_structure(const Structure &structure); + + Eigen::VectorXi sparse_indices_by_type(int n_types, std::vector species, const std::vector atoms); + void add_specific_environments(const Structure&, std::vector>, bool update); + void add_local_specific_environments(const Structure &structure, const std::vector atoms); + void add_global_specific_environments(const Structure &structure, const std::vector atoms); + void predict_local_uncertainties(Structure &structure); + + std::vector predict_on_structures(std::vector struc_list, + double cutoff, std::vector descriptor_calculators); + + /** + Method for constructing SGP model from training dataset. + + @param training_strucs A list of Structure objects + @param cutoff The cutoff for SGP + @param descriptor_calculators A list of Descriptor objects, e.g. B2, B3, ... + @param trianing_sparse_indices A list of indices of sparse environments in each training structure + @param n_types An integer to specify number of types. For B2 descriptor, n_type is equal to the + number of species + */ + void build(const std::vector &training_strucs, + double cutoff, std::vector descriptor_calculators, + const std::vector>> &training_sparse_indices, + int n_types, bool update = false); + + /** + Method for loading training data distributedly. Each process loads a portion of the whole training + data set, and load the whole sparse set. + + @param training_strucs A list of Structure objects + @param cutoff The cutoff for SGP + @param descriptor_calculators A list of Descriptor objects, e.g. B2, B3, ... + @param trianing_sparse_indices A list of indices of sparse environments in each training structure + @param n_types An integer to specify number of types. For B2 descriptor, n_type is equal to the + number of species + */ + void load_local_training_data(const std::vector &training_strucs, + double cutoff, std::vector descriptor_calculators, + const std::vector>> &training_sparse_indices, + int n_types, bool update = false); + + void gather_sparse_descriptors(std::vector> n_clusters_by_type, + const std::vector &training_strucs); + + /** + Method for computing kernel matrices and vectors + + @param training_strucs A list of Structure objects + */ + void compute_kernel_matrices(const std::vector &training_strucs); + + void set_hyperparameters(Eigen::VectorXd hyps); + void stack_Kuf(); + void update_matrices_QR(); + + Eigen::MatrixXd varmap_coeffs; // for debugging. TODO: remove this line + + double compute_likelihood_gradient_stable(bool precomputed_KnK); + Eigen::VectorXd y_K_alpha; + void compute_likelihood_stable(); + void compute_KnK(bool precomputed_KnK); + void precompute_KnK(); + Eigen::VectorXd compute_like_grad_of_kernel_hyps(); + Eigen::VectorXd compute_like_grad_of_noise(bool precomputed_KnK); + +}; +#endif diff --git a/src/flare_pp/bffs/sparse_gp.cpp b/src/flare_pp/bffs/sparse_gp.cpp index 7bcc3632..a02bf538 100644 --- a/src/flare_pp/bffs/sparse_gp.cpp +++ b/src/flare_pp/bffs/sparse_gp.cpp @@ -5,6 +5,7 @@ #include // setprecision #include #include // Iota +#include SparseGP ::SparseGP() {} @@ -52,6 +53,9 @@ SparseGP ::SparseGP(std::vector kernels, double energy_noise, } } +// Destructor +SparseGP ::~SparseGP() {} + void SparseGP ::initialize_sparse_descriptors(const Structure &structure) { if (sparse_descriptors.size() != 0) return; @@ -130,12 +134,14 @@ SparseGP ::compute_cluster_uncertainties(const Structure &structure) { } void SparseGP ::add_specific_environments(const Structure &structure, - const std::vector atoms) { + const std::vector> atoms) { + + initialize_sparse_descriptors(structure); // Gather clusters with central atom in the given list. std::vector>> indices_1; for (int i = 0; i < n_kernels; i++){ - sparse_indices[i].push_back(atoms); // for each kernel the added atoms are the same + sparse_indices[i].push_back(atoms[i]); int n_types = structure.descriptors[i].n_types; std::vector> indices_2; @@ -144,8 +150,8 @@ void SparseGP ::add_specific_environments(const Structure &structure, std::vector indices_3; for (int k = 0; k < n_clusters; k++){ int atom_index_1 = structure.descriptors[i].atom_indices[j](k); - for (int l = 0; l < atoms.size(); l++){ - int atom_index_2 = atoms[l]; + for (int l = 0; l < atoms[i].size(); l++){ + int atom_index_2 = atoms[i][l]; if (atom_index_1 == atom_index_2){ indices_3.push_back(k); } @@ -180,6 +186,8 @@ void SparseGP ::add_specific_environments(const Structure &structure, void SparseGP ::add_uncertain_environments(const Structure &structure, const std::vector &n_added) { + initialize_sparse_descriptors(structure); + // Compute cluster uncertainties. std::vector> sorted_indices = sort_clusters_by_uncertainty(structure); @@ -240,6 +248,8 @@ void SparseGP ::add_uncertain_environments(const Structure &structure, void SparseGP ::add_random_environments(const Structure &structure, const std::vector &n_added) { + initialize_sparse_descriptors(structure); + // Randomly select environments without replacement. std::vector> envs1; for (int i = 0; i < structure.descriptors.size(); i++) { // NOTE: n_kernels might be diff from descriptors number @@ -429,14 +439,14 @@ void SparseGP ::update_Kuf( } if (training_structures[j].forces.size() != 0) { - kern_mat.block(u_ind, label_count(j) + current_count, n3, - n_atoms * 3) = - Kuf_kernels[i].block(n1, label_count(j) + current_count, n3, - n_atoms * 3); - kern_mat.block(u_ind + n3, label_count(j) + current_count, n4, - n_atoms * 3) = - envs_struc_kernels.block(n2, 1, n4, n_atoms * 3); - current_count += n_atoms * 3; + std::vector atom_indices = training_atom_indices[j]; + for (int a = 0; a < atom_indices.size(); a++) { + kern_mat.block(u_ind, label_count(j) + current_count, n3, 3) = + Kuf_kernels[i].block(n1, label_count(j) + current_count, n3, 3); + kern_mat.block(u_ind + n3, label_count(j) + current_count, n4, 3) = + envs_struc_kernels.block(n2, 1 + atom_indices[a] * 3, n4, 3); + current_count += 3; + } } if (training_structures[j].stresses.size() != 0) { @@ -454,41 +464,36 @@ void SparseGP ::update_Kuf( } } -void SparseGP ::add_training_structure(const Structure &structure) { - +void SparseGP ::add_training_structure(const Structure &structure, + const std::vector atom_indices) { initialize_sparse_descriptors(structure); + int n_atoms = structure.noa; int n_energy = structure.energy.size(); - int n_force = structure.forces.size(); + int n_force = 0; + std::vector atoms; + if (atom_indices[0] == -1) { // add all atoms + n_force = structure.forces.size(); + for (int i = 0; i < n_atoms; i++) { + atoms.push_back(i); + } + } else { + atoms = atom_indices; + n_force = atoms.size() * 3; + } + training_atom_indices.push_back(atoms); int n_stress = structure.stresses.size(); int n_struc_labels = n_energy + n_force + n_stress; - int n_atoms = structure.noa; - - // Update Kuf kernels. - Eigen::MatrixXd envs_struc_kernels; - for (int i = 0; i < n_kernels; i++) { - int n_sparse = sparse_descriptors[i].n_clusters; - - envs_struc_kernels = - kernels[i]->envs_struc(sparse_descriptors[i], structure.descriptors[i], - kernels[i]->kernel_hyperparameters); - - Kuf_kernels[i].conservativeResize(n_sparse, n_labels + n_struc_labels); - Kuf_kernels[i].block(0, n_labels, n_sparse, n_energy) = - envs_struc_kernels.block(0, 0, n_sparse, n_energy); - Kuf_kernels[i].block(0, n_labels + n_energy, n_sparse, n_force) = - envs_struc_kernels.block(0, 1, n_sparse, n_force); - Kuf_kernels[i].block(0, n_labels + n_energy + n_force, n_sparse, n_stress) = - envs_struc_kernels.block(0, 1 + n_atoms * 3, n_sparse, n_sparse); - } // Update labels. label_count.conservativeResize(training_structures.size() + 2); label_count(training_structures.size() + 1) = n_labels + n_struc_labels; y.conservativeResize(n_labels + n_struc_labels); y.segment(n_labels, n_energy) = structure.energy; - y.segment(n_labels + n_energy, n_force) = structure.forces; y.segment(n_labels + n_energy + n_force, n_stress) = structure.stresses; + for (int a = 0; a < atoms.size(); a++) { + y.segment(n_labels + n_energy + a * 3, 3) = structure.forces.segment(atoms[a] * 3, 3); + } // Update noise. noise_vector.conservativeResize(n_labels + n_struc_labels); @@ -499,6 +504,41 @@ void SparseGP ::add_training_structure(const Structure &structure) { noise_vector.segment(n_labels + n_energy + n_force, n_stress) = Eigen::VectorXd::Constant(n_stress, 1 / (stress_noise * stress_noise)); + // Save "1" vector for energy, force and stress noise, for likelihood gradient calculation + e_noise_one.conservativeResize(n_labels + n_struc_labels); + f_noise_one.conservativeResize(n_labels + n_struc_labels); + s_noise_one.conservativeResize(n_labels + n_struc_labels); + + e_noise_one.segment(n_labels, n_struc_labels) = Eigen::VectorXd::Zero(n_struc_labels); + f_noise_one.segment(n_labels, n_struc_labels) = Eigen::VectorXd::Zero(n_struc_labels); + s_noise_one.segment(n_labels, n_struc_labels) = Eigen::VectorXd::Zero(n_struc_labels); + + e_noise_one.segment(n_labels, n_energy) = Eigen::VectorXd::Ones(n_energy); + f_noise_one.segment(n_labels + n_energy, n_force) = Eigen::VectorXd::Ones(n_force); + s_noise_one.segment(n_labels + n_energy + n_force, n_stress) = Eigen::VectorXd::Ones(n_stress); + + // Update Kuf kernels. + Eigen::MatrixXd envs_struc_kernels; + for (int i = 0; i < n_kernels; i++) { + int n_sparse = sparse_descriptors[i].n_clusters; + + envs_struc_kernels = // contain all atoms + kernels[i]->envs_struc(sparse_descriptors[i], structure.descriptors[i], + kernels[i]->kernel_hyperparameters); + + Kuf_kernels[i].conservativeResize(n_sparse, n_labels + n_struc_labels); + Kuf_kernels[i].block(0, n_labels, n_sparse, n_energy) = + envs_struc_kernels.block(0, 0, n_sparse, n_energy); + Kuf_kernels[i].block(0, n_labels + n_energy + n_force, n_sparse, n_stress) = + envs_struc_kernels.block(0, 1 + n_atoms * 3, n_sparse, n_stress); + + // Only add forces from `atoms` + for (int a = 0; a < atoms.size(); a++) { + Kuf_kernels[i].block(0, n_labels + n_energy + a * 3, n_sparse, 3) = + envs_struc_kernels.block(0, 1 + atoms[a] * 3, n_sparse, 3); // if n_energy=0, we can not use n_energy but 1 + } + } + // Update label count. n_energy_labels += n_energy; n_force_labels += n_force; @@ -562,7 +602,7 @@ void SparseGP ::update_matrices_QR() { // QR decompose A. Eigen::HouseholderQR qr(A); - Eigen::VectorXd Q_b = qr.householderQ().transpose() * b; + Eigen::VectorXd Q_b = (qr.householderQ().transpose() * b).segment(0, Kuf.cols()); R_inv = qr.matrixQR().block(0, 0, Kuu.cols(), Kuu.cols()) .triangularView() .solve(Kuu_eye); @@ -691,6 +731,264 @@ void SparseGP ::compute_likelihood_stable() { log_marginal_likelihood = complexity_penalty + data_fit + constant_term; } +double SparseGP ::compute_likelihood_gradient_stable(bool precomputed_KnK) { + + double duration = 0; + std::chrono::high_resolution_clock::time_point t1, t2; + t1 = std::chrono::high_resolution_clock::now(); + + Eigen::VectorXd K_alpha = Kuf.transpose() * alpha; + Eigen::VectorXd y_K_alpha = y - K_alpha; + data_fit = + -(1. / 2.) * y.transpose() * noise_vector.cwiseProduct(y_K_alpha); + constant_term = -(1. / 2.) * n_labels * log(2 * M_PI); + + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: likelihood datafit " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + + // Compute complexity penalty. + double noise_det = - 2 * (n_energy_labels * log(abs(energy_noise)) + + n_force_labels * log(abs(force_noise)) + + n_stress_labels * log(abs(stress_noise))); + + assert(L_diag.size() == R_inv_diag.size()); + double Kuu_inv_det = 0; + double sigma_inv_det = 0; +//#pragma omp parallel for + for (int i = 0; i < L_diag.size(); i++) { + Kuu_inv_det -= 2 * log(abs(L_diag(i))); + sigma_inv_det += 2 * log(abs(R_inv_diag(i))); + } + + complexity_penalty = (1. / 2.) * (noise_det + Kuu_inv_det + sigma_inv_det); + log_marginal_likelihood = complexity_penalty + data_fit + constant_term; + + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: likelihood complexity " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + + // Compute Kuu and Kuf matrices and gradients. + int n_hyps_total = hyperparameters.size(); + + //Eigen::MatrixXd Kuu_mat = Eigen::MatrixXd::Zero(n_sparse, n_sparse); + //Eigen::MatrixXd Kuf_mat = Eigen::MatrixXd::Zero(n_sparse, n_labels); + + std::vector Kuu_grad, Kuf_grad, Kuu_grads, Kuf_grads; + + int n_hyps, hyp_index = 0, grad_index = 0; + Eigen::VectorXd hyps_curr; + + int count = 0; + Eigen::VectorXd complexity_grad = Eigen::VectorXd::Zero(n_hyps_total); + Eigen::VectorXd datafit_grad = Eigen::VectorXd::Zero(n_hyps_total); + likelihood_gradient = Eigen::VectorXd::Zero(n_hyps_total); + for (int i = 0; i < n_kernels; i++) { + n_hyps = kernels[i]->kernel_hyperparameters.size(); + hyps_curr = hyperparameters.segment(hyp_index, n_hyps); + int size = Kuu_kernels[i].rows(); + + Kuu_grad = kernels[i]->Kuu_grad(sparse_descriptors[i], Kuu_kernels[i], hyps_curr); + if (!precomputed_KnK) { + Kuf_grad = kernels[i]->Kuf_grad(sparse_descriptors[i], training_structures, + i, Kuf_kernels[i], hyps_curr); + } + + //Kuu_mat.block(count, count, size, size) = Kuu_grad[0]; + //Kuf_mat.block(count, 0, size, n_labels) = Kuf_grad[0]; + Eigen::MatrixXd Kuu_i = Kuu_grad[0]; + + for (int j = 0; j < n_hyps; j++) { + Kuu_grads.push_back(Eigen::MatrixXd::Zero(n_sparse, n_sparse)); + Kuu_grads[hyp_index + j].block(count, count, size, size) = + Kuu_grad[j + 1]; + + if (!precomputed_KnK) { + Kuf_grads.push_back(Eigen::MatrixXd::Zero(n_sparse, n_labels)); + Kuf_grads[hyp_index + j].block(count, 0, size, n_labels) = + Kuf_grad[j + 1]; + } + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: Kuu_grad Kuf_grad " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + + // Compute Pi matrix and save as an intermediate variable + + Eigen::MatrixXd dK_noise_K; + if (precomputed_KnK) { + dK_noise_K = compute_dKnK(i); + } else { + Eigen::MatrixXd noise_diag = noise_vector.asDiagonal(); + dK_noise_K = Kuf_grads[hyp_index + j] * noise_diag * Kuf.transpose(); + } + Eigen::MatrixXd Pi_mat = dK_noise_K + dK_noise_K.transpose() + Kuu_grads[hyp_index + j]; + + // Derivative of complexity over sigma + // TODO: the 2nd term is not very stable numerically, because dK_noise_K is very large, and Kuu_grads is small + complexity_grad(hyp_index + j) += 1./2. * (Kuu_i.inverse() * Kuu_grad[j + 1]).trace() - 1./2. * (Pi_mat * Sigma).trace(); + + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: dC/dsigma " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + + // Derivative of data_fit over sigma + Eigen::VectorXd dK_alpha; + if (precomputed_KnK) { + Eigen::MatrixXd dKuf = Eigen::MatrixXd::Zero(n_sparse, n_labels); + dKuf.block(count, 0, size, n_labels) = Kuf.block(count, 0, size, n_labels); + dK_alpha = (2. / hyps_curr(j)) * dKuf.transpose() * alpha; + } else { + dK_alpha = Kuf_grads[hyp_index + j].transpose() * alpha; + } + + datafit_grad(hyp_index + j) += dK_alpha.transpose() * noise_vector.cwiseProduct(y_K_alpha); + datafit_grad(hyp_index + j) += - 1./2. * alpha.transpose() * Kuu_grads[hyp_index + j] * alpha; + + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: dD/dsigma " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + + likelihood_gradient(hyp_index + j) += complexity_grad(hyp_index + j) + datafit_grad(hyp_index + j); + } + + count += size; + hyp_index += n_hyps; + } + + // Derivative of complexity over noise + double en3 = energy_noise * energy_noise * energy_noise; + double fn3 = force_noise * force_noise * force_noise; + double sn3 = stress_noise * stress_noise * stress_noise; + + compute_KnK(precomputed_KnK); + complexity_grad(hyp_index + 0) = - e_noise_one.sum() / energy_noise + + (KnK_e * Sigma).trace() / en3; + complexity_grad(hyp_index + 1) = - f_noise_one.sum() / force_noise + + (KnK_f * Sigma).trace() / fn3; + complexity_grad(hyp_index + 2) = - s_noise_one.sum() / stress_noise + + (KnK_s * Sigma).trace() / sn3; + + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: dC/dnoise " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + + // Derivative of data_fit over noise + datafit_grad(hyp_index + 0) = y_K_alpha.transpose() * e_noise_one.cwiseProduct(y_K_alpha); + datafit_grad(hyp_index + 0) /= en3; + datafit_grad(hyp_index + 1) = y_K_alpha.transpose() * f_noise_one.cwiseProduct(y_K_alpha); + datafit_grad(hyp_index + 1) /= fn3; + datafit_grad(hyp_index + 2) = y_K_alpha.transpose() * s_noise_one.cwiseProduct(y_K_alpha); + datafit_grad(hyp_index + 2) /= sn3; + + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: dD/dnoise " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + + likelihood_gradient(hyp_index + 0) += complexity_grad(hyp_index + 0) + datafit_grad(hyp_index + 0); + likelihood_gradient(hyp_index + 1) += complexity_grad(hyp_index + 1) + datafit_grad(hyp_index + 1); + likelihood_gradient(hyp_index + 2) += complexity_grad(hyp_index + 2) + datafit_grad(hyp_index + 2); + + return log_marginal_likelihood; + +} + +void SparseGP ::precompute_KnK() { + Kuf_e_noise_Kfu = {}; + Kuf_f_noise_Kfu = {}; + Kuf_s_noise_Kfu = {}; + for (int i = 0; i < n_kernels; i++) { + Eigen::VectorXd hyps_i = kernels[i]->kernel_hyperparameters; + assert(hyps_i.size() == 1); + + for (int j = 0; j < n_kernels; j++) { + Eigen::VectorXd hyps_j = kernels[j]->kernel_hyperparameters; + assert(hyps_j.size() == 1); + + double sig4 = hyps_i(0) * hyps_i(0) * hyps_j(0) * hyps_j(0); + + Kuf_e_noise_Kfu.push_back(Kuf_kernels[i] * e_noise_one.asDiagonal() * Kuf_kernels[j].transpose() / sig4); + Kuf_f_noise_Kfu.push_back(Kuf_kernels[i] * f_noise_one.asDiagonal() * Kuf_kernels[j].transpose() / sig4); + Kuf_s_noise_Kfu.push_back(Kuf_kernels[i] * s_noise_one.asDiagonal() * Kuf_kernels[j].transpose() / sig4); + } + } +} + +void SparseGP ::compute_KnK(bool precomputed) { + if (precomputed) { + KnK_e = Eigen::MatrixXd::Zero(n_sparse, n_sparse); + KnK_f = Eigen::MatrixXd::Zero(n_sparse, n_sparse); + KnK_s = Eigen::MatrixXd::Zero(n_sparse, n_sparse); + + int count_i = 0, count_ij = 0; + for (int i = 0; i < n_kernels; i++) { + Eigen::VectorXd hyps_i = kernels[i]->kernel_hyperparameters; + assert(hyps_i.size() == 1); + int size_i = Kuu_kernels[i].rows(); + int count_j = 0; + for (int j = 0; j < n_kernels; j++) { + Eigen::VectorXd hyps_j = kernels[j]->kernel_hyperparameters; + assert(hyps_j.size() == 1); + int size_j = Kuu_kernels[j].rows(); + + double sig4 = hyps_i(0) * hyps_i(0) * hyps_j(0) * hyps_j(0); + + KnK_e.block(count_i, count_j, size_i, size_j) += Kuf_e_noise_Kfu[count_ij] * sig4; + KnK_f.block(count_i, count_j, size_i, size_j) += Kuf_f_noise_Kfu[count_ij] * sig4; + KnK_s.block(count_i, count_j, size_i, size_j) += Kuf_s_noise_Kfu[count_ij] * sig4; + + count_ij += 1; + count_j += size_j; + } + count_i += size_i; + } + } else { + KnK_e = Kuf * e_noise_one.asDiagonal() * Kuf.transpose(); + KnK_f = Kuf * f_noise_one.asDiagonal() * Kuf.transpose(); + KnK_s = Kuf * s_noise_one.asDiagonal() * Kuf.transpose(); + } +} + +Eigen::MatrixXd SparseGP ::compute_dKnK(int i) { + Eigen::MatrixXd dKnK = Eigen::MatrixXd::Zero(n_sparse, n_sparse); + + int count_ij = i * n_kernels; + Eigen::VectorXd hyps_i = kernels[i]->kernel_hyperparameters; + assert(hyps_i.size() == 1); + + int count_i = 0; + for (int r = 0; r < i; r++) { + count_i += Kuu_kernels[r].rows(); + } + int size_i = Kuu_kernels[i].rows(); + + int count_j = 0; + for (int j = 0; j < n_kernels; j++) { + Eigen::VectorXd hyps_j = kernels[j]->kernel_hyperparameters; + assert(hyps_j.size() == 1); + int size_j = Kuu_kernels[j].rows(); + + double sig3 = 2 * hyps_i(0) * hyps_j(0) * hyps_j(0); + double sig3e = sig3 / (energy_noise * energy_noise); + double sig3f = sig3 / (force_noise * force_noise); + double sig3s = sig3 / (stress_noise * stress_noise); + + dKnK.block(count_i, count_j, size_i, size_j) += Kuf_e_noise_Kfu[count_ij] * sig3e; + dKnK.block(count_i, count_j, size_i, size_j) += Kuf_f_noise_Kfu[count_ij] * sig3f; + dKnK.block(count_i, count_j, size_i, size_j) += Kuf_s_noise_Kfu[count_ij] * sig3s; + + count_ij += 1; + count_j += size_j; + } + return dKnK; +} + void SparseGP ::compute_likelihood() { if (n_labels == 0) { std::cout << "Warning: The likelihood is being computed without any " @@ -744,9 +1042,9 @@ SparseGP ::compute_likelihood_gradient(const Eigen::VectorXd &hyperparameters) { hyps_curr = hyperparameters.segment(hyp_index, n_hyps); int size = Kuu_kernels[i].rows(); - Kuu_grad = kernels[i]->Kuu_grad(sparse_descriptors[i], Kuu, hyps_curr); + Kuu_grad = kernels[i]->Kuu_grad(sparse_descriptors[i], Kuu_kernels[i], hyps_curr); Kuf_grad = kernels[i]->Kuf_grad(sparse_descriptors[i], training_structures, - i, Kuf, hyps_curr); + i, Kuf_kernels[i], hyps_curr); Kuu_mat.block(count, count, size, size) = Kuu_grad[0]; Kuf_mat.block(count, 0, size, n_labels) = Kuf_grad[0]; @@ -770,41 +1068,14 @@ SparseGP ::compute_likelihood_gradient(const Eigen::VectorXd &hyperparameters) { .inverse(); // Construct updated noise vector and gradients. - Eigen::VectorXd noise_vec = Eigen::VectorXd::Zero(n_labels); - Eigen::VectorXd e_noise_grad = Eigen::VectorXd::Zero(n_labels); - Eigen::VectorXd f_noise_grad = Eigen::VectorXd::Zero(n_labels); - Eigen::VectorXd s_noise_grad = Eigen::VectorXd::Zero(n_labels); - double sigma_e = hyperparameters(hyp_index); double sigma_f = hyperparameters(hyp_index + 1); double sigma_s = hyperparameters(hyp_index + 2); - int current_count = 0; - for (int i = 0; i < training_structures.size(); i++) { - int n_atoms = training_structures[i].noa; - - if (training_structures[i].energy.size() != 0) { - noise_vec(current_count) = sigma_e * sigma_e; - e_noise_grad(current_count) = 2 * sigma_e; - current_count += 1; - } - - if (training_structures[i].forces.size() != 0) { - noise_vec.segment(current_count, n_atoms * 3) = - Eigen::VectorXd::Constant(n_atoms * 3, sigma_f * sigma_f); - f_noise_grad.segment(current_count, n_atoms * 3) = - Eigen::VectorXd::Constant(n_atoms * 3, 2 * sigma_f); - current_count += n_atoms * 3; - } - - if (training_structures[i].stresses.size() != 0) { - noise_vec.segment(current_count, 6) = - Eigen::VectorXd::Constant(6, sigma_s * sigma_s); - s_noise_grad.segment(current_count, 6) = - Eigen::VectorXd::Constant(6, 2 * sigma_s); - current_count += 6; - } - } + Eigen::VectorXd noise_vec = sigma_e * sigma_e * e_noise_one + sigma_f * sigma_f * f_noise_one + sigma_s * sigma_s * s_noise_one; + Eigen::VectorXd e_noise_grad = 2 * sigma_e * e_noise_one; + Eigen::VectorXd f_noise_grad = 2 * sigma_f * f_noise_one; + Eigen::VectorXd s_noise_grad = 2 * sigma_s * s_noise_one; // Compute Qff and Qff grads. Eigen::MatrixXd Qff_plus_lambda = @@ -839,6 +1110,7 @@ SparseGP ::compute_likelihood_gradient(const Eigen::VectorXd &hyperparameters) { // Compute log determinant from the diagonal of U. complexity_penalty = 0; +#pragma omp parallel for reduction(+:complexity_penalty) for (int i = 0; i < Qff_plus_lambda.rows(); i++) { complexity_penalty += -log(abs(Qff_plus_lambda(i, i))); } @@ -855,9 +1127,9 @@ SparseGP ::compute_likelihood_gradient(const Eigen::VectorXd &hyperparameters) { Eigen::MatrixXd Qff_inv_grad; for (int i = 0; i < n_hyps_total; i++) { Qff_inv_grad = Qff_inverse * Qff_grads[i]; - likelihood_gradient(i) = - -Qff_inv_grad.trace() + y.transpose() * Qff_inv_grad * Q_inv_y; - likelihood_gradient(i) /= 2; + double complexity_grad = - Qff_inv_grad.trace(); + double datafit_grad = y.transpose() * Qff_inv_grad * Q_inv_y; + likelihood_gradient(i) = (complexity_grad + datafit_grad) / 2.; } return log_marginal_likelihood; @@ -868,14 +1140,18 @@ void SparseGP ::set_hyperparameters(Eigen::VectorXd hyps) { int n_hyps, hyp_index = 0; Eigen::VectorXd new_hyps; + double duration = 0; + std::chrono::high_resolution_clock::time_point t1, t2; + t1 = std::chrono::high_resolution_clock::now(); + std::vector Kuu_grad, Kuf_grad; for (int i = 0; i < n_kernels; i++) { n_hyps = kernels[i]->kernel_hyperparameters.size(); new_hyps = hyps.segment(hyp_index, n_hyps); - Kuu_grad = kernels[i]->Kuu_grad(sparse_descriptors[i], Kuu, new_hyps); + Kuu_grad = kernels[i]->Kuu_grad(sparse_descriptors[i], Kuu_kernels[i], new_hyps); Kuf_grad = kernels[i]->Kuf_grad(sparse_descriptors[i], training_structures, - i, Kuf, new_hyps); + i, Kuf_kernels[i], new_hyps); Kuu_kernels[i] = Kuu_grad[0]; Kuf_kernels[i] = Kuf_grad[0]; @@ -884,114 +1160,47 @@ void SparseGP ::set_hyperparameters(Eigen::VectorXd hyps) { hyp_index += n_hyps; } + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: set_hyp: update Kuf Kuu " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + stack_Kuu(); stack_Kuf(); + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: set_hyp: stack Kuf Kuu " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); + + hyperparameters = hyps; energy_noise = hyps(hyp_index); force_noise = hyps(hyp_index + 1); stress_noise = hyps(hyp_index + 2); - int current_count = 0; - for (int i = 0; i < training_structures.size(); i++) { - int n_atoms = training_structures[i].noa; - - if (training_structures[i].energy.size() != 0) { - noise_vector(current_count) = 1 / (energy_noise * energy_noise); - current_count += 1; - } - - if (training_structures[i].forces.size() != 0) { - noise_vector.segment(current_count, n_atoms * 3) = - Eigen::VectorXd::Constant(n_atoms * 3, - 1 / (force_noise * force_noise)); - current_count += n_atoms * 3; - } + noise_vector = 1 / (energy_noise * energy_noise) * e_noise_one + + 1 / (force_noise * force_noise) * f_noise_one + + 1 / (stress_noise * stress_noise) * s_noise_one; - if (training_structures[i].stresses.size() != 0) { - noise_vector.segment(current_count, 6) = - Eigen::VectorXd::Constant(6, 1 / (stress_noise * stress_noise)); - current_count += 6; - } - } + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: set_hyp: update noise " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); // Update remaining matrices. update_matrices_QR(); -} - -void SparseGP::write_mapping_coefficients(std::string file_name, - std::string contributor, - int kernel_index) { - // Compute mapping coefficients. - Eigen::MatrixXd mapping_coeffs = - kernels[kernel_index]->compute_mapping_coefficients(*this, kernel_index); + t2 = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Time: set_hyp: update qr " << duration << std::endl; + t1 = std::chrono::high_resolution_clock::now(); - // Make beta file. - std::ofstream coeff_file; - coeff_file.open(file_name); - - // Record the date. - time_t now = std::time(0); - std::string t(ctime(&now)); - coeff_file << "DATE: "; - coeff_file << t.substr(0, t.length() - 1) << " "; - - // Record the contributor. - coeff_file << "CONTRIBUTOR: "; - coeff_file << contributor << "\n"; - - // Write descriptor information to file. - int coeff_size = mapping_coeffs.row(0).size(); - training_structures[0].descriptor_calculators[kernel_index]->write_to_file( - coeff_file, coeff_size); - - // Write beta vectors to file. - coeff_file << std::scientific << std::setprecision(16); - - int count = 0; - for (int i = 0; i < mapping_coeffs.rows(); i++) { - Eigen::VectorXd coeff_vals = mapping_coeffs.row(i); - - // Start a new line for each beta. - if (count != 0) { - coeff_file << "\n"; - } - - for (int j = 0; j < coeff_vals.size(); j++) { - double coeff_val = coeff_vals[j]; - - // Pad with 2 spaces if positive, 1 if negative. - if (coeff_val > 0) { - coeff_file << " "; - } else { - coeff_file << " "; - } - - coeff_file << coeff_vals[j]; - count++; - - // New line if 5 numbers have been added. - if (count == 5) { - count = 0; - coeff_file << "\n"; - } - } - } - - coeff_file.close(); } -void SparseGP::write_varmap_coefficients( - std::string file_name, std::string contributor, int kernel_index) { - - // TODO: merge this function with write_mapping_coeff, - // add an option in the function above for mapping "mean" or "var" - - // Compute mapping coefficients. - //Eigen::MatrixXd varmap_coeffs = - varmap_coeffs = - kernels[kernel_index]->compute_varmap_coefficients(*this, kernel_index); +void SparseGP::write_mapping_coefficients(std::string file_name, + std::string contributor, std::vector kernel_indices, + std::string map_type) { // Make beta file. std::ofstream coeff_file; @@ -1007,44 +1216,63 @@ void SparseGP::write_varmap_coefficients( coeff_file << "CONTRIBUTOR: "; coeff_file << contributor << "\n"; - // Write descriptor information to file. - int coeff_size = varmap_coeffs.row(0).size(); - training_structures[0].descriptor_calculators[kernel_index]-> - write_to_file(coeff_file, coeff_size); - - // Write beta vectors to file. - coeff_file << std::scientific << std::setprecision(16); + // Write the number of kernels/descriptors to map + coeff_file << kernel_indices.size(); - int count = 0; - for (int i = 0; i < varmap_coeffs.rows(); i++) { - Eigen::VectorXd coeff_vals = varmap_coeffs.row(i); + for (int k = 0; k < kernel_indices.size(); k++) { + int kernel_index = kernel_indices[k]; - // Start a new line for each beta. - if (count != 0) { - coeff_file << "\n"; + // Compute mapping coefficients. + Eigen::MatrixXd mapping_coeffs; + if (map_type == std::string("potential")) { + mapping_coeffs = + kernels[kernel_index]->compute_mapping_coefficients(*this, kernel_index); + } else if (map_type == std::string("uncertainty")) { + mapping_coeffs = + kernels[kernel_index]->compute_varmap_coefficients(*this, kernel_index); } - - for (int j = 0; j < coeff_vals.size(); j++) { - double coeff_val = coeff_vals[j]; - - // Pad with 2 spaces if positive, 1 if negative. - if (coeff_val > 0) { - coeff_file << " "; - } else { - coeff_file << " "; - } - - coeff_file << coeff_vals[j]; - count++; - - // New line if 5 numbers have been added. - if (count == 5) { + + // Write descriptor information to file. + int coeff_size = mapping_coeffs.row(0).size(); + training_structures[0].descriptor_calculators[kernel_index]->write_to_file( + coeff_file, coeff_size); + + // Write beta vectors to file. + coeff_file << std::scientific << std::setprecision(16); + + int count = 0; + for (int i = 0; i < mapping_coeffs.rows(); i++) { + Eigen::VectorXd coeff_vals = mapping_coeffs.row(i); + + // Start a new line for each beta. + if (count != 0) { count = 0; coeff_file << "\n"; } + + for (int j = 0; j < coeff_vals.size(); j++) { + double coeff_val = coeff_vals[j]; + + // Pad with 2 spaces if positive, 1 if negative. + if (coeff_val > 0) { + coeff_file << " "; + } else { + coeff_file << " "; + } + + coeff_file << coeff_vals[j]; + count++; + + // New line if 5 numbers have been added. + if (count == 5) { + count = 0; + if (i != mapping_coeffs.rows() - 1 || j != coeff_vals.size() - 1) { + coeff_file << "\n"; + } + } + } } } - coeff_file.close(); } diff --git a/src/flare_pp/bffs/sparse_gp.h b/src/flare_pp/bffs/sparse_gp.h index 284dea82..91d95f08 100644 --- a/src/flare_pp/bffs/sparse_gp.h +++ b/src/flare_pp/bffs/sparse_gp.h @@ -17,6 +17,8 @@ class SparseGP { std::vector kernels; std::vector Kuu_kernels, Kuf_kernels; Eigen::MatrixXd Kuu, Kuf; + std::vector Kuf_e_noise_Kfu, Kuf_f_noise_Kfu, Kuf_s_noise_Kfu; + Eigen::MatrixXd KnK_e, KnK_f, KnK_s; int n_kernels = 0; double Kuu_jitter; @@ -27,10 +29,11 @@ class SparseGP { // Training and sparse points. std::vector sparse_descriptors; std::vector training_structures; + std::vector> training_atom_indices; std::vector>> sparse_indices; // Label attributes. - Eigen::VectorXd noise_vector, y, label_count; + Eigen::VectorXd noise_vector, y, label_count, e_noise_one, f_noise_one, s_noise_one; int n_energy_labels = 0, n_force_labels = 0, n_stress_labels = 0, n_sparse = 0, n_labels = 0, n_strucs = 0; double energy_noise, force_noise, stress_noise; @@ -42,14 +45,27 @@ class SparseGP { // Constructors. SparseGP(); + + // Destructors. + virtual ~SparseGP(); + + /** + Basic Sparse GP constructor. + + @param kernels A list of Kernel objects, e.g. NormalizedInnerProduct, SquaredExponential. + Note the number of kernels should be equal to the number of descriptor calculators. + @param energy_noise Noise hyperparameter for total energy. + @param force_noise Noise hyperparameter for atomic forces. + @param stress_noise Noise hyperparameter for total stress. + */ SparseGP(std::vector kernels, double energy_noise, double force_noise, double stress_noise); void initialize_sparse_descriptors(const Structure &structure); void add_all_environments(const Structure &structure); - void add_specific_environments(const Structure &structure, - const std::vector atoms); + virtual void add_specific_environments(const Structure &structure, + const std::vector> atoms); void add_random_environments(const Structure &structure, const std::vector &n_added); void add_uncertain_environments(const Structure &structure, @@ -59,33 +75,35 @@ class SparseGP { std::vector> sort_clusters_by_uncertainty(const Structure &structure); - void add_training_structure(const Structure &structure); + virtual void add_training_structure(const Structure &structure, const std::vector atom_indices = {-1}); + void update_Kuu(const std::vector &cluster_descriptors); void update_Kuf(const std::vector &cluster_descriptors); void stack_Kuu(); - void stack_Kuf(); + virtual void stack_Kuf(); - void update_matrices_QR(); + virtual void update_matrices_QR(); void predict_mean(Structure &structure); void predict_SOR(Structure &structure); void predict_DTC(Structure &structure); - void predict_local_uncertainties(Structure &structure); + virtual void predict_local_uncertainties(Structure &structure); - void compute_likelihood_stable(); - void compute_likelihood(); + virtual void compute_likelihood_stable(); + virtual double compute_likelihood_gradient_stable(bool precomputed_KnK = false); + virtual void precompute_KnK(); + virtual void compute_KnK(bool precomputed = false); + virtual Eigen::MatrixXd compute_dKnK(int i); - double compute_likelihood_gradient(const Eigen::VectorXd &hyperparameters); - void set_hyperparameters(Eigen::VectorXd hyps); + virtual void compute_likelihood(); - void write_mapping_coefficients(std::string file_name, - std::string contributor, - int kernel_index); + double compute_likelihood_gradient(const Eigen::VectorXd &hyperparameters); + virtual void set_hyperparameters(Eigen::VectorXd hyps); Eigen::MatrixXd varmap_coeffs; // for debugging. TODO: remove this line - void write_varmap_coefficients(std::string file_name, - std::string contributor, - int kernel_index); + void write_mapping_coefficients(std::string file_name, + std::string contributor, std::vector kernel_indices, + std::string map_type = std::string("potential")); // TODO: Make kernels jsonable. NLOHMANN_DEFINE_TYPE_INTRUSIVE(SparseGP, hyperparameters, kernels, diff --git a/src/flare_pp/descriptors/b1.cpp b/src/flare_pp/descriptors/b1.cpp new file mode 100644 index 00000000..63064d3e --- /dev/null +++ b/src/flare_pp/descriptors/b1.cpp @@ -0,0 +1,353 @@ +#include "b1.h" +#include "cutoffs.h" +#include "descriptor.h" +#include "radial.h" +#include "structure.h" +#include "y_grad.h" +#include // File operations +#include // setprecision +#include + +B1 ::B1() {} + +B1 ::B1(const std::string &radial_basis, const std::string &cutoff_function, + const std::vector &radial_hyps, + const std::vector &cutoff_hyps, + const std::vector &descriptor_settings) { + + this->radial_basis = radial_basis; + this->cutoff_function = cutoff_function; + this->radial_hyps = radial_hyps; + this->cutoff_hyps = cutoff_hyps; + this->descriptor_settings = descriptor_settings; + + set_radial_basis(radial_basis, this->radial_pointer); + set_cutoff(cutoff_function, this->cutoff_pointer); +} + +void B1 ::write_to_file(std::ofstream &coeff_file, int coeff_size) { + coeff_file << "\n" << "B1" << "\n"; + + // Report radial basis set. + coeff_file << radial_basis << "\n"; + + // Record number of species, nmax, lmax, and the cutoff. + int n_species = descriptor_settings[0]; + int n_max = descriptor_settings[1]; + int l_max = 0; + double cutoff = radial_hyps[1]; + + coeff_file << n_species << " " << n_max << " " << l_max << " "; + coeff_file << coeff_size << "\n"; + coeff_file << cutoff_function << "\n"; + + // Report cutoff to 2 decimal places. + coeff_file << std::fixed << std::setprecision(2); + coeff_file << cutoff << "\n"; +} + +DescriptorValues B1 ::compute_struc(Structure &structure) { + + // Initialize descriptor values. + DescriptorValues desc = DescriptorValues(); + + // Compute single bond values. + Eigen::MatrixXd single_bond_vals, force_dervs, neighbor_coords; + Eigen::VectorXi unique_neighbor_count, cumulative_neighbor_count, + descriptor_indices; + + int nos = descriptor_settings[0]; + int N = descriptor_settings[1]; + int lmax = 0; + + b1_single_bond(single_bond_vals, force_dervs, neighbor_coords, + unique_neighbor_count, cumulative_neighbor_count, + descriptor_indices, radial_pointer, cutoff_pointer, nos, + N, lmax, radial_hyps, cutoff_hyps, structure); + + // Compute descriptor values. + Eigen::MatrixXd B1_vals, B1_force_dervs; + Eigen::VectorXd B1_norms, B1_force_dots; + + compute_b1(B1_vals, B1_force_dervs, B1_norms, B1_force_dots, single_bond_vals, + force_dervs, unique_neighbor_count, cumulative_neighbor_count, + descriptor_indices, nos, N, lmax); + + // Gather species information. + int noa = structure.noa; + Eigen::VectorXi species_count = Eigen::VectorXi::Zero(nos); + Eigen::VectorXi neighbor_count = Eigen::VectorXi::Zero(nos); + for (int i = 0; i < noa; i++) { + int s = structure.species[i]; + int n_neigh = unique_neighbor_count(i); + species_count(s)++; + neighbor_count(s) += n_neigh; + } + + // Initialize arrays. + int n_d = B1_vals.cols(); + desc.n_descriptors = n_d; + desc.n_types = nos; + desc.n_atoms = noa; + desc.volume = structure.volume; + desc.cumulative_type_count.push_back(0); + for (int s = 0; s < nos; s++) { + int n_s = species_count(s); + int n_neigh = neighbor_count(s); + + // Record species and neighbor count. + desc.n_clusters_by_type.push_back(n_s); + desc.cumulative_type_count.push_back(desc.cumulative_type_count[s] + n_s); + desc.n_clusters += n_s; + desc.n_neighbors_by_type.push_back(n_neigh); + + desc.descriptors.push_back(Eigen::MatrixXd::Zero(n_s, n_d)); + desc.descriptor_force_dervs.push_back( + Eigen::MatrixXd::Zero(n_neigh * 3, n_d)); + desc.neighbor_coordinates.push_back(Eigen::MatrixXd::Zero(n_neigh, 3)); + + desc.cutoff_values.push_back(Eigen::VectorXd::Ones(n_s)); + desc.cutoff_dervs.push_back(Eigen::VectorXd::Zero(n_neigh * 3)); + desc.descriptor_norms.push_back(Eigen::VectorXd::Zero(n_s)); + desc.descriptor_force_dots.push_back(Eigen::VectorXd::Zero(n_neigh * 3)); + + desc.neighbor_counts.push_back(Eigen::VectorXi::Zero(n_s)); + desc.cumulative_neighbor_counts.push_back(Eigen::VectorXi::Zero(n_s)); + desc.atom_indices.push_back(Eigen::VectorXi::Zero(n_s)); + desc.neighbor_indices.push_back(Eigen::VectorXi::Zero(n_neigh)); + } + + // Assign to structure. + Eigen::VectorXi species_counter = Eigen::VectorXi::Zero(nos); + Eigen::VectorXi neighbor_counter = Eigen::VectorXi::Zero(nos); + for (int i = 0; i < noa; i++) { + int s = structure.species[i]; + int s_count = species_counter(s); + int n_neigh = unique_neighbor_count(i); + int n_count = neighbor_counter(s); + int cum_neigh = cumulative_neighbor_count(i); + + desc.descriptors[s].row(s_count) = B1_vals.row(i); + desc.descriptor_force_dervs[s].block(n_count * 3, 0, n_neigh * 3, n_d) = + B1_force_dervs.block(cum_neigh * 3, 0, n_neigh * 3, n_d); + desc.neighbor_coordinates[s].block(n_count, 0, n_neigh, 3) = + neighbor_coords.block(cum_neigh, 0, n_neigh, 3); + + desc.descriptor_norms[s](s_count) = B1_norms(i); + desc.descriptor_force_dots[s].segment(n_count * 3, n_neigh * 3) = + B1_force_dots.segment(cum_neigh * 3, n_neigh * 3); + + desc.neighbor_counts[s](s_count) = n_neigh; + desc.cumulative_neighbor_counts[s](s_count) = n_count; + desc.atom_indices[s](s_count) = i; + desc.neighbor_indices[s].segment(n_count, n_neigh) = + descriptor_indices.segment(cum_neigh, n_neigh); + + species_counter(s)++; + neighbor_counter(s) += n_neigh; + } + + return desc; +} + +void compute_b1(Eigen::MatrixXd &B1_vals, Eigen::MatrixXd &B1_force_dervs, + Eigen::VectorXd &B1_norms, Eigen::VectorXd &B1_force_dots, + const Eigen::MatrixXd &single_bond_vals, + const Eigen::MatrixXd &single_bond_force_dervs, + const Eigen::VectorXi &unique_neighbor_count, + const Eigen::VectorXi &cumulative_neighbor_count, + const Eigen::VectorXi &descriptor_indices, int nos, int N, + int lmax) { + + int n_atoms = single_bond_vals.rows(); + int n_neighbors = cumulative_neighbor_count(n_atoms); + int n_radial = nos * N; + assert(lmax == 0); // for b1, lmax = m = 0 + int n_harmonics = (lmax + 1) * (lmax + 1); + int n_bond = n_radial * n_harmonics; + int n_d = n_radial; + + // Initialize arrays. + B1_vals = Eigen::MatrixXd::Zero(n_atoms, n_d); + B1_force_dervs = Eigen::MatrixXd::Zero(n_neighbors * 3, n_d); + B1_norms = Eigen::VectorXd::Zero(n_atoms); + B1_force_dots = Eigen::VectorXd::Zero(n_neighbors * 3); + +#pragma omp parallel for + for (int atom = 0; atom < n_atoms; atom++) { + int n_atom_neighbors = unique_neighbor_count(atom); + int force_start = cumulative_neighbor_count(atom) * 3; + int n1, n2, l, m, n1_l, n2_l; + int counter = 0; + for (int n1 = 0; n1 < n_radial; n1++) { + n1_l = n1 * n_harmonics; + B1_vals(atom, counter) += single_bond_vals(atom, n1_l); + + // Store force derivatives. + for (int n = 0; n < n_atom_neighbors; n++) { + for (int comp = 0; comp < 3; comp++) { + int ind = force_start + n * 3 + comp; + B1_force_dervs(ind, counter) += + single_bond_force_dervs(ind, n1_l); + } + } + counter++; + } + // Compute descriptor norm and force dot products. + B1_norms(atom) = sqrt(B1_vals.row(atom).dot(B1_vals.row(atom))); + B1_force_dots.segment(force_start, n_atom_neighbors * 3) = + B1_force_dervs.block(force_start, 0, n_atom_neighbors * 3, n_d) * + B1_vals.row(atom).transpose(); + } +} + +void b1_single_bond( + Eigen::MatrixXd &single_bond_vals, Eigen::MatrixXd &force_dervs, + Eigen::MatrixXd &neighbor_coordinates, Eigen::VectorXi &neighbor_count, + Eigen::VectorXi &cumulative_neighbor_count, + Eigen::VectorXi &neighbor_indices, + std::function &, std::vector &, double, + int, std::vector)> + radial_function, + std::function &, double, double, + std::vector)> + cutoff_function, + int nos, int N, int lmax, const std::vector &radial_hyps, + const std::vector &cutoff_hyps, const Structure &structure) { + + int n_atoms = structure.noa; + int n_neighbors = structure.n_neighbors; + assert(lmax == 0); // for b1, lmax = m = 0 + + // TODO: Make rcut an attribute of the descriptor calculator. + double rcut = radial_hyps[1]; + + // Count atoms inside the descriptor cutoff. + neighbor_count = Eigen::VectorXi::Zero(n_atoms); + Eigen::VectorXi store_neighbors = Eigen::VectorXi::Zero(n_neighbors); +#pragma omp parallel for + for (int i = 0; i < n_atoms; i++) { + int i_neighbors = structure.neighbor_count(i); + int rel_index = structure.cumulative_neighbor_count(i); + for (int j = 0; j < i_neighbors; j++) { + int current_count = neighbor_count(i); + int neigh_index = rel_index + j; + double r = structure.relative_positions(neigh_index, 0); + // Check that atom is within descriptor cutoff. + if (r <= rcut) { + int struc_index = structure.structure_indices(neigh_index); + // Update neighbor list. + store_neighbors(rel_index + current_count) = struc_index; + neighbor_count(i)++; + } + } + } + + // Count cumulative number of unique neighbors. + cumulative_neighbor_count = Eigen::VectorXi::Zero(n_atoms + 1); + for (int i = 1; i < n_atoms + 1; i++) { + cumulative_neighbor_count(i) += + cumulative_neighbor_count(i - 1) + neighbor_count(i - 1); + } + + // Record neighbor indices. + int bond_neighbors = cumulative_neighbor_count(n_atoms); + neighbor_indices = Eigen::VectorXi::Zero(bond_neighbors); +#pragma omp parallel for + for (int i = 0; i < n_atoms; i++) { + int i_neighbors = neighbor_count(i); + int ind1 = cumulative_neighbor_count(i); + int ind2 = structure.cumulative_neighbor_count(i); + for (int j = 0; j < i_neighbors; j++) { + neighbor_indices(ind1 + j) = store_neighbors(ind2 + j); + } + } + + // Initialize single bond arrays. + int number_of_harmonics = (lmax + 1) * (lmax + 1); + int no_bond_vals = N * number_of_harmonics; + int single_bond_size = no_bond_vals * nos; + + single_bond_vals = Eigen::MatrixXd::Zero(n_atoms, single_bond_size); + force_dervs = Eigen::MatrixXd::Zero(bond_neighbors * 3, single_bond_size); + neighbor_coordinates = Eigen::MatrixXd::Zero(bond_neighbors, 3); + +#pragma omp parallel for + for (int i = 0; i < n_atoms; i++) { + int i_neighbors = structure.neighbor_count(i); + int rel_index = structure.cumulative_neighbor_count(i); + int neighbor_index = cumulative_neighbor_count(i); + + // Initialize radial and spherical harmonic vectors. + std::vector g = std::vector(N, 0); + std::vector gx = std::vector(N, 0); + std::vector gy = std::vector(N, 0); + std::vector gz = std::vector(N, 0); + + std::vector h = std::vector(number_of_harmonics, 0); + std::vector hx = std::vector(number_of_harmonics, 0); + std::vector hy = std::vector(number_of_harmonics, 0); + std::vector hz = std::vector(number_of_harmonics, 0); + + double x, y, z, r, bond, bond_x, bond_y, bond_z, g_val, gx_val, gy_val, + gz_val, h_val; + int s, neigh_index, descriptor_counter, unique_ind; + for (int j = 0; j < i_neighbors; j++) { + neigh_index = rel_index + j; + r = structure.relative_positions(neigh_index, 0); + if (r > rcut) + continue; // Skip if outside cutoff. + x = structure.relative_positions(neigh_index, 1); + y = structure.relative_positions(neigh_index, 2); + z = structure.relative_positions(neigh_index, 3); + s = structure.neighbor_species(neigh_index); + + // Store neighbor coordinates. + neighbor_coordinates(neighbor_index, 0) = x; + neighbor_coordinates(neighbor_index, 1) = y; + neighbor_coordinates(neighbor_index, 2) = z; + + // Compute radial basis values and spherical harmonics. + calculate_radial(g, gx, gy, gz, radial_function, cutoff_function, x, y, z, + r, rcut, N, radial_hyps, cutoff_hyps); + get_Y(h, hx, hy, hz, x, y, z, lmax); + + // Store the products and their derivatives. + descriptor_counter = s * no_bond_vals; + + for (int radial_counter = 0; radial_counter < N; radial_counter++) { + // Retrieve radial values. + g_val = g[radial_counter]; + gx_val = gx[radial_counter]; + gy_val = gy[radial_counter]; + gz_val = gz[radial_counter]; + + // Compute single bond value. + int angular_counter = 0; + h_val = h[angular_counter]; + bond = g_val * h_val; + + // Calculate derivatives with the product rule. + bond_x = gx_val * h_val + g_val * hx[angular_counter]; + bond_y = gy_val * h_val + g_val * hy[angular_counter]; + bond_z = gz_val * h_val + g_val * hz[angular_counter]; + + // Update single bond arrays. + single_bond_vals(i, descriptor_counter) += bond; + + force_dervs(neighbor_index * 3, descriptor_counter) += bond_x; + force_dervs(neighbor_index * 3 + 1, descriptor_counter) += bond_y; + force_dervs(neighbor_index * 3 + 2, descriptor_counter) += bond_z; + + descriptor_counter++; + } + neighbor_index++; + } + } +} + +// TODO: Implement. +nlohmann::json B1 ::return_json(){ + nlohmann::json j; + return j; +} diff --git a/src/flare_pp/descriptors/b1.h b/src/flare_pp/descriptors/b1.h new file mode 100644 index 00000000..e3907b44 --- /dev/null +++ b/src/flare_pp/descriptors/b1.h @@ -0,0 +1,60 @@ +#ifndef B1_H +#define B1_H + +#include "descriptor.h" +#include +#include + +class Structure; + +class B1 : public Descriptor { +public: + std::function &, std::vector &, double, int, + std::vector)> + radial_pointer; + std::function &, double, double, + std::vector)> + cutoff_pointer; + std::string radial_basis, cutoff_function; + std::vector radial_hyps, cutoff_hyps; + std::vector descriptor_settings; + int K = 1; // Body order + + B1(); + + B1(const std::string &radial_basis, const std::string &cutoff_function, + const std::vector &radial_hyps, + const std::vector &cutoff_hyps, + const std::vector &descriptor_settings); + + DescriptorValues compute_struc(Structure &structure); + + void write_to_file(std::ofstream &coeff_file, int coeff_size); + + nlohmann::json return_json(); +}; + +void compute_b1(Eigen::MatrixXd &B1_vals, Eigen::MatrixXd &B1_force_dervs, + Eigen::VectorXd &B1_norms, Eigen::VectorXd &B1_force_dots, + const Eigen::MatrixXd &single_bond_vals, + const Eigen::MatrixXd &single_bond_force_dervs, + const Eigen::VectorXi &unique_neighbor_count, + const Eigen::VectorXi &cumulative_neighbor_count, + const Eigen::VectorXi &descriptor_indices, int nos, int N, + int lmax); + +void b1_single_bond( + Eigen::MatrixXd &single_bond_vals, Eigen::MatrixXd &force_dervs, + Eigen::MatrixXd &neighbor_coordinates, Eigen::VectorXi &neighbor_count, + Eigen::VectorXi &cumulative_neighbor_count, + Eigen::VectorXi &neighbor_indices, + std::function &, std::vector &, double, + int, std::vector)> + radial_function, + std::function &, double, double, + std::vector)> + cutoff_function, + int nos, int N, int lmax, const std::vector &radial_hyps, + const std::vector &cutoff_hyps, const Structure &structure); + +#endif diff --git a/src/flare_pp/descriptors/b2.h b/src/flare_pp/descriptors/b2.h index 17df31c9..2371d79d 100644 --- a/src/flare_pp/descriptors/b2.h +++ b/src/flare_pp/descriptors/b2.h @@ -22,6 +22,7 @@ class B2 : public Descriptor { std::vector descriptor_settings; std::string descriptor_name = "B2"; + int K = 2; // Body order /** Matrix of cutoff values, with element (i, j) corresponding to the cutoff * assigned to the species pair (i, j), where i is the central species diff --git a/src/flare_pp/descriptors/b2_norm.cpp b/src/flare_pp/descriptors/b2_norm.cpp index be440424..ac67eebb 100644 --- a/src/flare_pp/descriptors/b2_norm.cpp +++ b/src/flare_pp/descriptors/b2_norm.cpp @@ -154,7 +154,7 @@ void compute_b2_norm( for (int i = 0; i < n_atoms; i++){ double norm_val = B2_norms_1(i); double norm_val_3 = norm_val * norm_val * norm_val; - B2_vals.row(i) = B2_vals_1.row(i) / norm_val; + B2_vals.row(i) = B2_vals_1.row(i) / norm_val; // TODO: norm_val can be 0 B2_norms(i) = 1; int n_atom_neighbors = unique_neighbor_count(i); int force_start = cumulative_neighbor_count(i) * 3; diff --git a/src/flare_pp/descriptors/b3.cpp b/src/flare_pp/descriptors/b3.cpp index d2325823..db72f36e 100644 --- a/src/flare_pp/descriptors/b3.cpp +++ b/src/flare_pp/descriptors/b3.cpp @@ -1,9 +1,10 @@ #include "b3.h" +#include "bk.h" #include "cutoffs.h" #include "descriptor.h" #include "radial.h" #include "structure.h" -#include "wigner3j.h" +#include "coeffs.h" #include "y_grad.h" #include @@ -20,7 +21,7 @@ B3 ::B3(const std::string &radial_basis, const std::string &cutoff_function, this->cutoff_hyps = cutoff_hyps; this->descriptor_settings = descriptor_settings; - wigner3j_coeffs = compute_coeffs(descriptor_settings[2]); + wigner3j_coeffs = compute_coeffs(3, descriptor_settings[2]); set_radial_basis(radial_basis, this->radial_pointer); set_cutoff(cutoff_function, this->cutoff_pointer); @@ -40,11 +41,13 @@ DescriptorValues B3 ::compute_struc(Structure &structure) { int nos = descriptor_settings[0]; int N = descriptor_settings[1]; int lmax = descriptor_settings[2]; + double cutoff = radial_hyps[1]; + Eigen::MatrixXd cutoffs = Eigen::MatrixXd::Constant(nos, nos, cutoff); complex_single_bond(single_bond_vals, force_dervs, neighbor_coords, unique_neighbor_count, cumulative_neighbor_count, descriptor_indices, radial_pointer, cutoff_pointer, nos, - N, lmax, radial_hyps, cutoff_hyps, structure); + N, lmax, radial_hyps, cutoff_hyps, structure, cutoffs); // Compute descriptor values. Eigen::MatrixXd B3_vals, B3_force_dervs; @@ -237,150 +240,6 @@ void compute_B3(Eigen::MatrixXd &B3_vals, Eigen::MatrixXd &B3_force_dervs, } } -void complex_single_bond( - Eigen::MatrixXcd &single_bond_vals, Eigen::MatrixXcd &force_dervs, - Eigen::MatrixXd &neighbor_coordinates, Eigen::VectorXi &neighbor_count, - Eigen::VectorXi &cumulative_neighbor_count, - Eigen::VectorXi &neighbor_indices, - std::function &, std::vector &, double, - int, std::vector)> - radial_function, - std::function &, double, double, - std::vector)> - cutoff_function, - int nos, int N, int lmax, const std::vector &radial_hyps, - const std::vector &cutoff_hyps, const Structure &structure) { - - int n_atoms = structure.noa; - int n_neighbors = structure.n_neighbors; - - // TODO: Make rcut an attribute of the descriptor calculator. - double rcut = radial_hyps[1]; - - // Count atoms inside the descriptor cutoff. - neighbor_count = Eigen::VectorXi::Zero(n_atoms); - Eigen::VectorXi store_neighbors = Eigen::VectorXi::Zero(n_neighbors); -#pragma omp parallel for - for (int i = 0; i < n_atoms; i++) { - int i_neighbors = structure.neighbor_count(i); - int rel_index = structure.cumulative_neighbor_count(i); - for (int j = 0; j < i_neighbors; j++) { - int current_count = neighbor_count(i); - int neigh_index = rel_index + j; - double r = structure.relative_positions(neigh_index, 0); - // Check that atom is within descriptor cutoff. - if (r <= rcut) { - int struc_index = structure.structure_indices(neigh_index); - // Update neighbor list. - store_neighbors(rel_index + current_count) = struc_index; - neighbor_count(i)++; - } - } - } - - // Count cumulative number of unique neighbors. - cumulative_neighbor_count = Eigen::VectorXi::Zero(n_atoms + 1); - for (int i = 1; i < n_atoms + 1; i++) { - cumulative_neighbor_count(i) += - cumulative_neighbor_count(i - 1) + neighbor_count(i - 1); - } - - // Record neighbor indices. - int bond_neighbors = cumulative_neighbor_count(n_atoms); - neighbor_indices = Eigen::VectorXi::Zero(bond_neighbors); -#pragma omp parallel for - for (int i = 0; i < n_atoms; i++) { - int i_neighbors = neighbor_count(i); - int ind1 = cumulative_neighbor_count(i); - int ind2 = structure.cumulative_neighbor_count(i); - for (int j = 0; j < i_neighbors; j++) { - neighbor_indices(ind1 + j) = store_neighbors(ind2 + j); - } - } - - // Initialize single bond arrays. - int number_of_harmonics = (lmax + 1) * (lmax + 1); - int no_bond_vals = N * number_of_harmonics; - int single_bond_size = no_bond_vals * nos; - - single_bond_vals = Eigen::MatrixXcd::Zero(n_atoms, single_bond_size); - force_dervs = Eigen::MatrixXcd::Zero(bond_neighbors * 3, single_bond_size); - neighbor_coordinates = Eigen::MatrixXd::Zero(bond_neighbors, 3); - -#pragma omp parallel for - for (int i = 0; i < n_atoms; i++) { - int i_neighbors = structure.neighbor_count(i); - int rel_index = structure.cumulative_neighbor_count(i); - int neighbor_index = cumulative_neighbor_count(i); - - // Initialize radial and spherical harmonic vectors. - std::vector g = std::vector(N, 0); - std::vector gx = std::vector(N, 0); - std::vector gy = std::vector(N, 0); - std::vector gz = std::vector(N, 0); - - Eigen::VectorXcd h, hx, hy, hz; - - double x, y, z, r, g_val, gx_val, gy_val, gz_val; - std::complex bond, bond_x, bond_y, bond_z, h_val; - int s, neigh_index, descriptor_counter, unique_ind; - for (int j = 0; j < i_neighbors; j++) { - neigh_index = rel_index + j; - r = structure.relative_positions(neigh_index, 0); - if (r > rcut) - continue; // Skip if outside cutoff. - x = structure.relative_positions(neigh_index, 1); - y = structure.relative_positions(neigh_index, 2); - z = structure.relative_positions(neigh_index, 3); - s = structure.neighbor_species(neigh_index); - - // Store neighbor coordinates. - neighbor_coordinates(neighbor_index, 0) = x; - neighbor_coordinates(neighbor_index, 1) = y; - neighbor_coordinates(neighbor_index, 2) = z; - - // Compute radial basis values and spherical harmonics. - calculate_radial(g, gx, gy, gz, radial_function, cutoff_function, x, y, z, - r, rcut, N, radial_hyps, cutoff_hyps); - get_complex_Y(h, hx, hy, hz, x, y, z, lmax); - - // Store the products and their derivatives. - descriptor_counter = s * no_bond_vals; - - for (int radial_counter = 0; radial_counter < N; radial_counter++) { - // Retrieve radial values. - g_val = g[radial_counter]; - gx_val = gx[radial_counter]; - gy_val = gy[radial_counter]; - gz_val = gz[radial_counter]; - - for (int angular_counter = 0; angular_counter < number_of_harmonics; - angular_counter++) { - - // Compute single bond value. - h_val = h(angular_counter); - bond = g_val * h_val; - - // Calculate derivatives with the product rule. - bond_x = gx_val * h_val + g_val * hx(angular_counter); - bond_y = gy_val * h_val + g_val * hy(angular_counter); - bond_z = gz_val * h_val + g_val * hz(angular_counter); - - // Update single bond arrays. - single_bond_vals(i, descriptor_counter) += bond; - - force_dervs(neighbor_index * 3, descriptor_counter) += bond_x; - force_dervs(neighbor_index * 3 + 1, descriptor_counter) += bond_y; - force_dervs(neighbor_index * 3 + 2, descriptor_counter) += bond_z; - - descriptor_counter++; - } - } - neighbor_index++; - } - } -} - // TODO: Implement. nlohmann::json B3 ::return_json(){ nlohmann::json j; diff --git a/src/flare_pp/descriptors/b3.h b/src/flare_pp/descriptors/b3.h index b6da7ee8..ee620104 100644 --- a/src/flare_pp/descriptors/b3.h +++ b/src/flare_pp/descriptors/b3.h @@ -21,6 +21,7 @@ class B3 : public Descriptor { Eigen::VectorXd wigner3j_coeffs; std::string descriptor_name = "B3"; + int K = 3; // Body order B3(); @@ -43,18 +44,4 @@ void compute_B3(Eigen::MatrixXd &B3_vals, Eigen::MatrixXd &B3_force_dervs, const Eigen::VectorXi &descriptor_indices, int nos, int N, int lmax, const Eigen::VectorXd &wigner3j_coeffs); -void complex_single_bond( - Eigen::MatrixXcd &single_bond_vals, Eigen::MatrixXcd &force_dervs, - Eigen::MatrixXd &neighbor_coordinates, Eigen::VectorXi &neighbor_count, - Eigen::VectorXi &cumulative_neighbor_count, - Eigen::VectorXi &neighbor_indices, - std::function &, std::vector &, double, - int, std::vector)> - radial_function, - std::function &, double, double, - std::vector)> - cutoff_function, - int nos, int N, int lmax, const std::vector &radial_hyps, - const std::vector &cutoff_hyps, const Structure &structure); - #endif diff --git a/src/flare_pp/descriptors/bk.cpp b/src/flare_pp/descriptors/bk.cpp new file mode 100644 index 00000000..b08df3b6 --- /dev/null +++ b/src/flare_pp/descriptors/bk.cpp @@ -0,0 +1,460 @@ +#include "bk.h" +#include "cutoffs.h" +#include "descriptor.h" +#include "omp.h" +#include "radial.h" +#include "structure.h" +#include "coeffs.h" +#include "indices.h" +#include "y_grad.h" +#include // File operations +#include // setprecision +#include + +Bk ::Bk() {} + +Bk ::Bk(const std::string &radial_basis, const std::string &cutoff_function, + const std::vector &radial_hyps, + const std::vector &cutoff_hyps, + const std::vector &descriptor_settings) { + + this->radial_basis = radial_basis; + this->cutoff_function = cutoff_function; + this->radial_hyps = radial_hyps; + this->cutoff_hyps = cutoff_hyps; + this->descriptor_settings = descriptor_settings; // nos, K, nmax, lmax + + // check if lmax = 0 when K = 1 + if (this->descriptor_settings[1] == 1 && this->descriptor_settings[3] != 0) { + std::cout << "Warning: change lmax to 0 because K = 1" << std::endl; + this->descriptor_settings[3] = 0; + } + + nu = compute_indices(descriptor_settings); + std::cout << "nu size: " << nu.size() << std::endl; + coeffs = compute_coeffs(descriptor_settings[1], descriptor_settings[3]); + + set_radial_basis(radial_basis, this->radial_pointer); + set_cutoff(cutoff_function, this->cutoff_pointer); + + // Create cutoff matrix. + int n_species = descriptor_settings[0]; + double cutoff_val = radial_hyps[1]; + cutoffs = Eigen::MatrixXd::Constant(n_species, n_species, cutoff_val); +} + +Bk ::Bk(const std::string &radial_basis, const std::string &cutoff_function, + const std::vector &radial_hyps, + const std::vector &cutoff_hyps, + const std::vector &descriptor_settings, + const Eigen::MatrixXd &cutoffs) { + + this->radial_basis = radial_basis; + this->cutoff_function = cutoff_function; + this->radial_hyps = radial_hyps; + this->cutoff_hyps = cutoff_hyps; + this->descriptor_settings = descriptor_settings; // nos, K, nmax, lmax + + // check if lmax = 0 when K = 1 + if (this->descriptor_settings[1] == 1 && this->descriptor_settings[3] != 0) { + std::cout << "Warning: change lmax to 0 because K = 1" << std::endl; + this->descriptor_settings[3] = 0; + } + + nu = compute_indices(descriptor_settings); + std::cout << "nu size: " << nu.size() << std::endl; + coeffs = compute_coeffs(descriptor_settings[1], descriptor_settings[3]); + + set_radial_basis(radial_basis, this->radial_pointer); + set_cutoff(cutoff_function, this->cutoff_pointer); + + // Assign cutoff matrix. + this->cutoffs = cutoffs; +} + +void Bk ::write_to_file(std::ofstream &coeff_file, int coeff_size) { + int n_species = descriptor_settings[0]; + int K = descriptor_settings[1]; + int n_max = descriptor_settings[2]; + int l_max = descriptor_settings[3]; + + coeff_file << "\n" << "B" << K << "\n"; + + // Report radial basis set. + coeff_file << radial_basis << "\n"; + + // Record number of species, nmax, lmax, and the cutoff. + double cutoff = radial_hyps[1]; + + coeff_file << n_species << " " << K << " " << n_max << " " << l_max << " "; + coeff_file << coeff_size << "\n"; + coeff_file << cutoff_function << "\n"; + + // Report cutoffs to 2 decimal places. + coeff_file << std::fixed << std::setprecision(2); + for (int i = 0; i < n_species; i ++){ + for (int j = 0; j < n_species; j ++){ + coeff_file << cutoffs(i, j) << " "; + } + } + coeff_file << "\n"; +} + +DescriptorValues Bk ::compute_struc(Structure &structure) { + + // Initialize descriptor values. + DescriptorValues desc = DescriptorValues(); + + // Compute single bond values. + Eigen::MatrixXcd single_bond_vals, force_dervs; + Eigen::MatrixXd neighbor_coords; + Eigen::VectorXi unique_neighbor_count, cumulative_neighbor_count, + descriptor_indices; + + int nos = descriptor_settings[0]; + int K = descriptor_settings[1]; + int N = descriptor_settings[2]; + int lmax = descriptor_settings[3]; + + complex_single_bond(single_bond_vals, force_dervs, neighbor_coords, + unique_neighbor_count, cumulative_neighbor_count, + descriptor_indices, radial_pointer, cutoff_pointer, nos, + N, lmax, radial_hyps, cutoff_hyps, structure, cutoffs); + + // Compute descriptor values. + Eigen::MatrixXd Bk_vals, Bk_force_dervs; + Eigen::VectorXd Bk_norms, Bk_force_dots; + + compute_Bk(Bk_vals, Bk_force_dervs, Bk_norms, Bk_force_dots, single_bond_vals, + force_dervs, unique_neighbor_count, cumulative_neighbor_count, + descriptor_indices, nu, nos, K, N, lmax, coeffs); + + // Gather species information. + int noa = structure.noa; + Eigen::VectorXi species_count = Eigen::VectorXi::Zero(nos); + Eigen::VectorXi neighbor_count = Eigen::VectorXi::Zero(nos); + for (int i = 0; i < noa; i++) { + int s = structure.species[i]; + int n_neigh = unique_neighbor_count(i); + species_count(s)++; + neighbor_count(s) += n_neigh; + } + + // Initialize arrays. + int n_d = Bk_vals.cols(); + desc.n_descriptors = n_d; + desc.n_types = nos; + desc.n_atoms = noa; + desc.volume = structure.volume; + desc.cumulative_type_count.push_back(0); + for (int s = 0; s < nos; s++) { + int n_s = species_count(s); + int n_neigh = neighbor_count(s); + + // Record species and neighbor count. + desc.n_clusters_by_type.push_back(n_s); + desc.cumulative_type_count.push_back(desc.cumulative_type_count[s] + n_s); + desc.n_clusters += n_s; + desc.n_neighbors_by_type.push_back(n_neigh); + + desc.descriptors.push_back(Eigen::MatrixXd::Zero(n_s, n_d)); + desc.descriptor_force_dervs.push_back( + Eigen::MatrixXd::Zero(n_neigh * 3, n_d)); + desc.neighbor_coordinates.push_back(Eigen::MatrixXd::Zero(n_neigh, 3)); + + desc.cutoff_values.push_back(Eigen::VectorXd::Ones(n_s)); + desc.cutoff_dervs.push_back(Eigen::VectorXd::Zero(n_neigh * 3)); + desc.descriptor_norms.push_back(Eigen::VectorXd::Zero(n_s)); + desc.descriptor_force_dots.push_back(Eigen::VectorXd::Zero(n_neigh * 3)); + + desc.neighbor_counts.push_back(Eigen::VectorXi::Zero(n_s)); + desc.cumulative_neighbor_counts.push_back(Eigen::VectorXi::Zero(n_s)); + desc.atom_indices.push_back(Eigen::VectorXi::Zero(n_s)); + desc.neighbor_indices.push_back(Eigen::VectorXi::Zero(n_neigh)); + } + + // Assign to structure. + Eigen::VectorXi species_counter = Eigen::VectorXi::Zero(nos); + Eigen::VectorXi neighbor_counter = Eigen::VectorXi::Zero(nos); + for (int i = 0; i < noa; i++) { + int s = structure.species[i]; + int s_count = species_counter(s); + int n_neigh = unique_neighbor_count(i); + int n_count = neighbor_counter(s); + int cum_neigh = cumulative_neighbor_count(i); + + desc.descriptors[s].row(s_count) = Bk_vals.row(i); + desc.descriptor_force_dervs[s].block(n_count * 3, 0, n_neigh * 3, n_d) = + Bk_force_dervs.block(cum_neigh * 3, 0, n_neigh * 3, n_d); + desc.neighbor_coordinates[s].block(n_count, 0, n_neigh, 3) = + neighbor_coords.block(cum_neigh, 0, n_neigh, 3); + + desc.descriptor_norms[s](s_count) = Bk_norms(i); + desc.descriptor_force_dots[s].segment(n_count * 3, n_neigh * 3) = + Bk_force_dots.segment(cum_neigh * 3, n_neigh * 3); + + desc.neighbor_counts[s](s_count) = n_neigh; + desc.cumulative_neighbor_counts[s](s_count) = n_count; + desc.atom_indices[s](s_count) = i; + desc.neighbor_indices[s].segment(n_count, n_neigh) = + descriptor_indices.segment(cum_neigh, n_neigh); + + species_counter(s)++; + neighbor_counter(s) += n_neigh; + } + + return desc; +} + +void compute_Bk(Eigen::MatrixXd &Bk_vals, Eigen::MatrixXd &Bk_force_dervs, + Eigen::VectorXd &Bk_norms, Eigen::VectorXd &Bk_force_dots, + const Eigen::MatrixXcd &single_bond_vals, + const Eigen::MatrixXcd &single_bond_force_dervs, + const Eigen::VectorXi &unique_neighbor_count, + const Eigen::VectorXi &cumulative_neighbor_count, + const Eigen::VectorXi &descriptor_indices, + std::vector> nu, int nos, int K, int N, + int lmax, const Eigen::VectorXd &coeffs) { + + int n_atoms = single_bond_vals.rows(); + int n_neighbors = cumulative_neighbor_count(n_atoms); + + // The value of last counter is the number of descriptors + std::vector last_index = nu[nu.size()-1]; + int n_d = last_index[last_index.size()-1] + 1; + + // Initialize arrays. + Bk_vals = Eigen::MatrixXd::Zero(n_atoms, n_d); + Bk_force_dervs = Eigen::MatrixXd::Zero(n_neighbors * 3, n_d); + Bk_norms = Eigen::VectorXd::Zero(n_atoms); + Bk_force_dots = Eigen::VectorXd::Zero(n_neighbors * 3); + +#pragma omp parallel for + for (int atom = 0; atom < n_atoms; atom++) { + int n_atom_neighbors = unique_neighbor_count(atom); + int force_start = cumulative_neighbor_count(atom) * 3; + for (int i = 0; i < nu.size(); i++) { + std::vector nu_list = nu[i]; + std::vector single_bond_index = std::vector(nu_list.end() - 2 - K, nu_list.end() - 2); // Get n1_l, n2_l, n3_l, etc. + // Forward + std::complex A_fwd = 1; + Eigen::VectorXcd dA = Eigen::VectorXcd::Ones(K); + for (int t = 0; t < K - 1; t++) { + A_fwd *= single_bond_vals(atom, single_bond_index[t]); + dA(t + 1) *= A_fwd; + } + // Backward + std::complex A_bwd = 1; + for (int t = K - 1; t > 0; t--) { + A_bwd *= single_bond_vals(atom, single_bond_index[t]); + dA(t - 1) *= A_bwd; + } + std::complex A = A_fwd * single_bond_vals(atom, single_bond_index[K - 1]); + + int counter = nu_list[nu_list.size() - 1]; + int m_index = nu_list[nu_list.size() - 2]; + Bk_vals(atom, counter) += real(coeffs(m_index) * A); + + // Store force derivatives. + for (int n = 0; n < n_atom_neighbors; n++) { + for (int comp = 0; comp < 3; comp++) { + int ind = force_start + n * 3 + comp; + std::complex dA_dr = 0; + for (int t = 0; t < K; t++) { + dA_dr += dA(t) * single_bond_force_dervs(ind, single_bond_index[t]); + } + Bk_force_dervs(ind, counter) += + real(coeffs(m_index) * dA_dr); + } + } + } + // Compute descriptor norm and force dot products. + Bk_norms(atom) = sqrt(Bk_vals.row(atom).dot(Bk_vals.row(atom))); + Bk_force_dots.segment(force_start, n_atom_neighbors * 3) = + Bk_force_dervs.block(force_start, 0, n_atom_neighbors * 3, n_d) * + Bk_vals.row(atom).transpose(); + } +} + +void complex_single_bond( + Eigen::MatrixXcd &single_bond_vals, Eigen::MatrixXcd &force_dervs, + Eigen::MatrixXd &neighbor_coordinates, Eigen::VectorXi &neighbor_count, + Eigen::VectorXi &cumulative_neighbor_count, + Eigen::VectorXi &neighbor_indices, + std::function &, std::vector &, double, + int, std::vector)> + radial_function, + std::function &, double, double, + std::vector)> + cutoff_function, + int nos, int N, int lmax, const std::vector &radial_hyps, + const std::vector &cutoff_hyps, const Structure &structure, + const Eigen::MatrixXd &cutoffs) { + + int n_atoms = structure.noa; + int n_neighbors = structure.n_neighbors; + + // Count atoms inside the descriptor cutoff. + neighbor_count = Eigen::VectorXi::Zero(n_atoms); + Eigen::VectorXi store_neighbors = Eigen::VectorXi::Zero(n_neighbors); +#pragma omp parallel for + for (int i = 0; i < n_atoms; i++) { + int i_neighbors = structure.neighbor_count(i); + int rel_index = structure.cumulative_neighbor_count(i); + int central_species = structure.species[i]; + for (int j = 0; j < i_neighbors; j++) { + int current_count = neighbor_count(i); + int neigh_index = rel_index + j; + int neighbor_species = structure.neighbor_species(neigh_index); + double rcut = cutoffs(central_species, neighbor_species); + double r = structure.relative_positions(neigh_index, 0); + // Check that atom is within descriptor cutoff. + if (r <= rcut) { + int struc_index = structure.structure_indices(neigh_index); + // Update neighbor list. + store_neighbors(rel_index + current_count) = struc_index; + neighbor_count(i)++; + } + } + } + + // Count cumulative number of unique neighbors. + cumulative_neighbor_count = Eigen::VectorXi::Zero(n_atoms + 1); + for (int i = 1; i < n_atoms + 1; i++) { + cumulative_neighbor_count(i) += + cumulative_neighbor_count(i - 1) + neighbor_count(i - 1); + } + + // Record neighbor indices. + int bond_neighbors = cumulative_neighbor_count(n_atoms); + neighbor_indices = Eigen::VectorXi::Zero(bond_neighbors); +#pragma omp parallel for + for (int i = 0; i < n_atoms; i++) { + int i_neighbors = neighbor_count(i); + int ind1 = cumulative_neighbor_count(i); + int ind2 = structure.cumulative_neighbor_count(i); + for (int j = 0; j < i_neighbors; j++) { + neighbor_indices(ind1 + j) = store_neighbors(ind2 + j); + } + } + + // Initialize single bond arrays. + int number_of_harmonics = (lmax + 1) * (lmax + 1); + int no_bond_vals = N * number_of_harmonics; + int single_bond_size = no_bond_vals * nos; + + single_bond_vals = Eigen::MatrixXcd::Zero(n_atoms, single_bond_size); + force_dervs = Eigen::MatrixXcd::Zero(bond_neighbors * 3, single_bond_size); + neighbor_coordinates = Eigen::MatrixXd::Zero(bond_neighbors, 3); + +#pragma omp parallel for + for (int i = 0; i < n_atoms; i++) { + int i_neighbors = structure.neighbor_count(i); + int rel_index = structure.cumulative_neighbor_count(i); + int neighbor_index = cumulative_neighbor_count(i); + int central_species = structure.species[i]; + + // Initialize radial hyperparameters. + std::vector new_radial_hyps = radial_hyps; + + // Initialize radial and spherical harmonic vectors. + std::vector g = std::vector(N, 0); + std::vector gx = std::vector(N, 0); + std::vector gy = std::vector(N, 0); + std::vector gz = std::vector(N, 0); + + Eigen::VectorXcd h, hx, hy, hz; + + double x, y, z, r, g_val, gx_val, gy_val, gz_val; + std::complex bond, bond_x, bond_y, bond_z, h_val; + int s, neigh_index, descriptor_counter, unique_ind; + for (int j = 0; j < i_neighbors; j++) { + neigh_index = rel_index + j; + int neighbor_species = structure.neighbor_species(neigh_index); + double rcut = cutoffs(central_species, neighbor_species); + r = structure.relative_positions(neigh_index, 0); + if (r > rcut) + continue; // Skip if outside cutoff. + x = structure.relative_positions(neigh_index, 1); + y = structure.relative_positions(neigh_index, 2); + z = structure.relative_positions(neigh_index, 3); + s = structure.neighbor_species(neigh_index); + + // Reset the endpoint of the radial basis set. + new_radial_hyps[1] = rcut; + + // Store neighbor coordinates. + neighbor_coordinates(neighbor_index, 0) = x; + neighbor_coordinates(neighbor_index, 1) = y; + neighbor_coordinates(neighbor_index, 2) = z; + + // Compute radial basis values and spherical harmonics. + calculate_radial(g, gx, gy, gz, radial_function, cutoff_function, x, y, z, + r, rcut, N, new_radial_hyps, cutoff_hyps); + get_complex_Y(h, hx, hy, hz, x, y, z, lmax); + + // Store the products and their derivatives. + descriptor_counter = s * no_bond_vals; + + for (int radial_counter = 0; radial_counter < N; radial_counter++) { + // Retrieve radial values. + g_val = g[radial_counter]; + gx_val = gx[radial_counter]; + gy_val = gy[radial_counter]; + gz_val = gz[radial_counter]; + + for (int angular_counter = 0; angular_counter < number_of_harmonics; + angular_counter++) { + + // Compute single bond value. + h_val = h(angular_counter); + bond = g_val * h_val; + + // Calculate derivatives with the product rule. + bond_x = gx_val * h_val + g_val * hx(angular_counter); + bond_y = gy_val * h_val + g_val * hy(angular_counter); + bond_z = gz_val * h_val + g_val * hz(angular_counter); + + // Update single bond arrays. + single_bond_vals(i, descriptor_counter) += bond; + + force_dervs(neighbor_index * 3, descriptor_counter) += bond_x; + force_dervs(neighbor_index * 3 + 1, descriptor_counter) += bond_y; + force_dervs(neighbor_index * 3 + 2, descriptor_counter) += bond_z; + + descriptor_counter++; + } + } + neighbor_index++; + } + } +} + +void to_json(nlohmann::json& j, const Bk & p){ + j = nlohmann::json{ + {"radial_basis", p.radial_basis}, + {"cutoff_function", p.cutoff_function}, + {"radial_hyps", p.radial_hyps}, + {"cutoff_hyps", p.cutoff_hyps}, + {"descriptor_settings", p.descriptor_settings}, + {"cutoffs", p.cutoffs}, + {"descriptor_name", p.descriptor_name} + }; +} + +void from_json(const nlohmann::json& j, Bk & p){ + p = Bk( + j.at("radial_basis"), + j.at("cutoff_function"), + j.at("radial_hyps"), + j.at("radial_hyps"), + j.at("descriptor_settings"), + j.at("cutoffs") + ); +} + +nlohmann::json Bk ::return_json(){ + nlohmann::json j; + to_json(j, *this); + return j; +} diff --git a/src/flare_pp/descriptors/bk.h b/src/flare_pp/descriptors/bk.h new file mode 100644 index 00000000..a00bec6d --- /dev/null +++ b/src/flare_pp/descriptors/bk.h @@ -0,0 +1,87 @@ +#ifndef BK_H +#define BK_H + +#include "descriptor.h" +#include +#include +#include +#include "json.h" + +class Structure; + +class Bk : public Descriptor { +public: + std::function &, std::vector &, double, int, + std::vector)> + radial_pointer; + std::function &, double, double, + std::vector)> + cutoff_pointer; + std::string radial_basis, cutoff_function; + std::vector radial_hyps, cutoff_hyps; + std::vector descriptor_settings; + Eigen::VectorXd coeffs; + std::vector> nu; + + std::string descriptor_name = "Bk"; + + /** Matrix of cutoff values, with element (i, j) corresponding to the cutoff + * assigned to the species pair (i, j), where i is the central species + * and j is the environment species. + */ + Eigen::MatrixXd cutoffs; + + Bk(); + + Bk(const std::string &radial_basis, const std::string &cutoff_function, + const std::vector &radial_hyps, + const std::vector &cutoff_hyps, + const std::vector &descriptor_settings); + + /** + * Construct the Bk descriptor with distinct cutoffs for each pair of species. + */ + Bk(const std::string &radial_basis, const std::string &cutoff_function, + const std::vector &radial_hyps, + const std::vector &cutoff_hyps, + const std::vector &descriptor_settings, + const Eigen::MatrixXd &cutoffs); + + DescriptorValues compute_struc(Structure &structure); + void write_to_file(std::ofstream &coeff_file, int coeff_size); + nlohmann::json return_json(); +}; + +void compute_Bk(Eigen::MatrixXd &Bk_vals, Eigen::MatrixXd &Bk_force_dervs, + Eigen::VectorXd &Bk_norms, Eigen::VectorXd &Bk_force_dots, + const Eigen::MatrixXcd &single_bond_vals, + const Eigen::MatrixXcd &single_bond_force_dervs, + const Eigen::VectorXi &unique_neighbor_count, + const Eigen::VectorXi &cumulative_neighbor_count, + const Eigen::VectorXi &descriptor_indices, + std::vector> nu, int nos, int K, int N, + int lmax, const Eigen::VectorXd &coeffs); + +/** + * Compute single bond vector with different cutoffs assigned to different + * pairs of elements. + */ +void complex_single_bond( + Eigen::MatrixXcd &single_bond_vals, Eigen::MatrixXcd &force_dervs, + Eigen::MatrixXd &neighbor_coordinates, Eigen::VectorXi &neighbor_count, + Eigen::VectorXi &cumulative_neighbor_count, + Eigen::VectorXi &neighbor_indices, + std::function &, std::vector &, double, + int, std::vector)> + radial_function, + std::function &, double, double, + std::vector)> + cutoff_function, + int nos, int N, int lmax, const std::vector &radial_hyps, + const std::vector &cutoff_hyps, const Structure &structure, + const Eigen::MatrixXd &cutoffs); + +void to_json(nlohmann::json& j, const Bk & p); +void from_json(const nlohmann::json& j, Bk & p); + +#endif diff --git a/src/flare_pp/descriptors/wigner3j.cpp b/src/flare_pp/descriptors/coeffs.cpp similarity index 99% rename from src/flare_pp/descriptors/wigner3j.cpp rename to src/flare_pp/descriptors/coeffs.cpp index d3ad1bb6..05690aed 100644 --- a/src/flare_pp/descriptors/wigner3j.cpp +++ b/src/flare_pp/descriptors/coeffs.cpp @@ -1,8 +1,33 @@ -#include "wigner3j.h" +#include "coeffs.h" #include -// See compute_wigner.py for the calculation of these coefficients. -Eigen::VectorXd compute_coeffs(int lmax) { +Eigen::VectorXd compute_coeffs(int K, int lmax) { + if (K == 1) { + return coeffs_K1(lmax); + } else if (K == 2) { + return coeffs_K2(lmax); + } else if (K == 3) { + return coeffs_K3(lmax); + } else { + std::cout << "Not implemented." << std::endl; + } + +} + +Eigen::VectorXd coeffs_K1(int lmax){ + Eigen::VectorXd k1_coef = Eigen::VectorXd::Ones(1); + return k1_coef; +} + +Eigen::VectorXd coeffs_K2(int lmax){ + Eigen::VectorXd k2_coef = Eigen::VectorXd::Zero(2); + k2_coef(0) = 1; // TODO: need to check this + k2_coef(1) = -1; + return k2_coef; +} + +Eigen::VectorXd coeffs_K3(int lmax) { + // See compute_wigner.py for the calculation of these coefficients. Eigen::VectorXd wigner3j_coeffs = Eigen::VectorXd::Zero(pow((lmax + 1), 6)); if (lmax == 0) { wigner3j_coeffs(0) = 1; diff --git a/src/flare_pp/descriptors/coeffs.h b/src/flare_pp/descriptors/coeffs.h new file mode 100644 index 00000000..d9514ac9 --- /dev/null +++ b/src/flare_pp/descriptors/coeffs.h @@ -0,0 +1,14 @@ +#ifndef COEFFS +#define COEFFS +#include + +Eigen::VectorXd compute_coeffs(int K, int lmax); + +Eigen::VectorXd coeffs_K1(int lmax); +Eigen::VectorXd coeffs_K2(int lmax); + +// Wigner 3j coefficients generated for l = 0, 1, 2, 3 using +// sympy.physics.wigner.wigner_3j +Eigen::VectorXd coeffs_K3(int lmax); + +#endif diff --git a/src/flare_pp/descriptors/descriptor.h b/src/flare_pp/descriptors/descriptor.h index f7217c3d..6c7b7456 100644 --- a/src/flare_pp/descriptors/descriptor.h +++ b/src/flare_pp/descriptors/descriptor.h @@ -14,6 +14,7 @@ class Descriptor { Descriptor(); std::string descriptor_name; + std::vector descriptor_settings; virtual DescriptorValues compute_struc(Structure &structure) = 0; diff --git a/src/flare_pp/descriptors/indices.cpp b/src/flare_pp/descriptors/indices.cpp new file mode 100644 index 00000000..60d6866a --- /dev/null +++ b/src/flare_pp/descriptors/indices.cpp @@ -0,0 +1,99 @@ +#include "indices.h" +#include +#include +#include + +std::vector> compute_indices(const std::vector &descriptor_settings) { + int nos = descriptor_settings[0]; + int K = descriptor_settings[1]; + int nmax = descriptor_settings[2]; + int lmax = descriptor_settings[3]; + + int n_radial = nos * nmax; + if (K == 1) { + assert(lmax == 0); + return K1(n_radial); + } else if (K == 2) { + return K2(n_radial, lmax); + } else if (K == 3) { + return K3(n_radial, lmax); + } else { + return K3(n_radial, lmax); + } +} + +std::vector> K1(int n_radial) { + int n1, n1_l; + std::vector> index_list; + int counter = 0; + for (int n1 = 0; n1 < n_radial; n1++) { + n1_l = n1; + int m_index = 0; + index_list.push_back({n1, n1_l, m_index, counter}); + counter++; + } + return index_list; +} + +std::vector> K2(int n_radial, int lmax) { + int n1, n2, l, m, n1_l, n2_l; + int n_harmonics = (lmax + 1) * (lmax + 1); + std::vector> index_list; + int counter = 0; + for (int n1 = 0; n1 < n_radial; n1++) { + for (int n2 = n1; n2 < n_radial; n2++) { + for (int l = 0; l < (lmax + 1); l++) { + for (int m = 0; m < (2 * l + 1); m++) { + n1_l = n1 * n_harmonics + (l * l + m); + n2_l = n2 * n_harmonics + (l * l + (2 * l - m)); + int m_index = (m + l) % 2; + index_list.push_back({n1, n2, l, m, n1_l, n2_l, m_index, counter}); + } + counter++; + } + } + } + return index_list; +} + +std::vector> K3(int n_radial, int lmax) { + int n1, n2, n3, l1, l2, l3, m1, m2, m3, n1_l, n2_l, n3_l; + int n_harmonics = (lmax + 1) * (lmax + 1); + std::vector> index_list; + int counter = 0; + for (int n1 = 0; n1 < n_radial; n1++) { + for (int n2 = n1; n2 < n_radial; n2++) { + for (int n3 = n2; n3 < n_radial; n3++) { + for (int l1 = 0; l1 < (lmax + 1); l1++) { + int ind_1 = pow(lmax + 1, 4) * l1 * l1; + for (int l2 = 0; l2 < (lmax + 1); l2++) { + int ind_2 = ind_1 + pow(lmax + 1, 2) * l2 * l2 * (2 * l1 + 1); + for (int l3 = 0; l3 < (lmax + 1); l3++) { + if ((abs(l1 - l2) > l3) || (l3 > l1 + l2)) + continue; + int ind_3 = ind_2 + l3 * l3 * (2 * l2 + 1) * (2 * l1 + 1); + for (int m1 = 0; m1 < (2 * l1 + 1); m1++) { + n1_l = n1 * n_harmonics + (l1 * l1 + m1); + int ind_4 = ind_3 + m1 * (2 * l3 + 1) * (2 * l2 + 1); + for (int m2 = 0; m2 < (2 * l2 + 1); m2++) { + n2_l = n2 * n_harmonics + (l2 * l2 + m2); + int ind_5 = ind_4 + m2 * (2 * l3 + 1); + for (int m3 = 0; m3 < (2 * l3 + 1); m3++) { + if (m1 + m2 + m3 - l1 - l2 - l3 != 0) + continue; + n3_l = n3 * n_harmonics + (l3 * l3 + m3); + + int m_index = ind_5 + m3; + index_list.push_back({n1, n2, n3, l1, l2, l3, m1, m2, m3, n1_l, n2_l, n3_l, m_index, counter}); + } + } + } + counter++; + } + } + } + } + } + } + return index_list; +} diff --git a/src/flare_pp/descriptors/indices.h b/src/flare_pp/descriptors/indices.h new file mode 100644 index 00000000..31e32a31 --- /dev/null +++ b/src/flare_pp/descriptors/indices.h @@ -0,0 +1,12 @@ +#ifndef INDICES +#define INDICES +#include + +// Indices of (n1,l1,m1), ..., (nK,lK,mK) for B term from A_nlm + +std::vector> compute_indices(const std::vector &descriptor_settings); + +std::vector> K1(int n_radial); +std::vector> K2(int n_radial, int lmax); +std::vector> K3(int n_radial, int lmax); +#endif diff --git a/src/flare_pp/descriptors/wigner3j.h b/src/flare_pp/descriptors/wigner3j.h deleted file mode 100644 index 2261e3f5..00000000 --- a/src/flare_pp/descriptors/wigner3j.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef WIGNER3J -#define WIGNER3J -#include - -// Wigner 3j coefficients generated for l = 0, 1, 2, 3 using -// sympy.physics.wigner.wigner_3j - -Eigen::VectorXd compute_coeffs(int lmax); - -#endif diff --git a/src/flare_pp/kernels/kernel.h b/src/flare_pp/kernels/kernel.h index 433b3b06..ad8d66e4 100644 --- a/src/flare_pp/kernels/kernel.h +++ b/src/flare_pp/kernels/kernel.h @@ -49,15 +49,15 @@ class Kernel { virtual Eigen::MatrixXd compute_varmap_coefficients(const SparseGP &gp_model, int kernel_index) = 0; - std::vector Kuu_grad(const ClusterDescriptor &envs, - const Eigen::MatrixXd &Kuu, - const Eigen::VectorXd &hyps); - - std::vector Kuf_grad(const ClusterDescriptor &envs, - const std::vector &strucs, - int kernel_index, - const Eigen::MatrixXd &Kuf, - const Eigen::VectorXd &hyps); + virtual std::vector Kuu_grad(const ClusterDescriptor &envs, + const Eigen::MatrixXd &Kuu, + const Eigen::VectorXd &hyps); + + virtual std::vector Kuf_grad(const ClusterDescriptor &envs, + const std::vector &strucs, + int kernel_index, + const Eigen::MatrixXd &Kuf, + const Eigen::VectorXd &hyps); virtual void set_hyperparameters(Eigen::VectorXd hyps) = 0; diff --git a/src/flare_pp/kernels/normalized_dot_product.cpp b/src/flare_pp/kernels/normalized_dot_product.cpp index 24df45c8..1e42f94f 100644 --- a/src/flare_pp/kernels/normalized_dot_product.cpp +++ b/src/flare_pp/kernels/normalized_dot_product.cpp @@ -579,16 +579,11 @@ NormalizedDotProduct ::Kuu_grad(const ClusterDescriptor &envs, std::vector kernel_gradients; // Compute Kuu. - Eigen::MatrixXd Kuu_new = Kuu; - Kuu_new /= sig2; - Kuu_new *= new_hyps(0) * new_hyps(0); + Eigen::MatrixXd Kuu_new = Kuu * (new_hyps(0) * new_hyps(0) / sig2); + kernel_gradients.push_back(Kuu_new); // Compute sigma gradient. - Eigen::MatrixXd sigma_gradient = Kuu; - sigma_gradient /= sig2; - sigma_gradient *= 2 * new_hyps(0); - - kernel_gradients.push_back(Kuu_new); + Eigen::MatrixXd sigma_gradient = Kuu * (2 * new_hyps(0) / sig2); kernel_gradients.push_back(sigma_gradient); return kernel_gradients; @@ -603,15 +598,11 @@ NormalizedDotProduct ::Kuf_grad(const ClusterDescriptor &envs, std::vector kernel_gradients; // Compute Kuf. - Eigen::MatrixXd Kuf_new = Kuf; - Kuf_new /= sig2; - Kuf_new *= new_hyps(0) * new_hyps(0); + Eigen::MatrixXd Kuf_new = Kuf * (new_hyps(0) * new_hyps(0) / sig2); kernel_gradients.push_back(Kuf_new); // Compute sigma gradient. - Eigen::MatrixXd sigma_gradient = Kuf; - sigma_gradient /= sig2; - sigma_gradient *= 2 * new_hyps(0); + Eigen::MatrixXd sigma_gradient = Kuf * (2 * new_hyps(0) / sig2); kernel_gradients.push_back(sigma_gradient); return kernel_gradients; @@ -671,14 +662,15 @@ NormalizedDotProduct ::compute_mapping_coefficients(const SparseGP &gp_model, continue; double alpha_val = gp_model.alpha(alpha_ind + c_types + j); - int beta_count = 0; // First loop over descriptor values. +#pragma omp parallel for for (int k = 0; k < p_size; k++) { double p_ik = p_current(k) / p_norm; // Second loop over descriptor values. for (int l = k; l < p_size; l++) { + int beta_count = std::round((2 * p_size - k + 1) * k / 2 + l - k); double p_il = p_current(l) / p_norm; double beta_val = sig2 * p_ik * p_il * alpha_val; @@ -689,7 +681,6 @@ NormalizedDotProduct ::compute_mapping_coefficients(const SparseGP &gp_model, mapping_coeffs(i, beta_count) += beta_val; } - beta_count++; } } } @@ -753,14 +744,15 @@ Eigen::MatrixXd NormalizedDotProduct ::compute_varmap_coefficients( double Kuu_inv_ij_normed = Kuu_inv_ij / pi_norm / pj_norm; // double Sigma_ij = gp_model.Sigma(K_ind + i, K_ind + j); // double Sigma_ij_normed = Sigma_ij / pi_norm / pj_norm; - int beta_count = 0; // First loop over descriptor values. +#pragma omp parallel for for (int k = 0; k < p_size; k++) { double p_ik = pi_current(k); // Second loop over descriptor values. for (int l = 0; l < p_size; l++){ + int beta_count = k * p_size + l; double p_jl = pj_current(l); // Update beta vector. @@ -771,7 +763,6 @@ Eigen::MatrixXd NormalizedDotProduct ::compute_varmap_coefficients( mapping_coeffs(s, beta_count) += sig2; // the self kernel term } - beta_count++; } } } diff --git a/src/flare_pp/structure.cpp b/src/flare_pp/structure.cpp index 7dbdc3de..413f7c74 100644 --- a/src/flare_pp/structure.cpp +++ b/src/flare_pp/structure.cpp @@ -164,6 +164,24 @@ double Structure ::get_single_sweep_cutoff() { return single_sweep_cutoff; } +int Structure::n_energy() const { + return energy.size(); +} + +int Structure::n_forces() const { + return forces.size(); +} + +int Structure::n_stresses() const { + return stresses.size(); +} + +int Structure::n_labels() const { + int n_energy = energy.size(); + int n_forces = forces.size(); + int n_stresses = stresses.size(); + return n_energy + n_forces + n_stresses; +} void Structure ::to_json(std::string file_name, const Structure & struc){ std::ofstream struc_file(file_name); diff --git a/src/flare_pp/structure.h b/src/flare_pp/structure.h index 428edcf5..da20c7b5 100644 --- a/src/flare_pp/structure.h +++ b/src/flare_pp/structure.h @@ -100,6 +100,10 @@ class Structure { double get_single_sweep_cutoff(); void compute_neighbors(); void compute_descriptors(); + int n_energy() const; + int n_forces() const; + int n_stresses() const; + int n_labels() const; NLOHMANN_DEFINE_TYPE_INTRUSIVE(Structure, neighbor_count, cutoff, cumulative_neighbor_count, structure_indices, neighbor_species, diff --git a/src/flare_pp/utils.cpp b/src/flare_pp/utils.cpp new file mode 100644 index 00000000..275541c9 --- /dev/null +++ b/src/flare_pp/utils.cpp @@ -0,0 +1,186 @@ +#include "utils.h" +#include + +#define MAXLINE 1024 + + +template +void utils::split(const std::string &s, char delim, Out result) { + std::istringstream iss(s); + std::string item; + while (std::getline(iss, item, delim)) { + if (item.length() > 0) *result++ = item; + } +} + +std::vector utils::split(const std::string &s, char delim) { + /* Convert a line of string into a list + * Similar to the python method str.split() + */ + std::vector elems; + split(s, delim, std::back_inserter(elems)); + return elems; +} + +std::tuple, std::vector>>> +utils::read_xyz(std::string filename, std::map species_map) { + + std::ifstream file(filename); + int n_atoms, atom_ind; + Eigen::MatrixXd cell, positions; + Eigen::VectorXd energy, forces, stress; + std::vector species; + std::vector sparse_inds; + int pos_col = 0; + int forces_col = 0; + + std::vector structure_list; + std::vector> sparse_inds_0; + std::vector values; + int new_frame_line = 0; + + int i; + + if (file.is_open()) { + std::string line; + while (std::getline(file, line)) { + values = split(line, ' '); + if (values.size() == 1) { + // the 1st line of a block, the number of atoms in a frame + n_atoms = std::stoi(values[0]); + cell = Eigen::MatrixXd::Zero(3, 3); + positions = Eigen::MatrixXd::Zero(n_atoms, 3); + forces = Eigen::VectorXd::Zero(n_atoms * 3); + energy = Eigen::VectorXd::Zero(1);; + stress = Eigen::VectorXd::Zero(6); + species = std::vector(n_atoms, 0); + sparse_inds = {}; + atom_ind = 0; + pos_col = 0; + forces_col = 0; + new_frame_line = 0; + } else if (new_frame_line == 0) { + // the 2nd line of a block, including cell, energy, stress, sparse indices + int v = 0; + while (v < values.size()) { + if (values[v].find(std::string("Lattice")) != std::string::npos) { + // Example: Lattice="1 0 0 0 1 0 0 0 1" + cell(0, 0) = std::stod(values[v].substr(9, values[v].length() - 9)); + cell(0, 1) = std::stod(values[v + 1]); + cell(0, 2) = std::stod(values[v + 2]); + cell(1, 0) = std::stod(values[v + 3]); + cell(1, 1) = std::stod(values[v + 4]); + cell(1, 2) = std::stod(values[v + 5]); + cell(2, 0) = std::stod(values[v + 6]); + cell(2, 1) = std::stod(values[v + 7]); + cell(2, 2) = std::stod(values[v + 8].substr(0, values[v + 8].length() - 1)); + v += 9; + } else if (values[v].find(std::string("energy")) != std::string::npos \ + && values[v].find(std::string("free_energy")) == std::string::npos) { + // Example: energy=-2.0 + energy(0) = std::stod(values[v].substr(7, values[v].length() - 7)); + v++; + } else if (values[v].find(std::string("stress")) != std::string::npos) { + // Example: stress="1 0 0 0 1 0 0 0 1" + stress(0) = - std::stod(values[v].substr(8, values[v].length() - 8)); // xx + stress(1) = - std::stod(values[v + 1]); // xy + stress(2) = - std::stod(values[v + 2]); // xz + stress(3) = - std::stod(values[v + 4]); // yy + stress(4) = - std::stod(values[v + 5]); // yz + stress(5) = - std::stod(values[v + 8].substr(0, values[v + 8].length() - 1)); // zz + v += 9; + } else if (values[v].find(std::string("sparse_indices")) != std::string::npos) { + // Example: sparse_indices="0 2 4 6" or sparse_indices="2" + size_t n = std::count(values[v].begin(), values[v].end(), '"'); + assert(n <= 1); + if (n == 0) { // Example: sparse_indices=2 or sparse_indices= + if (values[v].length() > 15) { + sparse_inds.push_back(std::stoi(values[v].substr(15, values[v].length() - 15))); + } + v++; + } else if (n == 1) { // Example: sparse_indices="0 2 4 6" + sparse_inds.push_back(std::stoi(values[v].substr(16, values[v].length() - 16))); + v++; + while (values[v].find(std::string("\"")) == std::string::npos) { + sparse_inds.push_back(std::stoi(values[v])); + v++; + } + sparse_inds.push_back(std::stoi(values[v].substr(0, values[v].length() - 1))); + v++; + } + } else if (values[v].find(std::string("Properties")) != std::string::npos) { + // Example: Properties=species:S:1:pos:R:3:forces:R:3:magmoms:R:1 + std::string str = values[v]; + std::vector props = split(str, ':'); + bool find_pos = false; + bool find_forces = false; + for (int p = 0; p < props.size(); p += 3) { + // Find the starting column of positions + if (props[p].find(std::string("pos")) == std::string::npos) { + if (!find_pos) pos_col += std::stoi(props[p + 2]); + } else { + find_pos = true; + } + // Find the starting column of forces + if (props[p].find(std::string("forces")) == std::string::npos) { + if (!find_forces) forces_col += std::stoi(props[p + 2]); + } else { + find_forces = true; + } + } + v++; + } else { + v++; + } + } + new_frame_line = 1; + } else if (new_frame_line > 0) { + // the rest n_atoms lines of a block, with format "symbol x y z fx fy fz" + species[atom_ind] = species_map[values[0]]; + positions(atom_ind, 0) = std::stod(values[pos_col + 0]); + positions(atom_ind, 1) = std::stod(values[pos_col + 1]); + positions(atom_ind, 2) = std::stod(values[pos_col + 2]); + forces(3 * atom_ind + 0) = std::stod(values[forces_col + 0]); + forces(3 * atom_ind + 1) = std::stod(values[forces_col + 1]); + forces(3 * atom_ind + 2) = std::stod(values[forces_col + 2]); + atom_ind++; + + if (new_frame_line == n_atoms) { + Structure structure(cell, species, positions); + structure.energy = energy; + structure.forces = forces; + structure.stresses = stress; + structure_list.push_back(structure); + sparse_inds_0.push_back(sparse_inds); // TODO: multiple kernels with different sparse inds + } + + new_frame_line++; + } else { + // raise error + printf("Unknown line!!!"); + } + } + file.close(); + } + std::vector>> sparse_inds_list; + sparse_inds_list.push_back(sparse_inds_0); + return std::make_tuple(structure_list, sparse_inds_list); +} + +utils::Timer::Timer() {} + +void utils::Timer::tic() { + t_start = std::chrono::high_resolution_clock::now(); +} + +void utils::Timer::toc(const char* code_name) { + t_end = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t_end - t_start ).count(); + std::cout << "Time: " << code_name << " " << duration << " ms" << std::endl; +} + +void utils::Timer::toc(const char* code_name, int rank) { + t_end = std::chrono::high_resolution_clock::now(); + duration = (double) std::chrono::duration_cast( t_end - t_start ).count(); + std::cout << "Rank " << rank << " Time: " << code_name << " " << duration << " ms" << std::endl; +} diff --git a/src/flare_pp/utils.h b/src/flare_pp/utils.h new file mode 100644 index 00000000..a5ef12ab --- /dev/null +++ b/src/flare_pp/utils.h @@ -0,0 +1,63 @@ +#include "structure.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef UTILS_H +#define UTILS_H + +namespace utils { + + /** + Read an .xyz file and return a list of `Structure` objects, with a list of + indices of sparse atoms in the structures. + + @param filename A string to specify the name of the .xyz file, which contains + all the DFT frames with energy, forces and stress. + @param species_map A `std::map` object that maps chemical symbol (string) to + species code in the `Structure` object. E.g. + ``` + std::map species_map = {{"H", 0,}, {"He", 1,}}; + ``` + @return A tuple of `Structure` list and sparse indices list. + */ + std::tuple, std::vector>>> + read_xyz(std::string filename, std::map species_map); + + template + void split(const std::string &s, char delim, Out result); + + /** + A useful function that mimic Python's `split` method of a string. + + @param s A string of type `std::string`. + @param delim The delimiter to separate the string `s`. + @return A list of strings separated by the delimiter. + */ + std::vector split(const std::string &s, char delim); + + class Timer; +} + +class utils::Timer { +public: + Timer(); + + double duration = 0; + std::chrono::high_resolution_clock::time_point t_start, t_end; + void tic(); + void toc(const char*); + void toc(const char*, int rank); +}; + +#endif diff --git a/tests/get_sgp.py b/tests/get_sgp.py index 3230ea81..d08c3966 100644 --- a/tests/get_sgp.py +++ b/tests/get_sgp.py @@ -1,38 +1,56 @@ import numpy as np -from flare_pp._C_flare import SparseGP, NormalizedDotProduct, B2, Structure +from flare_pp._C_flare import SparseGP, NormalizedDotProduct, Bk, Structure from flare_pp.sparse_gp import SGP_Wrapper from flare_pp.sparse_gp_calculator import SGP_Calculator +from flare_pp.parallel_sgp import ParSGP_Wrapper from flare.ase.atoms import FLARE_Atoms from ase import Atoms from ase.calculators.lj import LennardJones +from ase.calculators.singlepoint import SinglePointCalculator from ase.build import make_supercell -# Define kernel. +# Define kernels. +sigma = 1.0 +power = 2 +kernel1 = NormalizedDotProduct(sigma, power) + sigma = 2.0 power = 2 -kernel = NormalizedDotProduct(sigma, power) +kernel2 = NormalizedDotProduct(sigma, power) + +sigma = 3.0 +power = 2 +kernel3 = NormalizedDotProduct(sigma, power) + +# Define calculators. +species_map = {6: 0, 8: 1} -# Define B2 calculator. cutoff = 5.0 cutoff_function = "quadratic" radial_basis = "chebyshev" radial_hyps = [0.0, cutoff] cutoff_hyps = [] -descriptor_settings = [2, 8, 3] -b2_calc = B2(radial_basis, cutoff_function, radial_hyps, cutoff_hyps, - descriptor_settings) + +settings = [len(species_map), 1, 4, 0] +calc1 = Bk(radial_basis, cutoff_function, radial_hyps, cutoff_hyps, settings) + +settings = [len(species_map), 2, 4, 3] +calc2 = Bk(radial_basis, cutoff_function, radial_hyps, cutoff_hyps, settings) + +settings = [len(species_map), 3, 2, 2] +calc3 = Bk(radial_basis, cutoff_function, radial_hyps, cutoff_hyps, settings) # Define remaining parameters for the SGP wrapper. sigma_e = 0.001 sigma_f = 0.05 sigma_s = 0.006 -species_map = {6: 0, 8: 1} single_atom_energies = {0: -5, 1: -6} variance_type = "local" max_iterations = 20 -opt_method = "L-BFGS-B" -bounds = [(None, None), (sigma_e, None), (None, None), (None, None)] +#opt_method = "L-BFGS-B" +opt_method = "BFGS" +bounds = [(None, None), (None, None), (None, None), (sigma_e, None), (None, None), (None, None)] def get_random_atoms(a=2.0, sc_size=2, numbers=[6, 8], @@ -50,21 +68,38 @@ def get_random_atoms(a=2.0, sc_size=2, numbers=[6, 8], multiplier = np.identity(3) * sc_size atoms = make_supercell(unit_cell, multiplier) atoms.positions += (2 * np.random.rand(len(atoms), 3) - 1) * 0.1 + # modify the symbols + random_numbers = np.array(numbers)[np.random.choice(2, len(atoms), replace=True)].tolist() + atoms.numbers = random_numbers flare_atoms = FLARE_Atoms.from_ase_atoms(atoms) + calc = SinglePointCalculator(flare_atoms) + calc.results["energy"] = np.random.rand() + calc.results["forces"] = np.random.rand(len(atoms), 3) + calc.results["stress"] = np.random.rand(6) + flare_atoms.calc = calc + return flare_atoms def get_empty_sgp(): empty_sgp = SGP_Wrapper( - [kernel], [b2_calc], cutoff, sigma_e, sigma_f, sigma_s, species_map, - single_atom_energies=single_atom_energies, variance_type=variance_type, - opt_method=opt_method, bounds=bounds, max_iterations=max_iterations + [kernel1, kernel2, kernel3], + [calc1, calc2, calc3], + cutoff, + sigma_e, + sigma_f, + sigma_s, + species_map, + single_atom_energies=single_atom_energies, + variance_type=variance_type, + opt_method=opt_method, + bounds=bounds, + max_iterations=max_iterations, ) return empty_sgp - def get_updated_sgp(): training_structure = get_random_atoms() training_structure.calc = LennardJones() @@ -74,8 +109,14 @@ def get_updated_sgp(): stress = training_structure.get_stress() sgp = get_empty_sgp() - sgp.update_db(training_structure, forces, custom_range=(1, 2, 3, 4, 5), - energy=energy, stress=stress, mode="specific") + sgp.update_db( + training_structure, + forces, + custom_range=[(1, 2, 3, 4, 5) for k in range(len(sgp.descriptor_calculators))], + energy=energy, + stress=stress, + mode="specific", + ) return sgp @@ -85,3 +126,36 @@ def get_sgp_calc(): sgp_calc = SGP_Calculator(sgp) return sgp_calc + + +def get_empty_parsgp(): + empty_sgp = ParSGP_Wrapper( + [kernel1, kernel2, kernel3], + [calc1, calc2, calc3], + cutoff, + sigma_e, + sigma_f, + sigma_s, + species_map, + single_atom_energies=single_atom_energies, + variance_type=variance_type, + opt_method=opt_method, + bounds=bounds, + max_iterations=max_iterations, + ) + + return empty_sgp + +def get_training_data(): + # Make random structure. + sgp = get_empty_sgp() + n_frames = 5 + training_strucs = [] + training_sparse_indices = [[] for i in range(len(sgp.descriptor_calculators))] + for n in range(n_frames): + train_structure = get_random_atoms(a=2.0, sc_size=2, numbers=list(species_map.keys())) + n_atoms = len(train_structure) + training_strucs.append(train_structure) + for k in range(len(sgp.descriptor_calculators)): + training_sparse_indices[k].append(np.random.randint(0, n_atoms, n_atoms // 2).tolist()) + return training_strucs, training_sparse_indices diff --git a/tests/test_parallel_sgp.py b/tests/test_parallel_sgp.py new file mode 100644 index 00000000..d4ae27fe --- /dev/null +++ b/tests/test_parallel_sgp.py @@ -0,0 +1,119 @@ +import os, sys, shutil, time +import numpy as np +import pytest + +from ase.io import read, write +from ase.calculators import lammpsrun +from flare import struc +from flare.ase.atoms import FLARE_Atoms +from flare.lammps import lammps_calculator + +from flare_pp.sparse_gp import SGP_Wrapper +from flare_pp.sparse_gp_calculator import SGP_Calculator +from flare_pp.parallel_sgp import ParSGP_Wrapper +from flare_pp._C_flare import NormalizedDotProduct, Bk, SparseGP, Structure + +from .get_sgp import get_random_atoms, species_map, get_empty_parsgp, get_empty_sgp, get_training_data + +# we set the same seed for different ranks, +# so no need to broadcast the structures + +np.random.seed(10) + +def test_update_db(): + """Check that the covariance matrices have the correct size after the + sparse GP is updated.""" + + # build a non-empty parallel sgp + sgp = get_empty_parsgp() + training_strucs, training_sparse_indices = get_training_data() + sgp.build(training_strucs, training_sparse_indices, update=False) + + # add a new structure + train_structure = get_random_atoms(a=2.0, sc_size=2, numbers=list(species_map.keys())) + sgp.update_db(train_structure, custom_range=[1, 2, 3], mode="uncertain") + + u_size = 0 + for k in range(len(sgp.descriptor_calculators)): + u_size_kern = 0 + for inds in sgp.training_sparse_indices[k]: + u_size_kern += len(inds) + u_size += u_size_kern + + assert sgp.sparse_gp.Kuu.shape[0] == u_size + assert sgp.sparse_gp.alpha.shape[0] == u_size + + sgp.sparse_gp.finalize_MPI = False + +def test_train(): + """Check that the hyperparameters and likelihood are updated when the + train method is called.""" + + # TODO: add sparse_gp and compare the results + + #from flare_pp.sparse_gp import compute_negative_likelihood_grad_stable + #new_hyps = np.array(sgp.hyps) + 1 + ## + ##tic = time.time() + #compute_negative_likelihood_grad_stable(new_hyps, sgp.sparse_gp, precomputed=False) + ##toc = time.time() + #print("compute_negative_likelihood_grad_stable TIME:", toc - tic) + + # build a non-empty parallel sgp + sgp = get_empty_parsgp() + training_strucs, training_sparse_indices = get_training_data() + sgp.build(training_strucs, training_sparse_indices, update=False) + + hyps_init = tuple(sgp.hyps) + sgp.train() + hyps_post = tuple(sgp.hyps) + + #assert hyps_init != hyps_post + assert sgp.likelihood != 0.0 + + sgp.sparse_gp.finalize_MPI = False + +def test_predict(): + # build a non-empty parallel sgp + training_strucs, training_sparse_indices = get_training_data() + sgp = get_empty_parsgp() + sgp.build(training_strucs, training_sparse_indices, update=False) + + # build serial sgp with the same training data set + sgp_serial = get_empty_sgp() + for t in range(len(training_strucs)): + sgp_serial.update_db( + training_strucs[t], + training_strucs[t].forces, + custom_range=[training_sparse_indices[k][t] for k in range(len(sgp_serial.descriptor_calculators))], + energy=training_strucs[t].potential_energy, + stress=training_strucs[t].stress, + mode="specific", + ) + sgp_serial.sparse_gp.update_matrices_QR() + + # generate testing data + n_frames = 5 + test_strucs = [] + for n in range(n_frames): + atoms = get_random_atoms(a=2.0, sc_size=2, numbers=list(species_map.keys())) + test_strucs.append(atoms) + + # predict on testing data + sgp.predict_on_structures(test_strucs) + for n in range(n_frames): + par_energy = test_strucs[n].get_potential_energy() + par_forces = test_strucs[n].get_forces() + par_stress = test_strucs[n].get_stress() + + test_strucs[n].calc = SGP_Calculator(sgp_serial) + + ser_energy = test_strucs[n].get_potential_energy() + ser_forces = test_strucs[n].get_forces() + ser_stress = test_strucs[n].get_stress() + + np.allclose(par_energy, ser_energy) + np.allclose(par_forces, ser_forces) + np.allclose(par_stress, ser_stress) + + sgp.sparse_gp.finalize_MPI = True diff --git a/tests/test_sparse_gp.py b/tests/test_sparse_gp.py index 32b1b6ad..0bb4c6b4 100644 --- a/tests/test_sparse_gp.py +++ b/tests/test_sparse_gp.py @@ -8,9 +8,10 @@ from flare_pp.sparse_gp import SGP_Wrapper from flare_pp.sparse_gp_calculator import SGP_Calculator -from flare.ase.calculator import FLARE_Calculator from flare.ase.atoms import FLARE_Atoms +from ase.io import read, write +from ase.calculators import lammpsrun from ase.calculators.lj import LennardJones from ase.build import bulk @@ -23,7 +24,7 @@ def test_update_db(): sparse GP is updated.""" # Create a labelled structure. - custom_range = [1, 2, 3] + custom_range = [2, 2, 3] training_structure = get_random_atoms() training_structure.calc = LennardJones() forces = training_structure.get_forces() @@ -32,13 +33,18 @@ def test_update_db(): # Update the SGP. sgp = get_empty_sgp() - sgp.update_db(training_structure, forces, custom_range, energy, stress, - mode="specific" - ) + sgp.update_db( + training_structure, + forces, + custom_range, + energy, + stress, + mode="uncertain", + ) n_envs = len(custom_range) n_atoms = len(training_structure) - assert sgp.sparse_gp.Kuu.shape[0] == n_envs + assert sgp.sparse_gp.Kuu.shape[0] == np.sum(custom_range) assert sgp.sparse_gp.Kuf.shape[1] == 1 + n_atoms * 3 + 6 @@ -131,3 +137,169 @@ def test_write_model(): # Check that they're the same. max_abs_diff = np.max(np.abs(forces - forces_2)) assert max_abs_diff < 1e-8 + +def test_coeff(): + sgp_py = get_updated_sgp() + + # Dump potential coefficient file + sgp_py.write_mapping_coefficients("lmp.flare", "A", [0, 1, 2]) + + # Dump uncertainty coefficient file + # here the new kernel needs to be returned, otherwise the kernel won't be found in the current module + new_kern = sgp_py.write_varmap_coefficients("beta_var.txt", "B", [0, 1, 2]) + + assert ( + sgp_py.sparse_gp.sparse_indices[0] == sgp_py.sgp_var.sparse_indices[0] + ), "the sparse_gp and sgp_var don't have the same training data" + + for s in range(len(sgp_py.species_map)): + org_desc = sgp_py.sparse_gp.sparse_descriptors[0].descriptors[s] + new_desc = sgp_py.sgp_var.sparse_descriptors[0].descriptors[s] + if not np.allclose(org_desc, new_desc): # the atomic order might change + assert np.allclose(org_desc.shape, new_desc.shape) + for i in range(org_desc.shape[0]): + flag = False + for j in range( + new_desc.shape[0] + ): # seek in new_desc for matching of org_desc + if np.allclose(org_desc[i], new_desc[j]): + flag = True + break + assert flag, "the sparse_gp and sgp_var don't have the same descriptors" + + +@pytest.mark.skipif( + not os.environ.get("lmp", False), + reason=( + "lmp not found " + "in environment: Please install LAMMPS " + "and set the $lmp env. " + "variable to point to the executatble." + ), +) +def test_lammps(): + # create ASE calc + lmp_command = os.environ.get("lmp") + specorder = ["C", "O"] + pot_file = "lmp.flare" + params = { + "command": lmp_command, + "pair_style": "flare", + "pair_coeff": [f"* * {pot_file}"], + } + files = [pot_file] + lmp_calc = lammpsrun.LAMMPS( + label=f"tmp", + keep_tmp_files=True, + tmp_dir="./tmp/", + parameters=params, + files=files, + specorder=specorder, + ) + + test_atoms = get_random_atoms(a=2.0, sc_size=2, numbers=[6, 8], set_seed=12345) + test_atoms.calc = lmp_calc + lmp_f = test_atoms.get_forces() + lmp_e = test_atoms.get_potential_energy() + lmp_s = test_atoms.get_stress() + + print("GP predicting") + test_atoms.calc = None + test_atoms = FLARE_Atoms.from_ase_atoms(test_atoms) + sgp_calc = get_sgp_calc() + test_atoms.calc = sgp_calc + sgp_f = test_atoms.get_forces() + sgp_e = test_atoms.get_potential_energy() + sgp_s = test_atoms.get_stress() + + print("Energy") + print(lmp_e, sgp_e) + assert np.allclose(lmp_e, sgp_e) + + print("Forces") + print(np.concatenate([lmp_f, sgp_f], axis=1)) + assert np.allclose(lmp_f, sgp_f) + + print("Stress") + print(lmp_s) + print(sgp_s) + assert np.allclose(lmp_s, sgp_s) + + +@pytest.mark.skipif( + not os.environ.get("lmp", False), + reason=( + "lmp not found " + "in environment: Please install LAMMPS " + "and set the $lmp env. " + "variable to point to the executatble." + ), +) +def test_lammps_uncertainty(): + # create ASE calc + lmp_command = os.environ.get("lmp") + specorder = ["C", "O"] + pot_file = "lmp.flare" + params = { + "command": lmp_command, + "pair_style": "flare", + "pair_coeff": [f"* * {pot_file}"], + } + files = [pot_file] + lmp_calc = lammpsrun.LAMMPS( + label=f"tmp", + keep_tmp_files=True, + tmp_dir="./tmp/", + parameters=params, + files=files, + specorder=specorder, + ) + + test_atoms = get_random_atoms(a=2.0, sc_size=2, numbers=[6, 8], set_seed=54321) + + # compute uncertainty + in_lmp = """ +atom_style atomic +units metal +boundary p p p +atom_modify sort 0 0.0 + +read_data data.lammps + +### interactions +pair_style flare +pair_coeff * * lmp.flare +mass 1 1.008000 +mass 2 4.002602 + +### run +fix fix_nve all nve +compute unc all flare/std/atom beta_var.txt +dump dump_all all custom 1 traj.lammps id type x y z vx vy vz fx fy fz c_unc[1] c_unc[2] c_unc[3] +thermo_style custom step temp press cpu pxx pyy pzz pxy pxz pyz ke pe etotal vol lx ly lz atoms +thermo_modify flush yes format float %23.16g +thermo 1 +run 0 +""" + os.chdir("tmp") + write("data.lammps", test_atoms, format="lammps-data") + with open("in.lammps", "w") as f: + f.write(in_lmp) + shutil.copyfile("../beta_var.txt", "./beta_var.txt") + os.system(f"{lmp_command} < in.lammps > log.lammps") + unc_atoms = read("traj.lammps", format="lammps-dump-text") + sgp_py = get_updated_sgp() + lmp_stds = [unc_atoms.get_array(f"c_unc[{i+1}]") / sgp_py.hyps[i] for i in range(len(calc_list))] + lmp_stds = np.squeeze(lmp_stds).T + + # Test mapped variance (need to use sgp_var) + test_atoms.calc = None + test_atoms = FLARE_Atoms.from_ase_atoms(test_atoms) + sgp_calc = get_sgp_calc() + test_atoms.calc = sgp_calc + test_atoms.calc.gp_model.sparse_gp = sgp_py.sgp_var + test_atoms.calc.reset() + sgp_stds = test_atoms.calc.get_uncertainties(test_atoms) + print(sgp_stds) + print(lmp_stds) + assert np.allclose(sgp_stds, lmp_stds) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..72b395df --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,12 @@ +import os +import numpy as np +import pytest +import sys +from flare_pp.utils import add_sparse_indices_to_xyz + +def test_utils(): + add_sparse_indices_to_xyz( + xyz_file_in="all_dft_frames.xyz", + ind_file="all_select_idx.txt", + xyz_file_out="dft_data.xyz", + ) diff --git a/timing/CMakeLists.txt b/timing/CMakeLists.txt index 22f2c0f2..c301fecf 100644 --- a/timing/CMakeLists.txt +++ b/timing/CMakeLists.txt @@ -1,9 +1,14 @@ # make executable -add_executable(time time_single_bond.cpp) -target_include_directories(time PUBLIC ${ACE_INCLUDE_DIR}) -target_link_libraries(time PUBLIC flare_pp) +#add_executable(time time_single_bond.cpp) +#target_include_directories(time PUBLIC ${ACE_INCLUDE_DIR}) +#target_link_libraries(time PUBLIC flare_pp) +# +# +#add_executable(benchmark benchmark_B2.cpp) +#target_include_directories(benchmark PUBLIC ${ACE_INCLUDE_DIR}) +#target_link_libraries(benchmark PUBLIC flare_pp) - -add_executable(benchmark benchmark_B2.cpp) -target_include_directories(benchmark PUBLIC ${ACE_INCLUDE_DIR}) -target_link_libraries(benchmark PUBLIC flare_pp) +add_executable(time_mpi mpi_construction.cpp) +#target_include_directories(time_mpi PUBLIC ${ACE_INCLUDE_DIR}) +include_directories(../src/flare_pp) +target_link_libraries(time_mpi PUBLIC flare) diff --git a/timing/mpi_construction.cpp b/timing/mpi_construction.cpp new file mode 100644 index 00000000..a105c14a --- /dev/null +++ b/timing/mpi_construction.cpp @@ -0,0 +1,213 @@ +#include +#include +#include +#include +#include +#include "omp.h" +#include "mpi.h" + +#include "b2.h" +#include "parallel_sgp.h" +#include "sparse_gp.h" +#include "structure.h" +#include "normalized_dot_product.h" + +int main(int argc, char* argv[]) { + double sigma_e = 1; + double sigma_f = 2; + double sigma_s = 3; + + int n_atoms = 100; + int n_envs = 5; + int n_strucs = std::stoi(argv[1]); + int n_species = 2; + int n_types = n_species; + + double cell_size = 10; + double cutoff = cell_size / 2; + + int N = 8; + int L = 3; + std::string radial_string = "chebyshev"; + std::string cutoff_string = "cosine"; + std::vector radial_hyps{0, cutoff}; + std::vector cutoff_hyps; + std::vector descriptor_settings{n_species, N, L}; + + std::vector dc; + B2 ps(radial_string, cutoff_string, radial_hyps, cutoff_hyps, + descriptor_settings); + dc.push_back(&ps); + + blacs::initialize(); + + double sigma = 2.0; + int power = 2; + NormalizedDotProduct kernel_norm = NormalizedDotProduct(sigma, power); + std::vector kernels; + kernels.push_back(&kernel_norm); + ParallelSGP parallel_sgp = ParallelSGP(kernels, sigma_e, sigma_f, sigma_s); + SparseGP sparse_gp = SparseGP(kernels, sigma_e, sigma_f, sigma_s); + + // Build kernel matrices for paralle sgp + std::vector training_strucs; + std::vector>> sparse_indices = {{}}; + Eigen::MatrixXd cell, positions; + Eigen::VectorXd labels; + std::vector species, sparse_inds; + + for (int t = 0; t < n_strucs; t++) { + Eigen::MatrixXd cell = Eigen::MatrixXd::Identity(3, 3) * cell_size; + + // Make random positions + Eigen::MatrixXd positions = Eigen::MatrixXd::Random(n_atoms, 3) * cell_size / 2; + MPI_Bcast(positions.data(), n_atoms * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + // Make random labels + Eigen::VectorXd labels = Eigen::VectorXd::Random(1 + n_atoms * 3 + 6); + MPI_Bcast(labels.data(), 1 + n_atoms * 3 + 6, MPI_DOUBLE, 0, MPI_COMM_WORLD); + + // Make random species. + std::vector species; + for (int i = 0; i < n_atoms; i++) { + species.push_back(rand() % n_species); + } + MPI_Bcast(species.data(), n_atoms, MPI_INT, 0, MPI_COMM_WORLD); + + // Make random sparse envs + std::vector env_inds; + for (int i = 0; i < n_atoms; i++) env_inds.push_back(i); + std::random_shuffle( env_inds.begin(), env_inds.end() ); + std::vector sparse_inds; + for (int i = 0; i < n_envs; i++) sparse_inds.push_back(env_inds[i]); + MPI_Bcast(sparse_inds.data(), n_envs, MPI_INT, 0, MPI_COMM_WORLD); + sparse_indices[0].push_back(sparse_inds); + + Structure struc(cell, species, positions); + struc.energy = labels.segment(0, 1); + struc.forces = labels.segment(1, n_atoms * 3); + struc.stresses = labels.segment(1 + n_atoms * 3, 6); + training_strucs.push_back(struc); + } + + std::cout << "Start building" << std::endl; + double duration = 0; + std::chrono::high_resolution_clock::time_point t1 = std::chrono::high_resolution_clock::now(); + + parallel_sgp.build(training_strucs, cutoff, dc, sparse_indices, n_types); + + std::chrono::high_resolution_clock::time_point t2 = std::chrono::high_resolution_clock::now(); + duration += (double) std::chrono::duration_cast( t2 - t1 ).count(); + std::cout << "Rank: " << blacs::mpirank << ", time: " << duration << " ms" << std::endl; + + if (blacs::mpirank == 0) { + // Build sparse_gp (non parallel) + for (int t = 0; t < n_strucs; t++) { + cell = training_strucs[t].cell; + positions = training_strucs[t].positions; + species = training_strucs[t].species; + sparse_inds = sparse_indices[0][t]; + + Structure train_struc(cell, species, positions, cutoff, dc); + train_struc.energy = training_strucs[t].energy; + train_struc.forces = training_strucs[t].forces; + train_struc.stresses = training_strucs[t].stresses; + + sparse_gp.add_training_structure(train_struc); + sparse_gp.add_specific_environments(train_struc, sparse_inds); + } + sparse_gp.update_matrices_QR(); + std::cout << "Done QR for sparse_gp" << std::endl; + + // Check the kernel matrices are consistent + std::cout << "begin comparing n_clusters" << std::endl; + assert(parallel_sgp.sparse_descriptors[0].n_clusters == sparse_gp.Sigma.rows()); + assert(parallel_sgp.sparse_descriptors[0].n_clusters == sparse_gp.sparse_descriptors[0].n_clusters); + std::cout << "done comparing n_clusters" << std::endl; + for (int t = 0; t < parallel_sgp.sparse_descriptors[0].n_types; t++) { + for (int r = 0; r < parallel_sgp.sparse_descriptors[0].descriptors[t].rows(); r++) { + double par_desc_norm = parallel_sgp.sparse_descriptors[0].descriptor_norms[t](r); + double sgp_desc_norm = sparse_gp.sparse_descriptors[0].descriptor_norms[t](r); + + if (std::abs(par_desc_norm - sgp_desc_norm) > 1e-6) { + std::cout << "++ t=" << t << ", r=" << r; + std::cout << ", par_desc_norm=" << par_desc_norm << ", sgp_desc_norm=" << sgp_desc_norm << std::endl; +// throw std::runtime_error("descriptors does not match"); + } + + for (int c = 0; c < parallel_sgp.sparse_descriptors[0].descriptors[t].cols(); c++) { + double par_desc = parallel_sgp.sparse_descriptors[0].descriptors[t](r, c); + double sgp_desc = sparse_gp.sparse_descriptors[0].descriptors[t](r, c); + + if (std::abs(par_desc - sgp_desc) > 1e-6) { + std::cout << "*** t=" << t << ", r=" << r << " c=" << c; + std::cout << ", par_desc=" << par_desc << ", sgp_desc=" << sgp_desc << std::endl; +// throw std::runtime_error("descriptors does not match"); + } + } + } + } + std::cout << "Checked matrix shape" << std::endl; + std::cout << "parallel_sgp.Kuu(0, 0)=" << parallel_sgp.Kuu(0, 0) << std::endl; + + for (int r = 0; r < parallel_sgp.Kuu.rows(); r++) { + for (int c = 0; c < parallel_sgp.Kuu.rows(); c++) { + // Sometimes the accuracy is between 1e-6 ~ 1e-5 + if (std::abs(parallel_sgp.Kuu(r, c) - sparse_gp.Kuu(r, c)) > 1e-6) { + throw std::runtime_error("Kuu does not match"); + } + } + } + std::cout << "Kuu matches" << std::endl; + + for (int r = 0; r < parallel_sgp.Kuu_inverse.rows(); r++) { + for (int c = 0; c < parallel_sgp.Kuu_inverse.rows(); c++) { + std::cout << parallel_sgp.Kuu_inverse(r, c) << std::endl; + if (std::abs(parallel_sgp.Kuu_inverse(r, c) - sparse_gp.Kuu_inverse(r, c)) > 1e-5) { + throw std::runtime_error("Kuu_inverse does not match"); + } + } + } + std::cout << "Kuu_inverse matches" << std::endl; + + for (int r = 0; r < parallel_sgp.alpha.size(); r++) { + if (std::abs(parallel_sgp.alpha(r) - sparse_gp.alpha(r)) > 1e-6) { + std::cout << "alpha: r=" << r << " " << parallel_sgp.alpha(r) << " " << sparse_gp.alpha(r) << std::endl; + throw std::runtime_error("alpha does not match"); + } + } + std::cout << "alpha matches" << std::endl; + + // Compare predictions on testing structure are consistent + cell = Eigen::MatrixXd::Identity(3, 3) * cell_size; + positions = Eigen::MatrixXd::Random(n_atoms, 3) * cell_size / 2; + // Make random species. + species.clear(); + for (int i = 0; i < n_atoms; i++) { + species.push_back(rand() % n_species); + } + Structure test_struc(cell, species, positions, cutoff, dc); + parallel_sgp.predict_local_uncertainties(test_struc); + Structure test_struc_copy(test_struc.cell, test_struc.species, test_struc.positions, cutoff, dc); + sparse_gp.predict_local_uncertainties(test_struc_copy); + + for (int r = 0; r < test_struc.mean_efs.size(); r++) { + if (std::abs(test_struc.mean_efs(r) - test_struc_copy.mean_efs(r)) > 1e-5) { + std::cout << "mean_efs: r=" << r << " " << test_struc.mean_efs(r) << " " << test_struc_copy.mean_efs(r) << std::endl; + throw std::runtime_error("mean_efs does not match"); + } + } + std::cout << "mean_efs matches" << std::endl; + + for (int i = 0; i < test_struc.local_uncertainties.size(); i++) { + for (int r = 0; r < test_struc.local_uncertainties[i].size(); r++) { + if (std::abs(test_struc.local_uncertainties[i](r) - test_struc_copy.local_uncertainties[i](r)) > 1e-5) { + throw std::runtime_error("local_unc does not match"); + } + } + } + std::cout << "local_unc matches" << std::endl; + } + + return 0; +} diff --git a/timing/multinode_mpi.sh b/timing/multinode_mpi.sh new file mode 100644 index 00000000..5403179a --- /dev/null +++ b/timing/multinode_mpi.sh @@ -0,0 +1,16 @@ +#!/bin/sh +#SBATCH -n 36 +#SBATCH -N 2 +#SBATCH -t 10:00:00 +#SBATCH -p kozinsky +#SBATCH --mem-per-cpu=5000 +#SBATCH --mail-type=ALL + + +module load cmake/3.17.3-fasrc01 +module load gcc/9.3.0-fasrc01 openmpi/4.0.5-fasrc01 +module load intel-mkl/2019.5.281-fasrc01 +module load eigen/3.3.7-fasrc01 + +export OMP_NUM_THREADS=1 +mpirun -n 36 ./time_mpi 100