-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Metal] Add AD stack to support the local mutables when using AutoDiff #1257
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1257 +/- ##
=========================================
Coverage ? 65.65%
=========================================
Files ? 36
Lines ? 5066
Branches ? 928
=========================================
Hits ? 3326
Misses ? 1569
Partials ? 171 Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank for adding this, a few question about AD to think of OpenGL ones.
TI_ASSERT(stmt->width() == 1); | ||
|
||
const auto &var_name = stmt->raw_name(); | ||
emit("byte {}[{}];", var_name, stmt->size_in_bytes()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we have to use byte array here? Then it might be hard for OpenGL which doesn't support pointers or unions. Or might lose 64-bit support when AD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this will almost always be used for f32
, so maybe supporting that is good enough? However, note that in CUDA/Metal, the stack size is stored as an integer at the head of this byte array. I guess OpenGL can either maintain this integer separately, or just use a float for simplicity?
auto *stack = stmt->stack->as<StackAllocaStmt>(); | ||
const auto &stack_name = stack->raw_name(); | ||
const auto elem_size = stack->element_size_in_bytes(); | ||
emit("mtl_ad_stack_push({}, {});", stack_name, elem_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm thinking right, elem_size will always be 4 in Metal since it only support 32-bit data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'd say it's almost always, since it doesn't prevent us from using more complicated types (For example, one day we may support composite types like complex {f32, f32}
?)
Metal only supports 32-bit primitive data types, but it's fine to have composite types.
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF constexpr auto kMetalAdStackSourceCode = | ||
#define METAL_END_RUNTIME_STRUCTS_DEF ; | ||
#else | ||
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF | ||
#define METAL_END_RUNTIME_STRUCTS_DEF |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really ad-hoc and cool guard!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Just a nit.
#ifdef TI_INSIDE_METAL_CODEGEN | ||
|
||
#ifndef TI_METAL_NESTED_INCLUDE | ||
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF constexpr auto kMetalAdStackSourceCode = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Shall we use the TI_
prefix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TLDR is no... This is a somewhat obscure mechanism to support both including these this code as C++ code, and stringifying them as real Metal kernel/shader code. Every code in taichi/backends/metal/shaders
is following this style. The macros here are only meant to be used by the Metal backend, and not shared to a broader scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks! In the future, we may need to consider supporting stacks with dynamic size as well.
I noticed that Metal hasn't supported this yet. It seems like all I need to do is adding a small stack to support the complex AD kernels? I guess OpenGL can do similar things by simulating a per-thread stack, with each entry being a
(primal, adjoint)
pair (@archibate )Related issue = #581
[Click here for the format server]