Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add the support for aten::mul operator. #2905

Merged
merged 3 commits into from
Sep 24, 2020

Conversation

zheng-ningxin
Copy link
Contributor

Add the support for aten::mul operator.
Signed-off-by: Ningxin Ningxin.Zheng@microsoft.com

Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
@zheng-ningxin
Copy link
Contributor Author

Issue: #2901

Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
# mul has the similar behaviour with add, they both request
# the input tesors to have the same shape
'aten::mul': lambda module_mask, mask: add_inshape(module_mask, mask),
'aten::mul_': lambda module_mask, mask: add_inshape(module_mask, mask),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how XXX_inshape function works when an operator has multiple inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XXX_inshape cannot have multiple inputs.
For example a += b + c + d # a, b, c, d are tensors
then the inputs of add_inshape is the result of a+b+c

FYI

>>> class tmp(torch.nn.Module):
...     def forward(self, x):
...             ones = torch.ones_like(x)
...             one_2 = torch.ones_like(x)
...             one_3 = torch.ones_like(x)
...             x += one_2 + one_3+ ones
...             return x
...
>>> graph= torch.jit.trace(tmp, torch.ones(4))
>>> torch._C._jit_pass_inline(graph.graph)
>>> graph.graph
graph(%self : __torch__.torch.nn.modules.module.Module,
      %x : Float(4)):
  %9 : int = prim::Constant[value=6]() # <stdin>:4:0
  %10 : int = prim::Constant[value=0]() # <stdin>:4:0
  %11 : Device = prim::Constant[value="cpu"]() # <stdin>:4:0
  %12 : bool = prim::Constant[value=0]() # <stdin>:4:0
  %13 : None = prim::Constant()
  %ones : Float(4) = aten::ones_like(%x, %9, %10, %11, %12, %13) # <stdin>:4:0
  %15 : int = prim::Constant[value=6]() # <stdin>:5:0
  %16 : int = prim::Constant[value=0]() # <stdin>:5:0
  %17 : Device = prim::Constant[value="cpu"]() # <stdin>:5:0
  %18 : bool = prim::Constant[value=0]() # <stdin>:5:0
  %19 : None = prim::Constant()
  %one_2 : Float(4) = aten::ones_like(%x, %15, %16, %17, %18, %19) # <stdin>:5:0
  %21 : int = prim::Constant[value=6]() # <stdin>:6:0
  %22 : int = prim::Constant[value=0]() # <stdin>:6:0
  %23 : Device = prim::Constant[value="cpu"]() # <stdin>:6:0
  %24 : bool = prim::Constant[value=0]() # <stdin>:6:0
  %25 : None = prim::Constant()
  %one_3 : Float(4) = aten::ones_like(%x, %21, %22, %23, %24, %25) # <stdin>:6:0
  %27 : int = prim::Constant[value=1]() # <stdin>:7:0
  %28 : Float(4) = aten::add(%one_2, %one_3, %27) # <stdin>:7:0
  %29 : int = prim::Constant[value=1]() # <stdin>:7:0
  %30 : Float(4) = aten::add(%28, %ones, %29) # <stdin>:7:0
  %31 : int = prim::Constant[value=1]() # <stdin>:7:0
  %32 : Float(4) = aten::add_(%x, %30, %31) # <stdin>:7:0
  return (%32)

@QuanluZhang QuanluZhang linked an issue Sep 23, 2020 that may be closed by this pull request
@QuanluZhang QuanluZhang merged commit f1b8cd2 into microsoft:master Sep 24, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ModelSpeedUp support for aten::mul?
3 participants