diff --git a/src/schnetpack/atomistic/atomwise.py b/src/schnetpack/atomistic/atomwise.py index 5443de025..98ef8a1d4 100644 --- a/src/schnetpack/atomistic/atomwise.py +++ b/src/schnetpack/atomistic/atomwise.py @@ -69,14 +69,14 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # aggregate if self.aggregation_mode is not None: - if self.aggregation_mode == "avg": - y = y / inputs[properties.n_atoms][:, None] - idx_m = inputs[properties.idx_m] maxm = int(idx_m[-1]) + 1 y = snn.scatter_add(y, idx_m, dim_size=maxm) y = torch.squeeze(y, -1) + if self.aggregation_mode == "avg": + y = y / inputs[properties.n_atoms] + inputs[self.output_key] = y return inputs diff --git a/src/schnetpack/interfaces/ase_interface.py b/src/schnetpack/interfaces/ase_interface.py index 2758218ce..1e6ec11d1 100644 --- a/src/schnetpack/interfaces/ase_interface.py +++ b/src/schnetpack/interfaces/ase_interface.py @@ -76,9 +76,6 @@ def __init__( else: raise AtomsConverterError(f"Unrecognized precision {dtype}") - for t in self.transforms: - t.preprocessor() - def __call__(self, atoms: Atoms): """ @@ -162,7 +159,7 @@ def __init__( self.property_units = { self.energy: convert_units(energy_units, "eV"), self.forces: convert_units(forces_units, "eV/Angstrom"), - self.stress: convert_units(stress_units, "eV/A/A/A"), + self.stress: convert_units(stress_units, "eV/Ang/Ang/Ang"), } def calculate(