Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert model error #292

Closed
YilanWang opened this issue Apr 8, 2024 · 5 comments · Fixed by #293
Closed

convert model error #292

YilanWang opened this issue Apr 8, 2024 · 5 comments · Fixed by #293
Labels
enhancement New feature or request question Further information is requested work-medium work that can be done within 1 day

Comments

@YilanWang
Copy link

我在convert https://github.com/megvii-research/NAFNet 这个模型时, 发现在: https://github.com/megvii-research/NAFNet/blob/2b4af71ebe098a92a75910c233a3965a3e93ede4/basicsr/models/archs/arch_util.py#L291 这里, 会报这样的error:

RuntimeError: Could not export Python function call 'LayerNormFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:

请问我要如何解决呢?

@peterjc123 peterjc123 added the question Further information is requested label Apr 9, 2024
@peterjc123
Copy link
Collaborator

peterjc123 commented Apr 9, 2024

emm,这是因为你的模型没法过jit.trace,底层原因是因为你给定的路径下的LayerNorm2d用了torch.autograd.Function来实现,这个需要自己来实现模型的转换映射逻辑。当然还有个简单的方案就是把他给换回普通的torch.nn.LayerNorm或者torch.nn.functional.layer_norm.

import torch
import torch.nn as nn
import torch.nn.functional as F

def new_layer_norm(self, x):
    normalized_shape = x.shape[1:]
    return F.layer_norm(x, normalized_shape, self.weight.view(-1,1,1).expand(normalized_shape), self.bias.view(-1,1,1).expand(normalized_shape), self.eps)

def patch_layer_norm(model):
    for name, module in model.named_modules():
        if type(module).__name__ == 'LayerNorm2d':
            module.forward = new_layer_norm.get(module, type(module))

patch_layer_norm(model)

@YilanWang
Copy link
Author

确实我发现是trace的问题, 我也打算用pytorch.ln试试, 但是还是不死心想问问能不能支持下这个 :)

@peterjc123
Copy link
Collaborator

peterjc123 commented Apr 9, 2024

@YilanWang 如果你要走前面那条路也行的,需要实现LayerNormFunction的symbolic方法,参见https://pytorch.org/docs/master/onnx.html#static-symbolic-method 以及 https://pytorch.org/docs/master/onnx.html#c-operators ,就是需要把g.Op里面的方法名换成aten::layer_norm,然后参数前面按照这个函数的方式来传。

@peterjc123
Copy link
Collaborator

@YilanWang 深入研究了一下,发现这个确实是可以支持的,当然还有个问题就是这个graph是不能save的,所以需要先trace再convert,这样可以跳过save这一步。

@peterjc123 peterjc123 added enhancement New feature or request work-medium work that can be done within 1 day labels Apr 9, 2024
@peterjc123
Copy link
Collaborator

@YilanWang 结合 #293 应该已经可以了,试下下面的代码

model.eval()
graph = torch.jit.trace(model, dummy_input)

converter = TFLiteConverter(graph, dummy_input, ...)
converter.convert()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested work-medium work that can be done within 1 day
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants