-
Notifications
You must be signed in to change notification settings - Fork 307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
not suport two input image network? #14
Comments
Try to use the class Siamese(nn.Module):
def __init__(self):
super(Siamese, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 3, 1)
self.conv2 = nn.Conv2d(1, 10, 3, 1)
def forward(self, x):
# assume x is a list
return self.conv1(x[0]) + self.conv2(x[1])
def prepare_input(resolution):
x1 = torch.FloatTensor(1, *resolution)
x2 = torch.FloatTensor(1, *resolution)
return dict(x = [x1, x2])
if __name__ == '__main__':
model = Siamese()
flops, params = get_model_complexity_info(model, input_res=(1, 224, 224),
input_constructor=prepare_input,
as_strings=True, print_per_layer_stat=False)
print(' - Flops: ' + flops)
print(' - Params: ' + params) |
Thanks~~it works!! |
what if two inputs have different sizes? |
@chyohoo in that case you can ignore the def prepare_input(resolution):
x1 = torch.FloatTensor(1, 3, 224, 224)
x2 = torch.FloatTensor(1, 3, 128, 128)
return dict(x = [x1, x2]) |
Hi. I tried to implement the calculation following your advice, def prepare_input(resolution): ... flops, macs, params = get_model_complexity_info(model, input_res=((1, 3, 224, 224),(1, 1, 224, 224)),input_constructor=prepare_input,as_strings=True,print_per_layer_stat=True, verbose=True) However,I got the error. Warning: module Softmax is treated as a zero-op. How could I fix it? Many thanks. |
Hi! Take a look at the
|
Thanks for your prompt reply! It works now. |
Does batchsize have to be 1? Can you customize the batch size? I tried your prepare_input and got the following error: |
Hi! If you have an ordinary model that consumes only one input tensor x, the following would work for you: bs = 2
input_constructor = lambda _: {"x": torch.FloatTensor(bs, 3, 224, 224)}
macs, params = get_model_complexity_info(net, (3, 224, 224),
as_strings=True,
input_constructor=input_constructor,
print_per_layer_stat=True,
ost=ost) |
No description provided.
The text was updated successfully, but these errors were encountered: