Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal] Add AD stack to support the local mutables when using AutoDiff #1257

Merged
merged 3 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace {

namespace shaders {
#define TI_INSIDE_METAL_CODEGEN
#include "taichi/backends/metal/shaders/ad_stack.metal.h"
#include "taichi/backends/metal/shaders/helpers.metal.h"
#include "taichi/backends/metal/shaders/runtime_kernels.metal.h"
#undef TI_INSIDE_METAL_CODEGEN
Expand Down Expand Up @@ -551,6 +552,65 @@ class KernelCodegen : public IRVisitor {
TI_WARN("Cannot print inside Metal kernel, ignored");
}

void visit(StackAllocaStmt *stmt) override {
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
TI_ASSERT(stmt->width() == 1);

const auto &var_name = stmt->raw_name();
emit("byte {}[{}];", var_name, stmt->size_in_bytes());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we have to use byte array here? Then it might be hard for OpenGL which doesn't support pointers or unions. Or might lose 64-bit support when AD.

Copy link
Member Author

@k-ye k-ye Jun 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will almost always be used for f32, so maybe supporting that is good enough? However, note that in CUDA/Metal, the stack size is stored as an integer at the head of this byte array. I guess OpenGL can either maintain this integer separately, or just use a float for simplicity?

emit("mtl_ad_stack_init({});", var_name);
}

void visit(StackPopStmt *stmt) override {
emit("mtl_ad_stack_pop({});", stmt->stack->raw_name());
}

void visit(StackPushStmt *stmt) override {
auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto &stack_name = stack->raw_name();
const auto elem_size = stack->element_size_in_bytes();
emit("mtl_ad_stack_push({}, {});", stack_name, elem_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm thinking right, elem_size will always be 4 in Metal since it only support 32-bit data?

Copy link
Member Author

@k-ye k-ye Jun 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd say it's almost always, since it doesn't prevent us from using more complicated types (For example, one day we may support composite types like complex {f32, f32}?)
Metal only supports 32-bit primitive data types, but it's fine to have composite types.

const auto primal_name = stmt->raw_name() + "_primal_";
emit(
"thread auto* {} = reinterpret_cast<thread "
"{}*>(mtl_ad_stack_top_primal({}, {}));",
primal_name, metal_data_type_name(stmt->element_type()), stack_name,
elem_size);
emit("*{} = {};", primal_name, stmt->v->raw_name());
}

void visit(StackLoadTopStmt *stmt) override {
auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto primal_name = stmt->raw_name() + "_primal_";
emit(
"thread auto* {} = reinterpret_cast<thread "
"{}*>(mtl_ad_stack_top_primal({}, {}));",
primal_name, metal_data_type_name(stmt->element_type()),
stack->raw_name(), stack->element_size_in_bytes());
emit("{} = *{};", stmt->raw_name(), primal_name);
}

void visit(StackLoadTopAdjStmt *stmt) override {
auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto adjoint_name = stmt->raw_name() + "_adjoint_";
emit(
"thread auto* {} = reinterpret_cast<thread "
"{}*>(mtl_ad_stack_top_adjoint({}, {}));",
adjoint_name, metal_data_type_name(stmt->element_type()),
stack->raw_name(), stack->element_size_in_bytes());
emit("auto {} = *{};", stmt->raw_name(), adjoint_name);
}

void visit(StackAccAdjointStmt *stmt) override {
auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto adjoint_name = stmt->raw_name() + "_adjoint_";
emit(
"thread auto* {} = reinterpret_cast<thread "
"{}*>(mtl_ad_stack_top_adjoint({}, {}));",
adjoint_name, metal_data_type_name(stmt->element_type()),
stack->raw_name(), stack->element_size_in_bytes());
emit("*{} += {};", adjoint_name, stmt->v->raw_name());
}

private:
void emit_headers() {
SectionGuard sg(this, Section::Headers);
Expand All @@ -568,6 +628,8 @@ class KernelCodegen : public IRVisitor {
emit("");
current_appender().append_raw(compiled_structs_->snode_structs_source_code);
emit("");
current_appender().append_raw(shaders::kMetalAdStackSourceCode);
emit("");
emit_kernel_args_struct();
}

Expand Down Expand Up @@ -1045,7 +1107,7 @@ FunctionType CodeGen::compile() {
config.demote_dense_struct_fors = true;
irpass::compile_to_offloads(kernel_->ir.get(), config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir);
/*ad_use_stack=*/true, config.print_ir);

KernelCodegen codegen(taichi_kernel_name_,
kernel_->program.snode_root->node_type_name, kernel_,
Expand Down
76 changes: 76 additions & 0 deletions taichi/backends/metal/shaders/ad_stack.metal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include "taichi/backends/metal/shaders/prolog.h"

#ifdef TI_INSIDE_METAL_CODEGEN

#ifndef TI_METAL_NESTED_INCLUDE
#define METAL_BEGIN_AD_STACK_DEF constexpr auto kMetalAdStackSourceCode =
#define METAL_END_AD_STACK_DEF ;
#else
#define METAL_BEGIN_AD_STACK_DEF
#define METAL_END_AD_STACK_DEF
#endif // TI_METAL_NESTED_INCLUDE

#else

#include <cstdint>

#define METAL_BEGIN_AD_STACK_DEF
#define METAL_END_AD_STACK_DEF

#endif // TI_INSIDE_METAL_CODEGEN

// Autodiff stack for local mutables

// clang-format off
METAL_BEGIN_AD_STACK_DEF
STR(
// clang-format on
using AdStackPtr = thread byte *;

inline thread uint32_t *
mtl_ad_stack_n(AdStackPtr stack) {
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
return reinterpret_cast<thread uint32_t *>(stack);
}

inline AdStackPtr mtl_ad_stack_data(AdStackPtr stack) {
return stack + sizeof(uint32_t);
}

inline void mtl_ad_stack_init(AdStackPtr stack) {
*mtl_ad_stack_n(stack) = 0;
}

inline AdStackPtr mtl_ad_stack_top_primal(AdStackPtr stack,
int element_size) {
const auto n = *mtl_ad_stack_n(stack);
return mtl_ad_stack_data(stack) + (n - 1) * 2 * element_size;
}

inline AdStackPtr mtl_ad_stack_top_adjoint(AdStackPtr stack,
int element_size) {
return mtl_ad_stack_top_primal(stack, element_size) + element_size;
}

inline void mtl_ad_stack_pop(AdStackPtr stack) {
thread auto &n = *mtl_ad_stack_n(stack);
--n;
}

void mtl_ad_stack_push(AdStackPtr stack, int element_size) {
thread auto &n = *mtl_ad_stack_n(stack);
++n;

AdStackPtr data = mtl_ad_stack_top_primal(stack, element_size);
for (int i = 0; i < element_size * 2; ++i) {
data[i] = 0;
}
}
// clang-format off
)
METAL_END_AD_STACK_DEF
// clang-format on

#undef METAL_BEGIN_AD_STACK_DEF
#undef METAL_END_AD_STACK_DEF

#include "taichi/backends/metal/shaders/epilog.h"
2 changes: 1 addition & 1 deletion taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ bool is_supported(Arch arch, Extension ext) {
{Arch::x64, {Extension::sparse, Extension::data64, Extension::adstack}},
{Arch::arm64, {Extension::sparse, Extension::data64, Extension::adstack}},
{Arch::cuda, {Extension::sparse, Extension::data64, Extension::adstack}},
{Arch::metal, {}},
{Arch::metal, {Extension::adstack}},
{Arch::opengl, {}},
};
// if (with_opengl_extension_data64())
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_ad_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def func():
assert x.grad[1] == 1


@ti.require(ti.extension.adstack)
@ti.require(ti.extension.adstack, ti.extension.data64)
@ti.all_archs_with(default_fp=ti.f64)
def test_ad_if_parallel_f64():
x = ti.var(ti.f64, shape=2)
Expand Down Expand Up @@ -169,7 +169,7 @@ def func():
assert x.grad[1] == -0.25


@ti.require(ti.extension.adstack)
@ti.require(ti.extension.adstack, ti.extension.data64)
@ti.all_archs_with(default_fp=ti.f64)
def test_ad_if_parallel_complex_f64():
x = ti.var(ti.f64, shape=2)
Expand Down