Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] Support basic custom int/float types on metal #2145

Merged
merged 11 commits into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 220 additions & 29 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ std::string buffer_to_name(BuffersEnum b) {
return {};
}

bool is_ret_type_bit_pointer(Stmt *s) {
if (auto *ty = s->ret_type->cast<PointerType>()) {
// Don't use as() directly, it would fail when we inject a global tmp.
return ty->is_bit_pointer();
}
return false;
}

class KernelCodegen : public IRVisitor {
private:
enum class Section {
Expand Down Expand Up @@ -128,10 +136,17 @@ class KernelCodegen : public IRVisitor {
generate_kernels();

std::string source_code;
for (const auto s : kAllSections) {
source_code += section_appenders_.find(s)->second.lines();
source_code += '\n';
}
source_code += section_appenders_.at(Section::Headers).lines();

source_code += "namespace {\n";
source_code += section_appenders_.at(Section::Structs).lines();
source_code += section_appenders_.at(Section::KernelFuncs).lines();
source_code += "} // namespace\n";
source_code += section_appenders_.at(Section::Kernels).lines();
// for (const auto s : kAllSections) {
k-ye marked this conversation as resolved.
Show resolved Hide resolved
// source_code += section_appenders_.find(s)->second.lines();
// source_code += '\n';
// }
return source_code;
}

Expand Down Expand Up @@ -189,21 +204,6 @@ class KernelCodegen : public IRVisitor {
kRootBufferName);
}

void visit(GetChStmt *stmt) override {
// E.g. `parent.get*(runtime, mem_alloc)`
const auto get_call =
fmt::format("{}.get{}({}, {})", stmt->input_ptr->raw_name(), stmt->chid,
kRuntimeVarName, kMemAllocVarName);
if (stmt->output_snode->is_place()) {
emit(R"(device {}* {} = {}.val;)",
metal_data_type_name(stmt->output_snode->dt), stmt->raw_name(),
get_call);
} else {
emit(R"({} {} = {};)", stmt->output_snode->node_type_name,
stmt->raw_name(), get_call);
}
}

void visit(LinearizeStmt *stmt) override {
std::string val = "0";
for (int i = 0; i < (int)stmt->inputs.size(); i++) {
Expand All @@ -229,8 +229,29 @@ class KernelCodegen : public IRVisitor {
}
const auto *sn = stmt->snode;
const auto snty = sn->type;
if (snty == SNodeType::bit_struct) {
// Example *bit_struct* struct generated on Metal:
//
// struct Sx {
// // bit_struct
// Sx(device byte *b, ...) : base(b) {}
// device byte *base;
// };
emit("auto {} = {}.base;", stmt->raw_name(), parent);
return;
}
const std::string index_name = stmt->input_index->raw_name();

// Example SNode struct generated on Metal:
//
// struct S1 {
// // dense
// S1(device byte *addr, ...) { rep_.init(addr); }
// S1_ch children(int i) { return {rep_.addr() + (i * elem_stride)}; }
// inline void activate(int i) { rep_.activate(i); }
// ...
// private:
// SNodeRep_dense rep_;
// };
if (stmt->activate) {
TI_ASSERT(is_supported_sparse_type(snty));
emit("{}.activate({});", parent, index_name);
Expand All @@ -239,6 +260,32 @@ class KernelCodegen : public IRVisitor {
parent, index_name);
}

void visit(GetChStmt *stmt) override {
auto *in_snode = stmt->input_snode;
auto *out_snode = stmt->output_snode;
if (in_snode->type == SNodeType::bit_struct) {
TI_ASSERT(stmt->ret_type->as<PointerType>()->is_bit_pointer());
const auto *bit_struct_ty = in_snode->dt->cast<BitStructType>();
const auto bit_offset =
bit_struct_ty->get_member_bit_offset(in_snode->child_id(out_snode));
// stmt->input_ptr is the "base" member in the generated SNode struct.
emit("SNodeBitPointer {}({}, /*offset=*/{});", stmt->raw_name(),
stmt->input_ptr->raw_name(), bit_offset);
return;
}
// E.g. `parent.get*(runtime, mem_alloc)`
const auto get_call =
fmt::format("{}.get{}({}, {})", stmt->input_ptr->raw_name(), stmt->chid,
kRuntimeVarName, kMemAllocVarName);
if (out_snode->is_place()) {
emit(R"(device {}* {} = {}.val;)", metal_data_type_name(out_snode->dt),
stmt->raw_name(), get_call);
} else {
emit(R"({} {} = {};)", out_snode->node_type_name, stmt->raw_name(),
get_call);
}
}

void visit(SNodeOpStmt *stmt) override {
const std::string result_var = stmt->raw_name();
const auto opty = stmt->op_type;
Expand Down Expand Up @@ -292,13 +339,23 @@ class KernelCodegen : public IRVisitor {

void visit(GlobalStoreStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name());

if (!is_ret_type_bit_pointer(stmt->ptr)) {
emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name());
return;
}
handle_bit_pointer_global_store(stmt);
}

void visit(GlobalLoadStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
emit(R"({} {} = *{};)", metal_data_type_name(stmt->element_type()),
stmt->raw_name(), stmt->ptr->raw_name());
std::string rhs_expr;
if (!is_ret_type_bit_pointer(stmt->ptr)) {
rhs_expr = fmt::format("*{}", stmt->ptr->raw_name());
} else {
rhs_expr = construct_bit_pointer_global_load(stmt);
}
emit("const auto {} = {};", stmt->raw_name(), rhs_expr);
}

void visit(ArgLoadStmt *stmt) override {
Expand Down Expand Up @@ -457,7 +514,6 @@ class KernelCodegen : public IRVisitor {

void visit(AtomicOpStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
const auto dt = stmt->val->element_type();
const auto op_type = stmt->op_type;
std::string op_name;
bool handle_float = false;
Expand All @@ -475,6 +531,11 @@ class KernelCodegen : public IRVisitor {
TI_NOT_IMPLEMENTED;
}

if (is_ret_type_bit_pointer(stmt->dest)) {
handle_bit_pointer_atomics(stmt);
return;
}

std::string val_var = stmt->val->raw_name();
// TODO(k-ye): This is not a very reliable way to detect if we're in TLS
// xlogues...
Expand All @@ -488,7 +549,7 @@ class KernelCodegen : public IRVisitor {
emit("if ({} == 0) {{", kKernelTidInSimdgroupName);
current_appender().push_indent();
}

const auto dt = stmt->val->element_type();
if (dt->is_primitive(PrimitiveTypeID::i32)) {
emit(
"const auto {} = atomic_fetch_{}_explicit((device atomic_int*){}, "
Expand Down Expand Up @@ -626,9 +687,11 @@ class KernelCodegen : public IRVisitor {
if (std::holds_alternative<Stmt *>(entry)) {
auto *arg_stmt = std::get<Stmt *>(entry);
const auto dt = arg_stmt->element_type();
TI_ASSERT_INFO(dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::f32),
"print() only supports i32 or f32 scalars for now.");
TI_ASSERT_INFO(
dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::u32) ||
dt->is_primitive(PrimitiveTypeID::f32),
"print() only supports i32, u32 or f32 scalars for now.");
emit("{}.pm_set_{}({}, {});", msg_var_name, data_type_name(dt), i,
arg_stmt->raw_name());
} else {
Expand Down Expand Up @@ -773,6 +836,133 @@ class KernelCodegen : public IRVisitor {
emit_kernel_args_struct();
}

void handle_bit_pointer_global_store(GlobalStoreStmt *stmt) {
auto *ptr_type = stmt->ptr->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
CustomIntType *cit = nullptr;
std::string store_value_expr;
if (auto *cit_cast = pointee_type->cast<CustomIntType>()) {
cit = cit_cast;
store_value_expr = stmt->data->raw_name();
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
auto *digits_cit = cft->get_digits_type()->as<CustomIntType>();
cit = digits_cit;
store_value_expr = construct_float_to_custom_int_expr(
stmt->data, cft->get_scale(), digits_cit);
} else {
TI_NOT_IMPLEMENTED;
}
// Type of |stmt->ptr| is SNodeBitPointer
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
emit("mtl_set_full_bits({}, {});", stmt->ptr->raw_name(),
store_value_expr);
} else {
emit("mtl_set_partial_bits({},", stmt->ptr->raw_name());
emit(" {},", store_value_expr);
emit(" /*bits=*/{});", num_bits);
}
}

// Returns the expression of the load result
std::string construct_bit_pointer_global_load(GlobalLoadStmt *stmt) const {
auto *ptr_type = stmt->ptr->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
if (auto *cit = pointee_type->cast<CustomIntType>()) {
return construct_load_as_custom_int(stmt->ptr, cit);
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
const auto loaded = construct_load_as_custom_int(
stmt->ptr, cft->get_digits_type()->as<CustomIntType>());
// Computes `float(digits_expr) * scale`
// See LLVM backend's reconstruct_custom_float()
return fmt::format("(static_cast<float>({}) * {})", loaded,
cft->get_scale());
}
TI_NOT_IMPLEMENTED;
return "";
}

void handle_bit_pointer_atomics(AtomicOpStmt *stmt) {
TI_ERROR_IF(stmt->op_type != AtomicOpType::add,
"Only atomic add is supported for bit pointer types");
// Type of |dest_ptr| is SNodeBitPointer
const auto *dest_ptr = stmt->dest;
auto *ptr_type = dest_ptr->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
CustomIntType *cit = nullptr;
std::string val_expr;
if (auto *cit_cast = pointee_type->cast<CustomIntType>()) {
cit = cit_cast;
val_expr = stmt->val->raw_name();
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
cit = cft->get_digits_type()->as<CustomIntType>();
val_expr =
construct_float_to_custom_int_expr(stmt->val, cft->get_scale(), cit);
} else {
TI_NOT_IMPLEMENTED;
}
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
emit("const auto {} = mtl_atomic_add_full_bits({}, {});",
stmt->raw_name(), dest_ptr->raw_name(), val_expr);
} else {
emit("const auto {} = mtl_atomic_add_partial_bits({},", stmt->raw_name(),
dest_ptr->raw_name());
emit(" {},", val_expr);
emit(" /*bits=*/{});", num_bits);
}
}

// Returns the expression of `int(val_stmt * (1.0f / scale) + 0.5f)`
std::string construct_float_to_custom_int_expr(
const Stmt *val_stmt,
float64 scale,
CustomIntType *digits_cit) const {
DataType compute_dt(digits_cit->get_compute_type()->as<PrimitiveType>());
// This implicitly casts double to float on the host.
const float inv_scale = 1.0 / scale;
// Creating an expression (instead of holding intermediate results with
// variables) because |val_stmt| could be used multiple times. If the
// intermediate variables are named based on |val_stmt|, it would result in
// symbol redefinitions.
return fmt::format("mtl_float_to_custom_int<{}>(/*inv_scale=*/{} * {})",
metal_data_type_name(compute_dt), inv_scale,
val_stmt->raw_name());
}

// Returns expression of the loaded integer.
std::string construct_load_as_custom_int(const Stmt *bit_ptr_stmt,
CustomIntType *cit) const {
DataType compute_dt(cit->get_compute_type()->as<PrimitiveType>());
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
return fmt::format("mtl_get_full_bits<{}>({})",
metal_data_type_name(compute_dt),
bit_ptr_stmt->raw_name());
}
return fmt::format("mtl_get_partial_bits<{}>({}, {})",
metal_data_type_name(compute_dt),
bit_ptr_stmt->raw_name(), num_bits);
}

void validate_cft_for_metal(CustomFloatType *cft) const {
if (cft->get_exponent_type() != nullptr) {
TI_NOT_IMPLEMENTED;
}
if (cft->get_compute_type()->as<PrimitiveType>() != PrimitiveType::f32) {
TI_ERROR("Metal only supports 32-bit float");
}
}

static bool is_full_bits(int bits) {
return bits == (sizeof(uint32_t) * 8);
}

void emit_kernel_args_struct() {
if (ctx_attribs_.empty()) {
return;
Expand Down Expand Up @@ -924,7 +1114,8 @@ class KernelCodegen : public IRVisitor {
emit("const int {} = {} - {};", total_elems_name, end_expr, begin_expr);
ka.advisory_total_num_threads = kMaxNumThreadsGridStrideLoop;
}
ka.advisory_num_threads_per_group = stmt->block_dim;
// ka.advisory_num_threads_per_group = stmt->block_dim;
ka.advisory_num_threads_per_group = 1024;
k-ye marked this conversation as resolved.
Show resolved Hide resolved
// begin_ = thread_id + begin_expr
emit("const int begin_ = {} + {};", kKernelThreadIdName, begin_expr);
// end_ = total_elems + begin_expr
Expand Down
Loading