-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Support AllocateConst nodes in TensorIR scheduling flow #12489
Conversation
auto block = block_realize->block; | ||
block.CopyOnWrite()->body = | ||
tir::AllocateConst(var, dtype, extents, constant_map[var], block->body); | ||
n->body = BlockRealize(block_realize->iter_values, block_realize->predicate, block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note this change. This places AllocateConst
at the beginning of the body of BlockRealize
. I found that putting BlockRealize
as the body of AllocateConst
leads to many issues since many places in TIR code assume that the body of a primfunc starts with BlockRealize
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @masahi |
…#12489) * [TIR] Support AllocConstantNode in CreatePrimFunc * Handle AllocConstantNode in LeafBlockRemovalPlan * Properly handle AllocConstNode in BufferAllocationLocator * handle AllocateConst in EstimateFlops * remove NDArray printing * doc update * add test * cpplint * Removed dependency on link-params attribute from target * Restored NDArray printing to unbreak test
TIR-level constants, represented by
AllocateConst
nodes, were introduced in #8509. They are generated iflink-params = 1
.At Relay level,
link_params = True
makes constant tensors kept bound to the function body, rather than lifting them to function parameters duringFuseOps
. So we end up with the following Relay prim func, for example:(if
link-params = False
, which is the default, the weight also becomes the parameter of the function).I found that TensorIR related components below do not support such TIR-level constants currently:
CreatePrimFunc
: It assumes that all tensors will be passed as parameters. This doesn't hold iflink-params = 1
.PlanAndUpdateBufferAllocationLocation
: Handling forAllocateConst
nodes is incorrect. It creates a duplicated allocation for constant already allocated byAllocateConst
node, whichStorageRewrite
complains because there are two allocations of the same-name constant.LeafBlockRemovalPlan
(used bycompute_at
andcompute_inline
): Assumes that the body ofBlock
begins withSeqStmt
. This may not be the case since I placedAllocateConst
at the beginning of a body, like belowThe most important change is the introduction of
CreatePrimFuncWithConstants
function. It reusestir::BindParams
pass added in #8509 to create a PrimFunc withAllocateConst
nodes. Since this new function takes an array ofruntime::NDArray
as an additional argument, I had to change all places whereCreatePrimFunc
is used.cc @junrushao1994 @Hzfengsy @vinx13 @csullivan