This repository has been archived by the owner on Apr 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add initial modules that work with tensor maps
to wrap existing torch modules in a flexible way
- Loading branch information
1 parent
71b3dfd
commit f1d0477
Showing
3 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
try: | ||
import torch # noqa: F401 | ||
|
||
HAS_TORCH = True | ||
except ImportError: | ||
HAS_TORCH = False | ||
|
||
if HAS_TORCH: | ||
from .module_tensor import LinearTensorMap, ModuleTensorMap # noqa: F401 | ||
|
||
__all__ = ["LinearTensorMap", "ModuleTensorMap"] | ||
else: | ||
__all__ = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
try: | ||
from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap | ||
|
||
HAS_METATENSOR_TORCH = True | ||
except ImportError: | ||
from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap | ||
|
||
HAS_METATENSOR_TORCH = False | ||
|
||
from copy import deepcopy | ||
from typing import List, Optional | ||
|
||
import torch | ||
from torch.nn import Linear, Module, ModuleDict | ||
|
||
|
||
@torch.jit.interface | ||
class ModuleTensorMapInterface(torch.nn.Module): | ||
""" | ||
This interface required for TorchScript to index the ModuleDict with non-literals | ||
in ModuleTensorMap. TorchScript interprets the forward signature | ||
.. clode-block:: | ||
forward(self, tensor: TensorMap) -> TensorMap: | ||
as | ||
.. clode-block:: | ||
forward(self, input: torch.Tensor) -> torch.Tensor: | ||
Note that the typings and argument names must match exactly so that an interface is | ||
correctly implemented. | ||
Reference | ||
--------- | ||
https://github.com/pytorch/pytorch/pull/45716 | ||
""" | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
pass | ||
|
||
|
||
class ModuleTensorMap(Module): | ||
""" | ||
ModuleDictTensorMap is needed when the keys correspond to different features sizes | ||
:param module_map: a dictionary of modules with tensor map keys as dict keys | ||
each module is applied on a block | ||
:param out_tensor: a tensor map that has the output properties of the Labels | ||
""" | ||
|
||
def __init__(self, module_map: ModuleDict, out_tensor: Optional[TensorMap] = None): | ||
super().__init__() | ||
self._module_map = module_map | ||
self._out_tensor = out_tensor | ||
|
||
@classmethod | ||
def from_module( | ||
cls, | ||
in_keys: Labels, | ||
module: Module, | ||
many_to_one: bool = True, | ||
out_tensor: Optional[TensorMap] = None, | ||
): | ||
""" | ||
:param module: the mode that is applied on each block | ||
:param many_to_one: specifies if a separate module for each block is used | ||
:param out_tensor: a tensor map that has the output properties of the Labels | ||
and keys | ||
""" | ||
module_map = ModuleDict() | ||
for key in in_keys: | ||
module_key = ModuleTensorMap.module_key(key) | ||
if many_to_one: | ||
module_map[module_key] = module | ||
else: | ||
module_map[module_key] = deepcopy(module) | ||
|
||
return cls(module_map, out_tensor) | ||
|
||
def forward(self, tensor: TensorMap) -> TensorMap: | ||
out_blocks: List[TensorBlock] = [] | ||
for key, block in tensor.items(): | ||
out_block = self.forward_block(key, block) | ||
|
||
for parameter, gradient in block.gradients(): | ||
if len(gradient.gradients_list()) != 0: | ||
raise NotImplementedError( | ||
"gradients of gradients are not supported" | ||
) | ||
out_block.add_gradient( | ||
parameter=parameter, | ||
gradient=self.forward_block(key, gradient), | ||
) | ||
out_blocks.append(out_block) | ||
|
||
return TensorMap(tensor.keys, out_blocks) | ||
|
||
def forward_block(self, key: LabelsEntry, block: TensorBlock) -> TensorBlock: | ||
module_key: str = ModuleTensorMap.module_key(key) | ||
module: ModuleTensorMapInterface = self._module_map[module_key] | ||
out_values = module.forward(block.values) | ||
if self._out_tensor is None: | ||
properties = Labels.range("_", out_values.shape[-1]) | ||
else: | ||
properties = self._out_tensor.block(key).properties | ||
return TensorBlock( | ||
values=out_values, | ||
properties=properties, | ||
components=block.components, | ||
samples=block.samples, | ||
) | ||
|
||
@property | ||
def module_map(self) -> ModuleDict: | ||
""" | ||
dictionary that maps hashable tensor map keys to a module | ||
""" | ||
return self._module_map | ||
|
||
@property | ||
def out_tensor(self) -> Optional[TensorMap]: | ||
""" | ||
dictionary that maps hashable tensor map keys to property abels | ||
""" | ||
return self._out_tensor | ||
|
||
@staticmethod | ||
def module_key(key: LabelsEntry) -> str: | ||
return str(key) | ||
|
||
|
||
class LinearTensorMap(ModuleTensorMap): | ||
""" | ||
:param in_tensor: the input tensor map that is used do determine the keys and | ||
shape of the output | ||
:param out_tensor: a tensor map that has the output properties of the Labels and | ||
shape of output | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_tensor: TensorMap, | ||
out_tensor: TensorMap, | ||
bias: bool = True, | ||
): | ||
module_map = ModuleDict() | ||
for key, block in in_tensor.items(): | ||
module_key = ModuleTensorMap.module_key(key) | ||
module = Linear( | ||
len(block.samples), | ||
len(block.properties), | ||
bias, | ||
block.values.device, | ||
block.values.dtype, | ||
) | ||
module_map[module_key] = module | ||
|
||
super().__init__(module_map, out_tensor) | ||
|
||
@classmethod | ||
def from_module( | ||
cls, | ||
in_keys: Labels, | ||
in_features: int, | ||
out_features: int, | ||
bias: bool = True, | ||
device: torch.device = None, | ||
dtype: torch.dtype = None, | ||
many_to_one: bool = True, | ||
out_tensor: Optional[TensorMap] = None, | ||
): | ||
module = Linear(in_features, out_features, bias, device, dtype) | ||
return ModuleTensorMap.from_module(in_keys, module, many_to_one, out_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import pytest | ||
import torch | ||
|
||
from equisolve.nn import LinearTensorMap | ||
|
||
|
||
try: | ||
from metatensor.torch import Labels, TensorBlock, TensorMap, allclose_raise | ||
|
||
HAS_METATENSOR_TORCH = True | ||
except ImportError: | ||
from metatensor import Labels, TensorBlock, TensorMap, allclose_raise | ||
|
||
HAS_METATENSOR_TORCH = False | ||
|
||
|
||
# TODO copy paste from utilies, make this more generic and put it in a common place | ||
def random_single_block_no_components_tensor_map(): | ||
""" | ||
Create a dummy tensor map to be used in tests. This is the same one as the | ||
tensor map used in `tensor.rs` tests. | ||
""" | ||
block_1 = TensorBlock( | ||
values=torch.rand(4, 2), | ||
samples=Labels( | ||
["sample", "structure"], | ||
torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]], dtype=torch.int32), | ||
), | ||
components=[], | ||
properties=Labels(["properties"], torch.tensor([[0], [1]], dtype=torch.int32)), | ||
) | ||
positions_gradient = TensorBlock( | ||
values=torch.rand(7, 3, 2), | ||
samples=Labels( | ||
["sample", "structure", "center"], | ||
torch.tensor( | ||
[ | ||
[0, 0, 1], | ||
[0, 0, 2], | ||
[1, 1, 0], | ||
[1, 1, 1], | ||
[1, 1, 2], | ||
[2, 2, 0], | ||
[3, 3, 0], | ||
], | ||
dtype=torch.int32, | ||
), | ||
), | ||
components=[ | ||
Labels(["direction"], torch.tensor([[0], [1], [2]], dtype=torch.int32)) | ||
], | ||
properties=block_1.properties, | ||
) | ||
block_1.add_gradient("positions", positions_gradient) | ||
|
||
cell_gradient = TensorBlock( | ||
values=torch.rand(4, 6, 2), | ||
samples=Labels( | ||
["sample", "structure"], | ||
torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]], dtype=torch.int32), | ||
), | ||
components=[ | ||
Labels( | ||
["direction_xx_yy_zz_yz_xz_xy"], | ||
torch.tensor([[0], [1], [2], [3], [4], [5]], dtype=torch.int32), | ||
) | ||
], | ||
properties=block_1.properties, | ||
) | ||
block_1.add_gradient("cell", cell_gradient) | ||
|
||
return TensorMap(Labels.single(), [block_1]) | ||
|
||
|
||
class TestModuleTensorMap: | ||
@pytest.fixture(autouse=True) | ||
def set_random_generator(self): | ||
"""Set the random generator to same seed before each test is run. | ||
Otherwise test behaviour is dependend on the order of the tests | ||
in this file and the number of parameters of the test. | ||
""" | ||
torch.random.manual_seed(122578741812) | ||
|
||
@pytest.mark.parametrize( | ||
"tensor", | ||
[ | ||
random_single_block_no_components_tensor_map(), | ||
], | ||
) | ||
def test_linear_module(self, tensor): | ||
tensor_module = LinearTensorMap.from_module( | ||
tensor.keys, in_features=len(tensor[0].properties), out_features=5 | ||
) | ||
with torch.no_grad(): | ||
out_tensor = tensor_module(tensor) | ||
|
||
for key, block in tensor.items(): | ||
module = tensor_module.module_map[LinearTensorMap.module_key(key)] | ||
with torch.no_grad(): | ||
ref_values = module(block.values) | ||
out_block = out_tensor.block(key) | ||
assert torch.allclose(ref_values, out_block.values) | ||
|
||
for parameter, gradient in block.gradients(): | ||
with torch.no_grad(): | ||
ref_gradient_values = module(gradient.values) | ||
assert torch.allclose( | ||
ref_gradient_values, out_block.gradient(parameter).values | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"tensor", | ||
[ | ||
random_single_block_no_components_tensor_map(), | ||
], | ||
) | ||
@pytest.mark.skipif( | ||
not (HAS_METATENSOR_TORCH), reason="requires metatensor-torch to be run" | ||
) | ||
def test_torchscript_linear_module(self, tensor): | ||
tensor_module = LinearTensorMap.from_module( | ||
tensor.keys, in_features=len(tensor[0].properties), out_features=5 | ||
) | ||
ref_tensor = tensor_module(tensor) | ||
|
||
tensor_module_script = torch.jit.script(tensor_module) | ||
out_tensor = tensor_module_script(tensor) | ||
|
||
allclose_raise(ref_tensor, out_tensor) |