-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert model error #292
Comments
emm,这是因为你的模型没法过 import torch
import torch.nn as nn
import torch.nn.functional as F
def new_layer_norm(self, x):
normalized_shape = x.shape[1:]
return F.layer_norm(x, normalized_shape, self.weight.view(-1,1,1).expand(normalized_shape), self.bias.view(-1,1,1).expand(normalized_shape), self.eps)
def patch_layer_norm(model):
for name, module in model.named_modules():
if type(module).__name__ == 'LayerNorm2d':
module.forward = new_layer_norm.get(module, type(module))
patch_layer_norm(model) |
确实我发现是trace的问题, 我也打算用pytorch.ln试试, 但是还是不死心想问问能不能支持下这个 :) |
@YilanWang 如果你要走前面那条路也行的,需要实现LayerNormFunction的symbolic方法,参见https://pytorch.org/docs/master/onnx.html#static-symbolic-method 以及 https://pytorch.org/docs/master/onnx.html#c-operators ,就是需要把g.Op里面的方法名换成aten::layer_norm,然后参数前面按照这个函数的方式来传。 |
@YilanWang 深入研究了一下,发现这个确实是可以支持的,当然还有个问题就是这个graph是不能save的,所以需要先trace再convert,这样可以跳过save这一步。 |
@YilanWang 结合 #293 应该已经可以了,试下下面的代码 model.eval()
graph = torch.jit.trace(model, dummy_input)
converter = TFLiteConverter(graph, dummy_input, ...)
converter.convert() |
我在convert https://github.com/megvii-research/NAFNet 这个模型时, 发现在: https://github.com/megvii-research/NAFNet/blob/2b4af71ebe098a92a75910c233a3965a3e93ede4/basicsr/models/archs/arch_util.py#L291 这里, 会报这样的error:
RuntimeError: Could not export Python function call 'LayerNormFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
请问我要如何解决呢?
The text was updated successfully, but these errors were encountered: