diff --git a/src/equisolve/nn/__init__.py b/src/equisolve/nn/__init__.py new file mode 100644 index 0000000..0ada8c6 --- /dev/null +++ b/src/equisolve/nn/__init__.py @@ -0,0 +1,13 @@ +try: + import torch # noqa: F401 + + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + +if HAS_TORCH: + from .module_tensor import LinearTensorMap, ModuleTensorMap # noqa: F401 + + __all__ = ["LinearTensorMap", "ModuleTensorMap"] +else: + __all__ = [] diff --git a/src/equisolve/nn/module_tensor.py b/src/equisolve/nn/module_tensor.py new file mode 100644 index 0000000..323354b --- /dev/null +++ b/src/equisolve/nn/module_tensor.py @@ -0,0 +1,177 @@ +try: + from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap + + HAS_METATENSOR_TORCH = True +except ImportError: + from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap + + HAS_METATENSOR_TORCH = False + +from copy import deepcopy +from typing import List, Optional + +import torch +from torch.nn import Linear, Module, ModuleDict + + +@torch.jit.interface +class ModuleTensorMapInterface(torch.nn.Module): + """ + This interface required for TorchScript to index the ModuleDict with non-literals + in ModuleTensorMap. TorchScript interprets the forward signature + + .. clode-block:: + + forward(self, tensor: TensorMap) -> TensorMap: + + as + + .. clode-block:: + + forward(self, input: torch.Tensor) -> torch.Tensor: + + Note that the typings and argument names must match exactly so that an interface is + correctly implemented. + + Reference + --------- + https://github.com/pytorch/pytorch/pull/45716 + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + pass + + +class ModuleTensorMap(Module): + """ + ModuleDictTensorMap is needed when the keys correspond to different features sizes + + :param module_map: a dictionary of modules with tensor map keys as dict keys + each module is applied on a block + + :param out_tensor: a tensor map that has the output properties of the Labels + """ + + def __init__(self, module_map: ModuleDict, out_tensor: Optional[TensorMap] = None): + super().__init__() + self._module_map = module_map + self._out_tensor = out_tensor + + @classmethod + def from_module( + cls, + in_keys: Labels, + module: Module, + many_to_one: bool = True, + out_tensor: Optional[TensorMap] = None, + ): + """ + :param module: the mode that is applied on each block + :param many_to_one: specifies if a separate module for each block is used + :param out_tensor: a tensor map that has the output properties of the Labels + and keys + """ + module_map = ModuleDict() + for key in in_keys: + module_key = ModuleTensorMap.module_key(key) + if many_to_one: + module_map[module_key] = module + else: + module_map[module_key] = deepcopy(module) + + return cls(module_map, out_tensor) + + def forward(self, tensor: TensorMap) -> TensorMap: + out_blocks: List[TensorBlock] = [] + for key, block in tensor.items(): + out_block = self.forward_block(key, block) + + for parameter, gradient in block.gradients(): + if len(gradient.gradients_list()) != 0: + raise NotImplementedError( + "gradients of gradients are not supported" + ) + out_block.add_gradient( + parameter=parameter, + gradient=self.forward_block(key, gradient), + ) + out_blocks.append(out_block) + + return TensorMap(tensor.keys, out_blocks) + + def forward_block(self, key: LabelsEntry, block: TensorBlock) -> TensorBlock: + module_key: str = ModuleTensorMap.module_key(key) + module: ModuleTensorMapInterface = self._module_map[module_key] + out_values = module.forward(block.values) + if self._out_tensor is None: + properties = Labels.range("_", out_values.shape[-1]) + else: + properties = self._out_tensor.block(key).properties + return TensorBlock( + values=out_values, + properties=properties, + components=block.components, + samples=block.samples, + ) + + @property + def module_map(self) -> ModuleDict: + """ + dictionary that maps hashable tensor map keys to a module + """ + return self._module_map + + @property + def out_tensor(self) -> Optional[TensorMap]: + """ + dictionary that maps hashable tensor map keys to property abels + """ + return self._out_tensor + + @staticmethod + def module_key(key: LabelsEntry) -> str: + return str(key) + + +class LinearTensorMap(ModuleTensorMap): + """ + :param in_tensor: the input tensor map that is used do determine the keys and + shape of the output + :param out_tensor: a tensor map that has the output properties of the Labels and + shape of output + """ + + def __init__( + self, + in_tensor: TensorMap, + out_tensor: TensorMap, + bias: bool = True, + ): + module_map = ModuleDict() + for key, block in in_tensor.items(): + module_key = ModuleTensorMap.module_key(key) + module = Linear( + len(block.samples), + len(block.properties), + bias, + block.values.device, + block.values.dtype, + ) + module_map[module_key] = module + + super().__init__(module_map, out_tensor) + + @classmethod + def from_module( + cls, + in_keys: Labels, + in_features: int, + out_features: int, + bias: bool = True, + device: torch.device = None, + dtype: torch.dtype = None, + many_to_one: bool = True, + out_tensor: Optional[TensorMap] = None, + ): + module = Linear(in_features, out_features, bias, device, dtype) + return ModuleTensorMap.from_module(in_keys, module, many_to_one, out_tensor) diff --git a/tests/equisolve_tests/nn/test_module_tensor.py b/tests/equisolve_tests/nn/test_module_tensor.py new file mode 100644 index 0000000..a755be2 --- /dev/null +++ b/tests/equisolve_tests/nn/test_module_tensor.py @@ -0,0 +1,129 @@ +import pytest +import torch + +from equisolve.nn import LinearTensorMap + + +try: + from metatensor.torch import Labels, TensorBlock, TensorMap, allclose_raise + + HAS_METATENSOR_TORCH = True +except ImportError: + from metatensor import Labels, TensorBlock, TensorMap, allclose_raise + + HAS_METATENSOR_TORCH = False + + +# TODO copy paste from utilies, make this more generic and put it in a common place +def random_single_block_no_components_tensor_map(): + """ + Create a dummy tensor map to be used in tests. This is the same one as the + tensor map used in `tensor.rs` tests. + """ + block_1 = TensorBlock( + values=torch.rand(4, 2), + samples=Labels( + ["sample", "structure"], + torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]], dtype=torch.int32), + ), + components=[], + properties=Labels(["properties"], torch.tensor([[0], [1]], dtype=torch.int32)), + ) + positions_gradient = TensorBlock( + values=torch.rand(7, 3, 2), + samples=Labels( + ["sample", "structure", "center"], + torch.tensor( + [ + [0, 0, 1], + [0, 0, 2], + [1, 1, 0], + [1, 1, 1], + [1, 1, 2], + [2, 2, 0], + [3, 3, 0], + ], + dtype=torch.int32, + ), + ), + components=[ + Labels(["direction"], torch.tensor([[0], [1], [2]], dtype=torch.int32)) + ], + properties=block_1.properties, + ) + block_1.add_gradient("positions", positions_gradient) + + cell_gradient = TensorBlock( + values=torch.rand(4, 6, 2), + samples=Labels( + ["sample", "structure"], + torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]], dtype=torch.int32), + ), + components=[ + Labels( + ["direction_xx_yy_zz_yz_xz_xy"], + torch.tensor([[0], [1], [2], [3], [4], [5]], dtype=torch.int32), + ) + ], + properties=block_1.properties, + ) + block_1.add_gradient("cell", cell_gradient) + + return TensorMap(Labels.single(), [block_1]) + + +class TestModuleTensorMap: + @pytest.fixture(autouse=True) + def set_random_generator(self): + """Set the random generator to same seed before each test is run. + Otherwise test behaviour is dependend on the order of the tests + in this file and the number of parameters of the test. + """ + torch.random.manual_seed(122578741812) + + @pytest.mark.parametrize( + "tensor", + [ + random_single_block_no_components_tensor_map(), + ], + ) + def test_linear_module(self, tensor): + tensor_module = LinearTensorMap.from_module( + tensor.keys, in_features=len(tensor[0].properties), out_features=5 + ) + with torch.no_grad(): + out_tensor = tensor_module(tensor) + + for key, block in tensor.items(): + module = tensor_module.module_map[LinearTensorMap.module_key(key)] + with torch.no_grad(): + ref_values = module(block.values) + out_block = out_tensor.block(key) + assert torch.allclose(ref_values, out_block.values) + + for parameter, gradient in block.gradients(): + with torch.no_grad(): + ref_gradient_values = module(gradient.values) + assert torch.allclose( + ref_gradient_values, out_block.gradient(parameter).values + ) + + @pytest.mark.parametrize( + "tensor", + [ + random_single_block_no_components_tensor_map(), + ], + ) + @pytest.mark.skipif( + not (HAS_METATENSOR_TORCH), reason="requires metatensor-torch to be run" + ) + def test_torchscript_linear_module(self, tensor): + tensor_module = LinearTensorMap.from_module( + tensor.keys, in_features=len(tensor[0].properties), out_features=5 + ) + ref_tensor = tensor_module(tensor) + + tensor_module_script = torch.jit.script(tensor_module) + out_tensor = tensor_module_script(tensor) + + allclose_raise(ref_tensor, out_tensor)