diff --git a/deepmd_mace/mace.py b/deepmd_mace/mace.py index 85513b4..d1811af 100644 --- a/deepmd_mace/mace.py +++ b/deepmd_mace/mace.py @@ -15,6 +15,7 @@ BaseModel, ) from deepmd.pt.model.model.transform_output import ( + atomic_virial_corr, communicate_extended_output, ) from deepmd.pt.utils.nlist import ( @@ -564,7 +565,7 @@ def forward_lower_common( mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, - do_atomic_virial: bool = False, # noqa: ARG002 + do_atomic_virial: bool = False, comm_dict: Optional[dict[str, torch.Tensor]] = None, ) -> dict[str, torch.Tensor]: """Forward lower common pass of the model. @@ -714,6 +715,12 @@ def forward_lower_common( ) @ extended_coord_ff.unsqueeze(-2).to( extended_coord_.dtype, ) + if do_atomic_virial: + extended_virial_corr = atomic_virial_corr( + extended_coord_ff.unsqueeze(0), + atom_energy.view(1, nloc, 1), + ) + atomic_virial = atomic_virial + extended_virial_corr force = force.view(1, nall, 3).to(extended_coord_.dtype) virial = ( torch.sum(atomic_virial, dim=0).view(1, 9).to(extended_coord_.dtype) diff --git a/tests/test_model.py b/tests/test_model.py index b4fd8e7..769d518 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -270,6 +270,7 @@ def test_forward(self) -> None: "box": cell, "aparam": aparam, "fparam": fparam, + "do_atomic_virial": True, } if test_spin: input_dict["spin"] = spin @@ -282,6 +283,7 @@ def test_forward(self) -> None: "aparam": aparam, "fparam": fparam, "mapping": mapping_large, + "do_atomic_virial": True, } if test_spin: input_dict_lower["extended_spin"] = spin_ext