Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Support AllocateConst nodes in TensorIR scheduling flow #12489

Merged
merged 10 commits into from
Aug 22, 2022

Conversation

masahi
Copy link
Member

@masahi masahi commented Aug 18, 2022

TIR-level constants, represented by AllocateConst nodes, were introduced in #8509. They are generated if link-params = 1.

At Relay level, link_params = True makes constant tensors kept bound to the function body, rather than lifting them to function parameters during FuseOps. So we end up with the following Relay prim func, for example:

fn (%p0: Tensor[(128, 768), uint8]) -> Tensor[(128, 768), int32] {
  nn.contrib_dense_pack(%p0, meta[relay.Constant][0], units=None, out_dtype="int32", weight_layout="NC32n4c")
} 

(if link-params = False, which is the default, the weight also becomes the parameter of the function).

I found that TensorIR related components below do not support such TIR-level constants currently:

  • CreatePrimFunc: It assumes that all tensors will be passed as parameters. This doesn't hold if link-params = 1.
  • PlanAndUpdateBufferAllocationLocation: Handling for AllocateConst nodes is incorrect. It creates a duplicated allocation for constant already allocated by AllocateConst node, which StorageRewrite complains because there are two allocations of the same-name constant.
  • LeafBlockRemovalPlan (used by compute_at and compute_inline): Assumes that the body of Block begins with SeqStmt. This may not be the case since I placed AllocateConst at the beginning of a body, like below
def main(...) -> None:
    ...
    fused_nn_contrib_conv2d_NCHWc_constant_1 = T.allocate_const(..., "int32", [1, 16, 1, 1, 4])
    fused_constant_0 = T.allocate_const(..., "int8", [16, 16, 3, 3, 1, 4, 4])
    for i0, i1, i2, i3, i4 in T.grid(1, 16, 58, 58, 4):
       ...

The most important change is the introduction of CreatePrimFuncWithConstants function. It reuses tir::BindParams pass added in #8509 to create a PrimFunc with AllocateConst nodes. Since this new function takes an array of runtime::NDArray as an additional argument, I had to change all places where CreatePrimFunc is used.

cc @junrushao1994 @Hzfengsy @vinx13 @csullivan

auto block = block_realize->block;
block.CopyOnWrite()->body =
tir::AllocateConst(var, dtype, extents, constant_map[var], block->body);
n->body = BlockRealize(block_realize->iter_values, block_realize->predicate, block);
Copy link
Member Author

@masahi masahi Aug 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note this change. This places AllocateConst at the beginning of the body of BlockRealize. I found that putting BlockRealize as the body of AllocateConst leads to many issues since many places in TIR code assume that the body of a primfunc starts with BlockRealize.

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Hzfengsy Hzfengsy merged commit 8146a9b into apache:main Aug 22, 2022
@Hzfengsy
Copy link
Member

Thanks @masahi

xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
…#12489)

* [TIR] Support AllocConstantNode in CreatePrimFunc

* Handle AllocConstantNode in LeafBlockRemovalPlan

* Properly handle AllocConstNode in BufferAllocationLocator

* handle AllocateConst in EstimateFlops

* remove NDArray printing

* doc update

* add test

* cpplint

* Removed dependency on link-params attribute from target

* Restored NDArray printing to unbreak test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants