Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
add initial modules that work with tensor maps
Browse files Browse the repository at this point in the history
to wrap existing torch modules in a flexible way
  • Loading branch information
agoscinski committed Sep 14, 2023
1 parent 71b3dfd commit f1d0477
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/equisolve/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
try:
import torch # noqa: F401

HAS_TORCH = True
except ImportError:
HAS_TORCH = False

if HAS_TORCH:
from .module_tensor import LinearTensorMap, ModuleTensorMap # noqa: F401

__all__ = ["LinearTensorMap", "ModuleTensorMap"]
else:
__all__ = []
177 changes: 177 additions & 0 deletions src/equisolve/nn/module_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
try:
from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap

HAS_METATENSOR_TORCH = True
except ImportError:
from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap

HAS_METATENSOR_TORCH = False

from copy import deepcopy
from typing import List, Optional

import torch
from torch.nn import Linear, Module, ModuleDict


@torch.jit.interface
class ModuleTensorMapInterface(torch.nn.Module):
"""
This interface required for TorchScript to index the ModuleDict with non-literals
in ModuleTensorMap. TorchScript interprets the forward signature
.. clode-block::
forward(self, tensor: TensorMap) -> TensorMap:
as
.. clode-block::
forward(self, input: torch.Tensor) -> torch.Tensor:
Note that the typings and argument names must match exactly so that an interface is
correctly implemented.
Reference
---------
https://github.com/pytorch/pytorch/pull/45716
"""

def forward(self, input: torch.Tensor) -> torch.Tensor:
pass


class ModuleTensorMap(Module):
"""
ModuleDictTensorMap is needed when the keys correspond to different features sizes
:param module_map: a dictionary of modules with tensor map keys as dict keys
each module is applied on a block
:param out_tensor: a tensor map that has the output properties of the Labels
"""

def __init__(self, module_map: ModuleDict, out_tensor: Optional[TensorMap] = None):
super().__init__()
self._module_map = module_map
self._out_tensor = out_tensor

@classmethod
def from_module(
cls,
in_keys: Labels,
module: Module,
many_to_one: bool = True,
out_tensor: Optional[TensorMap] = None,
):
"""
:param module: the mode that is applied on each block
:param many_to_one: specifies if a separate module for each block is used
:param out_tensor: a tensor map that has the output properties of the Labels
and keys
"""
module_map = ModuleDict()
for key in in_keys:
module_key = ModuleTensorMap.module_key(key)
if many_to_one:
module_map[module_key] = module
else:
module_map[module_key] = deepcopy(module)

return cls(module_map, out_tensor)

def forward(self, tensor: TensorMap) -> TensorMap:
out_blocks: List[TensorBlock] = []
for key, block in tensor.items():
out_block = self.forward_block(key, block)

for parameter, gradient in block.gradients():
if len(gradient.gradients_list()) != 0:
raise NotImplementedError(
"gradients of gradients are not supported"
)
out_block.add_gradient(
parameter=parameter,
gradient=self.forward_block(key, gradient),
)
out_blocks.append(out_block)

return TensorMap(tensor.keys, out_blocks)

def forward_block(self, key: LabelsEntry, block: TensorBlock) -> TensorBlock:
module_key: str = ModuleTensorMap.module_key(key)
module: ModuleTensorMapInterface = self._module_map[module_key]
out_values = module.forward(block.values)
if self._out_tensor is None:
properties = Labels.range("_", out_values.shape[-1])
else:
properties = self._out_tensor.block(key).properties
return TensorBlock(
values=out_values,
properties=properties,
components=block.components,
samples=block.samples,
)

@property
def module_map(self) -> ModuleDict:
"""
dictionary that maps hashable tensor map keys to a module
"""
return self._module_map

@property
def out_tensor(self) -> Optional[TensorMap]:
"""
dictionary that maps hashable tensor map keys to property abels
"""
return self._out_tensor

@staticmethod
def module_key(key: LabelsEntry) -> str:
return str(key)


class LinearTensorMap(ModuleTensorMap):
"""
:param in_tensor: the input tensor map that is used do determine the keys and
shape of the output
:param out_tensor: a tensor map that has the output properties of the Labels and
shape of output
"""

def __init__(
self,
in_tensor: TensorMap,
out_tensor: TensorMap,
bias: bool = True,
):
module_map = ModuleDict()
for key, block in in_tensor.items():
module_key = ModuleTensorMap.module_key(key)
module = Linear(
len(block.samples),
len(block.properties),
bias,
block.values.device,
block.values.dtype,
)
module_map[module_key] = module

