Skip to content

Commit

Permalink
[CUDA] [AutoDiff] Fix CUDA data layout and stack alignment (#918)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu authored May 4, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 4d228b9 commit 89fc51e
Showing 7 changed files with 81 additions and 31 deletions.
7 changes: 4 additions & 3 deletions taichi/backends/cpu/jit_cpu.cpp
Original file line number Diff line number Diff line change
@@ -362,12 +362,13 @@ std::unique_ptr<JITSession> create_llvm_jit_session_cpu(Arch arch) {
TI_ERROR("LLVM TargetMachineBuilder has failed.");
jtmb = std::make_unique<JITTargetMachineBuilder>(std::move(*JTMB));

auto DL = jtmb->getDefaultDataLayoutForTarget();
if (!DL) {
auto data_layout = jtmb->getDefaultDataLayoutForTarget();
if (!data_layout) {
TI_ERROR("LLVM TargetMachineBuilder has failed when getting data layout.");
}

return std::make_unique<JITSessionCPU>(std::move(*jtmb), std::move(*DL));
return std::make_unique<JITSessionCPU>(std::move(*jtmb),
std::move(*data_layout));
}

TLANG_NAMESPACE_END
30 changes: 12 additions & 18 deletions taichi/backends/cuda/jit_cuda.cpp
Original file line number Diff line number Diff line change
@@ -76,9 +76,10 @@ class JITModuleCUDA : public JITModule {

class JITSessionCUDA : public JITSession {
public:
llvm::DataLayout DL;
llvm::DataLayout data_layout;

explicit JITSessionCUDA(llvm::DataLayout data_layout) : DL(data_layout) {
explicit JITSessionCUDA(llvm::DataLayout data_layout)
: data_layout(data_layout) {
}

virtual JITModule *add_module(std::unique_ptr<llvm::Module> M) override {
@@ -101,7 +102,7 @@ class JITSessionCUDA : public JITSession {
}

virtual llvm::DataLayout get_data_layout() override {
return DL;
return data_layout;
}

static std::string compile_module_to_ptx(
@@ -136,16 +137,16 @@ std::string convert(std::string new_name) {

std::string JITSessionCUDA::compile_module_to_ptx(
std::unique_ptr<llvm::Module> &module) {
// TODO: enabling this leads to LLVM error 'comdat global value has private
// linkage'
// Part of this function is borrowed from Halide::CodeGen_PTX_Dev.cpp
// TODO: enabling this leads to LLVM error "comdat global value has private
// linkage"
/*
if (llvm::verifyModule(*module, &llvm::errs())) {
module->print(llvm::errs(), nullptr);
TI_ERROR("Module broken");
}
*/

// Part of this function is borrowed from Halide::CodeGen_PTX_Dev.cpp
using namespace llvm;

for (auto &f : module->globals())
@@ -273,18 +274,11 @@ std::string JITSessionCUDA::compile_module_to_ptx(

std::unique_ptr<JITSession> create_llvm_jit_session_cuda(Arch arch) {
TI_ASSERT(arch == Arch::cuda);
// TODO: assuming CUDA has the same data layout as the host arch
std::unique_ptr<llvm::orc::JITTargetMachineBuilder> jtmb;
auto JTMB = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!JTMB)
TI_ERROR("LLVM TargetMachineBuilder has failed.");
jtmb = std::make_unique<llvm::orc::JITTargetMachineBuilder>(std::move(*JTMB));

auto DL = jtmb->getDefaultDataLayoutForTarget();
if (!DL) {
TI_ERROR("LLVM TargetMachineBuilder has failed when getting data layout.");
}
return std::make_unique<JITSessionCUDA>(DL.get());
// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#data-layout
auto data_layout = llvm::DataLayout(
"e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-"
"f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64");
return std::make_unique<JITSessionCUDA>(data_layout);
}
#else
std::unique_ptr<JITSession> create_llvm_jit_session_cuda(Arch arch) {
3 changes: 2 additions & 1 deletion taichi/jit/jit_session.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "taichi/jit/jit_session.h"

#include "llvm/IR/DataLayout.h"
#include "jit_session.h"

TLANG_NAMESPACE_BEGIN

10 changes: 5 additions & 5 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
@@ -1289,25 +1289,25 @@ void taichi_printf(LLVMRuntime *runtime, const char *format, Args &&... args) {
extern "C" { // local stack operations

Ptr stack_top_primal(Ptr stack, std::size_t element_size) {
auto n = *(i32 *)stack;
return stack + sizeof(i32) + (n - 1) * 2 * element_size;
auto n = *(u64 *)stack;
return stack + sizeof(u64) + (n - 1) * 2 * element_size;
}

Ptr stack_top_adjoint(Ptr stack, std::size_t element_size) {
return stack_top_primal(stack, element_size) + element_size;
}

void stack_init(Ptr stack) {
*(i32 *)stack = 0;
*(u64 *)stack = 0;
}

void stack_pop(Ptr stack) {
auto &n = *(i32 *)stack;
auto &n = *(u64 *)stack;
n--;
}

void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {
i32 &n = *(i32 *)stack;
u64 &n = *(u64 *)stack;
n += 1;
// TODO: assert n <= max_elements
std::memset(stack_top_primal(stack, element_size), 0, element_size * 2);
3 changes: 2 additions & 1 deletion taichi/transforms/make_adjoint.cpp
Original file line number Diff line number Diff line change
@@ -242,7 +242,8 @@ class MakeAdjoint : public IRVisitor {
accumulate(stmt->operand,
mul(adjoint(stmt), div(constant(0.5f), sqrt(stmt->operand))));
} else if (stmt->op_type == UnaryOpType::cast_value) {
if (is_real(stmt->cast_type) && is_real(stmt->operand->ret_type.data_type)) {
if (is_real(stmt->cast_type) &&
is_real(stmt->operand->ret_type.data_type)) {
accumulate(stmt->operand, adjoint(stmt));
}
} else if (stmt->op_type == UnaryOpType::logic_not) {
57 changes: 56 additions & 1 deletion tests/python/test_ad_if.py
Original file line number Diff line number Diff line change
@@ -111,6 +111,35 @@ def func():
assert x.grad[1] == 1


@ti.require(ti.extension.adstack)
@ti.all_archs_with(default_fp=ti.f64)
def test_ad_if_parallel_f64():
x = ti.var(ti.f64, shape=2)
y = ti.var(ti.f64, shape=2)

ti.root.lazy_grad()

@ti.kernel
def func():
for i in range(2):
t = x[i]
if t > 0:
y[i] = t
else:
y[i] = 2 * t

x[0] = 0
x[1] = 1
y.grad[0] = 1
y.grad[1] = 1

func()
func.grad()

assert x.grad[0] == 2
assert x.grad[1] == 1


@ti.require(ti.extension.adstack)
@ti.all_archs
def test_ad_if_parallel_complex():
@@ -140,7 +169,33 @@ def func():
assert x.grad[1] == -0.25


# TODO: test f64 stack
@ti.require(ti.extension.adstack)
@ti.all_archs_with(default_fp=ti.f64)
def test_ad_if_parallel_complex_f64():
x = ti.var(ti.f64, shape=2)
y = ti.var(ti.f64, shape=2)

ti.root.lazy_grad()

@ti.kernel
def func():
ti.parallelize(1)
for i in range(2):
t = 0.0
if x[i] > 0:
t = 1 / x[i]
y[i] = t

x[0] = 0
x[1] = 2
y.grad[0] = 1
y.grad[1] = 1

func()
func.grad()

assert x.grad[0] == 0
assert x.grad[1] == -0.25


@ti.host_arch_only
2 changes: 0 additions & 2 deletions tests/python/test_element_wise.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,6 @@ def func():
assert a[None][i, j] == approx(expected)



@ti.host_arch_only
def _test_matrix_element_wise_binary(dtype, n, m, ti_func, math_func):
a = ti.Matrix(n, m, dt=dtype, shape=())
@@ -96,7 +95,6 @@ def test_matrix_element_wise_binary():
_test_matrix_element_wise_binary(ti.i32, n, m, ti.raw_mod, _c_mod)



def test_matrix_element_wise_unary():
seed(233)
for n, m in [(5, 4), (3, 1)]:

0 comments on commit 89fc51e

Please sign in to comment.