Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal] Add AD stack to support the local mutables when using AutoDiff #1257

Merged
merged 3 commits into from
Jun 17, 2020

Conversation

k-ye
Copy link
Member

@k-ye k-ye commented Jun 16, 2020

I noticed that Metal hasn't supported this yet. It seems like all I need to do is adding a small stack to support the complex AD kernels? I guess OpenGL can do similar things by simulating a per-thread stack, with each entry being a (primal, adjoint) pair (@archibate )

Related issue = #581

[Click here for the format server]


@k-ye k-ye changed the title [Metal] Add AD stack to support the local mutables [Metal] Add AD stack to support the local mutables when using AutoDiff Jun 16, 2020
@codecov
Copy link

codecov bot commented Jun 16, 2020

Codecov Report

❗ No coverage uploaded for pull request base (master@035bcb9). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff            @@
##             master    #1257   +/-   ##
=========================================
  Coverage          ?   65.65%           
=========================================
  Files             ?       36           
  Lines             ?     5066           
  Branches          ?      928           
=========================================
  Hits              ?     3326           
  Misses            ?     1569           
  Partials          ?      171           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 035bcb9...a088c3b. Read the comment docs.

Copy link
Collaborator

@archibate archibate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank for adding this, a few question about AD to think of OpenGL ones.

TI_ASSERT(stmt->width() == 1);

const auto &var_name = stmt->raw_name();
emit("byte {}[{}];", var_name, stmt->size_in_bytes());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we have to use byte array here? Then it might be hard for OpenGL which doesn't support pointers or unions. Or might lose 64-bit support when AD.

Copy link
Member Author

@k-ye k-ye Jun 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will almost always be used for f32, so maybe supporting that is good enough? However, note that in CUDA/Metal, the stack size is stored as an integer at the head of this byte array. I guess OpenGL can either maintain this integer separately, or just use a float for simplicity?

auto *stack = stmt->stack->as<StackAllocaStmt>();
const auto &stack_name = stack->raw_name();
const auto elem_size = stack->element_size_in_bytes();
emit("mtl_ad_stack_push({}, {});", stack_name, elem_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm thinking right, elem_size will always be 4 in Metal since it only support 32-bit data?

Copy link
Member Author

@k-ye k-ye Jun 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd say it's almost always, since it doesn't prevent us from using more complicated types (For example, one day we may support composite types like complex {f32, f32}?)
Metal only supports 32-bit primitive data types, but it's fine to have composite types.

Comment on lines 6 to 10
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF constexpr auto kMetalAdStackSourceCode =
#define METAL_END_RUNTIME_STRUCTS_DEF ;
#else
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF
#define METAL_END_RUNTIME_STRUCTS_DEF
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really ad-hoc and cool guard!

@archibate archibate self-assigned this Jun 16, 2020
Copy link
Contributor

@xumingkuan xumingkuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Just a nit.

#ifdef TI_INSIDE_METAL_CODEGEN

#ifndef TI_METAL_NESTED_INCLUDE
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF constexpr auto kMetalAdStackSourceCode =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Shall we use the TI_ prefix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR is no... This is a somewhat obscure mechanism to support both including these this code as C++ code, and stringifying them as real Metal kernel/shader code. Every code in taichi/backends/metal/shaders is following this style. The macros here are only meant to be used by the Metal backend, and not shared to a broader scope.

Copy link
Member

@yuanming-hu yuanming-hu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! In the future, we may need to consider supporting stacks with dynamic size as well.

@k-ye k-ye merged commit c8636bc into taichi-dev:master Jun 17, 2020
@k-ye k-ye deleted the ad branch June 17, 2020 13:42
@FantasyVR FantasyVR mentioned this pull request Jun 20, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants