From d25bab548681cd08c297597f27e6f6c98d0740ae Mon Sep 17 00:00:00 2001 From: Jan Schlicht Date: Mon, 5 Jun 2023 08:30:20 +0200 Subject: [PATCH] feat: parameters for 'MatMul' Changes 'MatMul' to provide parameters for weights when used in ML. --- onnx2torch/node_converters/matmul.py | 33 ++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/onnx2torch/node_converters/matmul.py b/onnx2torch/node_converters/matmul.py index 32d5aa5..81ea151 100644 --- a/onnx2torch/node_converters/matmul.py +++ b/onnx2torch/node_converters/matmul.py @@ -2,27 +2,46 @@ 'OnnxMatMul', ] +import math + import torch from torch import nn from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode -from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OnnxToTorchModule, OnnxMapping from onnx2torch.utils.common import OperationConverterResult -from onnx2torch.utils.common import onnx_mapping_from_node - class OnnxMatMul(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring - return torch.matmul(x, y) + weight: torch.Tensor + + def __init__(self, in_features, out_features) -> None: + super().__init__() + + self.weight = nn.Parameter(torch.empty((out_features, in_features))) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + return torch.matmul(input, self.weight) @add_converter(operation_type='MatMul', version=1) @add_converter(operation_type='MatMul', version=9) @add_converter(operation_type='MatMul', version=13) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + a_name = node.input_values[0] + b_name = node.input_values[1] + + weights = graph.initializers[b_name] + weights = weights.to_torch() + + torch_module = OnnxMatMul(in_features=weights.shape[1], out_features=weights.shape[0]) + + with torch.no_grad(): + torch_module.weight.data = weights + return OperationConverterResult( - torch_module=OnnxMatMul(), - onnx_mapping=onnx_mapping_from_node(node=node), + torch_module=torch_module, + onnx_mapping=OnnxMapping(inputs=(a_name, ), outputs=node.output_values), )