Skip to content

Commit

Permalink
feat: parameters for 'MatMul'
Browse files Browse the repository at this point in the history
Changes 'MatMul' to provide parameters for weights when used in ML.
  • Loading branch information
nfnt committed Jun 5, 2023
1 parent 2ed856f commit d25bab5
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions onnx2torch/node_converters/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,46 @@
'OnnxMatMul',
]

import math

import torch
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OnnxToTorchModule, OnnxMapping
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node


class OnnxMatMul(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
return torch.matmul(x, y)
weight: torch.Tensor

def __init__(self, in_features, out_features) -> None:
super().__init__()

self.weight = nn.Parameter(torch.empty((out_features, in_features)))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
return torch.matmul(input, self.weight)


@add_converter(operation_type='MatMul', version=1)
@add_converter(operation_type='MatMul', version=9)
@add_converter(operation_type='MatMul', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
a_name = node.input_values[0]
b_name = node.input_values[1]

weights = graph.initializers[b_name]
weights = weights.to_torch()

torch_module = OnnxMatMul(in_features=weights.shape[1], out_features=weights.shape[0])

with torch.no_grad():
torch_module.weight.data = weights

return OperationConverterResult(
torch_module=OnnxMatMul(),
onnx_mapping=onnx_mapping_from_node(node=node),
torch_module=torch_module,
onnx_mapping=OnnxMapping(inputs=(a_name, ), outputs=node.output_values),
)

0 comments on commit d25bab5

Please sign in to comment.