Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
add tests for ModuleTensorMap
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Sep 15, 2023
1 parent cc3d982 commit 665688d
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion tests/equisolve_tests/nn/test_module_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
HAS_TORCH = False

if HAS_TORCH:
from equisolve.nn import LinearTensorMap
from torch.nn import Linear, Module, ModuleDict, Sigmoid

from equisolve.nn import LinearTensorMap, ModuleTensorMap

try:
from metatensor.torch import allclose_raise
Expand All @@ -22,6 +24,18 @@

HAS_METATENSOR_TORCH = False

if HAS_TORCH:

class MockModule(Module):
def __init__(self, in_features, out_features):
super().__init__()
self._linear = Linear(in_features, out_features)
self._activation = Sigmoid()
self._last_layer = Linear(out_features, 1)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return self._last_layer(self._activation(self._linear(input)))


@pytest.mark.skipif(not (HAS_TORCH), reason="requires torch to be run")
class TestModuleTensorMap:
Expand All @@ -33,6 +47,63 @@ def set_random_generator(self):
"""
torch.random.manual_seed(122578741812)

@pytest.mark.parametrize(
"tensor",
[
random_single_block_no_components_tensor_map(
HAS_TORCH, HAS_METATENSOR_TORCH
),
],
)
def test_module_tensor(self, tensor):
module_map = ModuleDict()
for key in tensor.keys:
module_map[ModuleTensorMap.module_key(key)] = MockModule(
in_features=len(tensor.block(key).properties), out_features=5
)
tensor_module = ModuleTensorMap(module_map)
with torch.no_grad():
out_tensor = tensor_module(tensor)

for key, block in tensor.items():
module = module_map[LinearTensorMap.module_key(key)]
with torch.no_grad():
ref_values = module(block.values)
out_block = out_tensor.block(key)
assert torch.allclose(ref_values, out_block.values)

for parameter, gradient in block.gradients():
with torch.no_grad():
ref_gradient_values = module(gradient.values)
assert torch.allclose(
ref_gradient_values, out_block.gradient(parameter).values
)

@pytest.mark.parametrize(
"tensor",
[
random_single_block_no_components_tensor_map(
HAS_TORCH, HAS_METATENSOR_TORCH
),
],
)
@pytest.mark.skipif(
not (HAS_METATENSOR_TORCH), reason="requires metatensor-torch to be run"
)
def test_torchscript_module_tensor(self, tensor):
module_map = ModuleDict()
for key in tensor.keys:
module_map[ModuleTensorMap.module_key(key)] = MockModule(
in_features=len(tensor.block(key).properties), out_features=5
)
tensor_module = ModuleTensorMap(module_map)
ref_tensor = tensor_module(tensor)

tensor_module_script = torch.jit.script(tensor_module)
out_tensor = tensor_module_script(tensor)

allclose_raise(ref_tensor, out_tensor)

@pytest.mark.parametrize(
"tensor",
[
Expand Down

0 comments on commit 665688d

Please sign in to comment.