super().__init__(module_map, out_tensor)

@classmethod
def from_module(
cls,
in_keys: Labels,
in_features: int,
out_features: int,
bias: bool = True,
device: torch.device = None,
dtype: torch.dtype = None,
many_to_one: bool = True,
out_tensor: Optional[TensorMap] = None,
):
module = Linear(in_features, out_features, bias, device, dtype)
return ModuleTensorMap.from_module(in_keys, module, many_to_one, out_tensor)
129 changes: 129 additions & 0 deletions tests/equisolve_tests/nn/test_module_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest
import torch

from equisolve.nn import LinearTensorMap


try:
from metatensor.torch import Labels, TensorBlock, TensorMap, allclose_raise

HAS_METATENSOR_TORCH = True
except ImportError:
from metatensor import Labels, TensorBlock, TensorMap, allclose_raise

HAS_METATENSOR_TORCH = False


# TODO copy paste from utilies, make this more generic and put it in a common place
def random_single_block_no_components_tensor_map():
"""
Create a dummy tensor map to be used in tests. This is the same one as the
tensor map used in `tensor.rs` tests.
"""
block_1 = TensorBlock(
values=torch.rand(4, 2),
samples=Labels(
["sample", "structure"],
torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]], dtype=torch.int32),
),
components=[],
properties=Labels(["properties"], torch.tensor([[0], [1]], dtype=torch.int32)),
)
positions_gradient = TensorBlock(
values=torch.rand(7, 3, 2),
samples=Labels(
["sample", "structure", "center"],
torch.tensor(
[
[0, 0, 1],
[0, 0, 2],
[1, 1, 0],
[1, 1, 1],
[1, 1, 2],
[2, 2, 0],
[3, 3, 0],
],
dtype=torch.int32,
),
),
components=[
Labels(["direction"], torch.tensor([[0], [1], [2]], dtype=torch.int32))
],
properties=block_1.properties,
)
block_1.add_gradient("positions", positions_gradient)

cell_gradient = TensorBlock(
values=torch.rand(4, 6, 2),
samples=Labels(
["sample", "structure"],
torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]], dtype=torch.int32),
),
components=[
Labels(
["direction_xx_yy_zz_yz_xz_xy"],
torch.tensor([[0], [1], [2], [3], [4], [5]], dtype=torch.int32),
)
],
properties=block_1.properties,
)
block_1.add_gradient("cell", cell_gradient)

return TensorMap(Labels.single(), [block_1])


class TestModuleTensorMap:
@pytest.fixture(autouse=True)
def set_random_generator(self):
"""Set the random generator to same seed before each test is run.
Otherwise test behaviour is dependend on the order of the tests
in this file and the number of parameters of the test.
"""
torch.random.manual_seed(122578741812)

@pytest.mark.parametrize(
"tensor",
[
random_single_block_no_components_tensor_map(),
],
)
def test_linear_module(self, tensor):
tensor_module = LinearTensorMap.from_module(
tensor.keys, in_features=len(tensor[0].properties), out_features=5
)
with torch.no_grad():
out_tensor = tensor_module(tensor)

for key, block in tensor.items():
module = tensor_module.module_map[LinearTensorMap.module_key(key)]
with torch.no_grad():
ref_values = module(block.values)
out_block = out_tensor.block(key)
assert torch.allclose(ref_values, out_block.values)

for parameter, gradient in block.gradients():
with torch.no_grad():
ref_gradient_values = module(gradient.values)
assert torch.allclose(
ref_gradient_values, out_block.gradient(parameter).values
)

@pytest.mark.parametrize(
"tensor",
[
random_single_block_no_components_tensor_map(),
],
)
@pytest.mark.skipif(
not (HAS_METATENSOR_TORCH), reason="requires metatensor-torch to be run"
)
def test_torchscript_linear_module(self, tensor):
tensor_module = LinearTensorMap.from_module(
tensor.keys, in_features=len(tensor[0].properties), out_features=5
)
ref_tensor = tensor_module(tensor)

tensor_module_script = torch.jit.script(tensor_module)
out_tensor = tensor_module_script(tensor)

allclose_raise(ref_tensor, out_tensor)

0 comments on commit f1d0477

Please sign in to comment.