From 3d40de9c3981209912525a26915b7cb615af4340 Mon Sep 17 00:00:00 2001 From: Mei Ye Date: Thu, 18 May 2023 21:39:06 +0000 Subject: [PATCH] [Vulkan] Add cooperative matrix support Add SPIR-V code generation for "SPV_NV_cooperative_matrix" extension. Add a matrix multiplicaiton unit test. --- python/tvm/target/target.py | 21 ++ src/runtime/vulkan/vulkan_device.cc | 5 +- src/runtime/vulkan/vulkan_device.h | 1 + src/runtime/vulkan/vulkan_device_api.cc | 4 + src/target/source/codegen_cuda.cc | 23 +-- src/target/spirv/codegen_spirv.cc | 193 ++++++++++++++++-- src/target/spirv/codegen_spirv.h | 14 ++ src/target/spirv/ir_builder.cc | 68 +++++- src/target/spirv/ir_builder.h | 25 ++- src/target/spirv/spirv_support.cc | 4 + src/target/spirv/spirv_support.h | 14 ++ src/target/target_kind.cc | 1 + src/tir/transforms/ir_utils.cc | 30 +++ src/tir/transforms/ir_utils.h | 8 + .../unittest/test_target_codegen_vulkan.py | 128 ++++++++++++ 15 files changed, 495 insertions(+), 44 deletions(-) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 06e1776965c2..fce9f3e6becc 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -173,11 +173,25 @@ def max_num_threads(self): """Returns the max_num_threads from the target if it exists.""" return int(self.attrs["max_num_threads"]) + @property + def max_block_size_x(self): + """Returns the max block size in x-dimension from the target if it exists.""" + return int(self.attrs["max_block_size_x"]) + + @property + def max_block_size_y(self): + """Returns the max block size in y-dimension from the target if it exists.""" + return int(self.attrs["max_block_size_y"]) + @property def thread_warp_size(self): """Returns the thread_warp_size from the target if it exists.""" return int(self.attrs["thread_warp_size"]) + @property + def max_shared_memory_per_block(self): + return int(self.attrs["max_shared_memory_per_block"]) + @property def max_function_args(self): return int(self.attrs.get("max_function_args", -1)) @@ -219,6 +233,13 @@ def supports_integer_dot_product(self): def libs(self): return list(self.attrs.get("libs", [])) + @property + def supports_cooperative_matrix(self): + if self.attrs.get("supports_cooperative_matrix", []): + return bool(self.attrs["supports_cooperative_matrix"]) + else: + return False + @property def features(self): return TargetFeatures(self) diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index b3e017d03418..dfc8034c85c3 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -134,6 +134,8 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, supports_integer_dot_product = device.HasExtension("VK_KHR_shader_integer_dot_product"); + supports_cooperative_matrix = device.HasExtension("VK_NV_cooperative_matrix"); + // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically // needed, since it will be set so long at least one queue has // VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future @@ -435,7 +437,8 @@ std::vector VulkanDevice::SelectEnabledExtensions() const { "VK_KHR_get_memory_requirements2", "VK_KHR_dedicated_allocation", "VK_KHR_spirv_1_4", - "VK_KHR_shader_integer_dot_product"}; + "VK_KHR_shader_integer_dot_product", + "VK_NV_cooperative_matrix"}; uint32_t device_extension_prop_count; VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 59ebf430e6e6..296483a6b104 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -88,6 +88,7 @@ struct VulkanDeviceProperties { bool supports_push_descriptor{false}; bool supports_dedicated_allocation{false}; bool supports_integer_dot_product{false}; + bool supports_cooperative_matrix{false}; uint32_t supported_subgroup_operations{0}; uint32_t max_num_threads{1}; uint32_t thread_warp_size{1}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 93f017a5aa66..108741525602 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -241,6 +241,10 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, *rv = prop.supports_integer_dot_product; } + if (property == "supports_cooperative_matrix") { + *rv = prop.supports_cooperative_matrix; + } + if (property == "device_name") { *rv = prop.device_name; } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d2131c522e38..aaf66601723d 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -33,6 +33,7 @@ #include #include +#include "../../tir/transforms/ir_utils.h" #include "literal/cuda_half_t.h" #include "ptx.h" @@ -1333,23 +1334,11 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " << variable->name_hint; std::string shape_str = fragment_shapes.at(variable); - size_t m, n, k; - size_t last_pos = 0, pos = 0; - pos = shape_str.find(", ", last_pos); - m = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos)); - last_pos = pos + 2; - pos = shape_str.find(", ", last_pos); - n = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos)); - last_pos = pos + 2; - k = tvm::codegen::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); - if (scope == "wmma.matrix_a") { - return size / m / k; - } else if (scope == "wmma.matrix_b") { - return size / n / k; - } else if (scope == "wmma.accumulator") { - return size / m / n; - } - return 0; + std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); + if (dim.first * dim.second != 0) + return size / dim.first / dim.second; + else + return 0; } void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 2a4233b44bcf..b1fd0171910a 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -129,6 +129,7 @@ void CodeGenSPIRV::InitFuncState() { builder_.reset(new spirv::IRBuilder(spirv_support_)); builder_->InitHeader(); shared_memory_bytes_used_ = 0; + fragment_info_.clear(); } spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { @@ -394,6 +395,120 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { LOG(FATAL) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" << Downcast(op->args[0]) << "\""; return spirv::Value(); + } else if (op->op.same_as(builtin::tvm_fill_fragment())) { + ICHECK_EQ(op->args.size(), 6U); + const VarNode* buffer_node = op->args[0].as(); + ICHECK(buffer_node && fragment_info_.count(buffer_node)); + DataType ele_dtype = GetElementDataType(buffer_node); + ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported"; + spirv::SType ele_stype = builder_->GetSType(ele_dtype); + spirv::SType& fragment_type = fragment_info_[buffer_node].stype; + double init = static_cast(Downcast(op->args[5])->value); + PrimExpr prim_index = op->args[4]; + spirv::Value init_val = builder_->GetCompositeConst(ele_stype, fragment_type, init); + spirv::SType ptr_type = + builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass); + spirv::Value index = MakeValue(prim_index); + ICHECK(var_map_.count(buffer_node)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], index); + builder_->MakeInst(spv::OpStore, ptr, init_val, spv::MemoryAccessMaskNone); + return spirv::Value(); + + } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { + ICHECK_EQ(op->args.size(), 8U); + const VarNode* buffer_node = op->args[0].as(); + ICHECK(buffer_node && fragment_info_.count(buffer_node)); + spirv::SType& fragment_type = fragment_info_[buffer_node].stype; + PrimExpr dst_index = op->args[4]; + PrimExpr src_ptr_expr = op->args[5]; + int stride = static_cast(Downcast(op->args[6])->value); + auto type_int = builder_->GetSType(DataType::Int(32)); + spirv::Value stride_val = builder_->IntImm(type_int, stride); + std::string layout = (op->args[7].as())->value; + spirv::SType dst_ptr_type = + builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass); + spirv::Value dst_ptr = + builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index)); + spirv::Value src_ptr = VisitExpr(op->args[5]); + spirv::SType type_bool = builder_->GetSType(DataType::UInt(1)); + spirv::Value t_val = builder_->UIntImm(type_bool, 1); + spirv::Value f_val = builder_->UIntImm(type_bool, 0); + spirv::Value loaded = + builder_->MakeValue(spv::OpCooperativeMatrixLoadNV, fragment_type, src_ptr, stride_val, + (layout != "row_major") ? t_val : f_val); + builder_->MakeInst(spv::OpStore, dst_ptr, loaded, spv::MemoryAccessMaskNone); + return spirv::Value(); + } else if (op->op.same_as(builtin::tvm_mma_sync())) { + const VarNode* buffer_d = op->args[0].as(); + const VarNode* buffer_a = op->args[2].as(); + const VarNode* buffer_b = op->args[4].as(); + const VarNode* buffer_c = op->args[6].as(); + PrimExpr index_d = op->args[1]; + PrimExpr index_a = op->args[3]; + PrimExpr index_b = op->args[5]; + tvm::tir::ExprDeepEqual expr_equal; + PrimExpr index_c = op->args[7]; + bool is_equal = ((buffer_d == buffer_c) && expr_equal(index_d, index_c)); + spirv::SType& fragment_type_d = fragment_info_[buffer_d].stype; + spirv::SType& fragment_type_a = fragment_info_[buffer_a].stype; + spirv::SType& fragment_type_b = fragment_info_[buffer_b].stype; + spirv::SType& fragment_type_c = fragment_info_[buffer_c].stype; + spv::StorageClass storage = fragment_info_[buffer_d].sclass; + spirv::SType ptr_type_d = builder_->GetPointerType(fragment_type_d, storage); + spirv::SType ptr_type_a = builder_->GetPointerType(fragment_type_a, storage); + spirv::SType ptr_type_b = builder_->GetPointerType(fragment_type_b, storage); + spirv::SType ptr_type_c = builder_->GetPointerType(fragment_type_c, storage); + spirv::Value ptr_d = + builder_->StructArrayAccess(ptr_type_d, var_map_[buffer_d], MakeValue(index_d)); + spirv::Value ptr_a = + builder_->StructArrayAccess(ptr_type_a, var_map_[buffer_a], MakeValue(index_a)); + spirv::Value ptr_b = + builder_->StructArrayAccess(ptr_type_b, var_map_[buffer_b], MakeValue(index_b)); + spirv::Value ptr_c = + is_equal ? ptr_d + : builder_->StructArrayAccess(ptr_type_c, var_map_[buffer_c], MakeValue(index_c)); + uint32_t mask = spv::MemoryAccessMaskNone; + spirv::Value loaded_a = builder_->MakeValue(spv::OpLoad, fragment_type_a, ptr_a, mask); + spirv::Value loaded_b = builder_->MakeValue(spv::OpLoad, fragment_type_b, ptr_b, mask); + spirv::Value loaded_c = builder_->MakeValue(spv::OpLoad, fragment_type_c, ptr_c, mask); + spirv::Value result = builder_->MakeValue(spv::OpCooperativeMatrixMulAddNV, fragment_type_d, + loaded_a, loaded_b, loaded_c); + builder_->MakeInst(spv::OpStore, ptr_d, result, spv::MemoryAccessMaskNone); + return spirv::Value(); + } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { + ICHECK_EQ(op->args.size(), 8U); + const VarNode* buffer_node = op->args[0].as(); + PrimExpr index = op->args[4]; + PrimExpr buffer_ptr = op->args[5]; + int stride = static_cast(Downcast(op->args[6])->value); + auto type_int = builder_->GetSType(DataType::Int(32)); + spirv::Value stride_val = builder_->IntImm(type_int, stride); + std::string layout = (op->args[7].as())->value; + spirv::Value dst_ptr = VisitExpr(op->args[5]); + spirv::SType& fragment_type = fragment_info_[buffer_node].stype; + spv::StorageClass storage = fragment_info_[buffer_node].sclass; + spirv::SType ptr_type = builder_->GetPointerType(fragment_type, storage); + spirv::Value ptr = + builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); + uint32_t mask = spv::MemoryAccessMaskNone; + spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask); + spirv::SType type_bool = builder_->GetSType(DataType::UInt(1)); + spirv::Value t_val = builder_->UIntImm(type_bool, 1); + spirv::Value f_val = builder_->UIntImm(type_bool, 0); + builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val, + (layout != "row_major") ? t_val : f_val); + return spirv::Value(); + } else if (op->op.same_as(builtin::address_of())) { + const BufferLoadNode* load = op->args[0].as(); + Var buffer_var = load->buffer->data; + const VarNode* buffer_node = buffer_var.get(); + PrimExpr index = load->indices[0]; + DataType ele_dtype = GetElementDataType(buffer_node); + spirv::SType ele_stype = builder_->GetSType(ele_dtype); + spirv::Value buffer_val = MakeValue(buffer_var); + spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class); + ICHECK(var_map_.count(buffer_node)); + return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); } else { LOG(FATAL) << "Unresolved call " << op->op; } @@ -657,22 +772,46 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + const std::string scope = GetPtrStorageScope(op->buffer_var); + auto storage_scope = runtime::StorageScope::Create(scope); spirv::SType etype = builder_->GetSType(op->dtype); - if (storage_scope.rank == runtime::StorageRank::kLocal) { - buf = - builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); - } else if (storage_scope.rank == runtime::StorageRank::kShared) { - // Shared memory - // Aligned on 4-byte boundary - int32_t aligned_constant_size = ((constant_size + 3) & ~0x3); - buf = builder_->Allocate(etype, static_cast(aligned_constant_size), - spv::StorageClassWorkgroup); - - size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast(constant_size); - shared_memory_bytes_used_ += num_bytes; - } else { - LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; + runtime::StorageRank rank = storage_scope.rank; + spv::StorageClass storage_class; + const VarNode* var_node = (op->buffer_var).get(); + + switch (rank) { + case runtime::StorageRank::kWMMAMatrixA: + case runtime::StorageRank::kWMMAMatrixB: + case runtime::StorageRank::kWMMAAccumulator: { + ICHECK(fragment_info_.count(var_node)); + fragment_info_[var_node].scope = scope; + etype = GetFragmentSType(var_node, op->dtype); + storage_class = spv::StorageClassFunction; + fragment_info_[var_node].sclass = storage_class; + ICHECK(fragment_info_.count(var_node)); + const std::string& scope = fragment_info_[var_node].scope; + const std::string& shape_str = fragment_info_.at(var_node).shape; + std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); + int64_t size = dim.first * dim.second; + buf = builder_->Allocate(etype, static_cast(constant_size) / size, storage_class); + } break; + case runtime::StorageRank::kLocal: { + storage_class = spv::StorageClassFunction; + buf = builder_->Allocate(etype, static_cast(constant_size), storage_class); + } break; + case runtime::StorageRank::kShared: { + storage_class = spv::StorageClassWorkgroup; + // Shared memory + // Aligned on 4-byte boundary + int32_t aligned_constant_size = ((constant_size + 3) & ~0x3); + buf = builder_->Allocate(etype, static_cast(aligned_constant_size), storage_class); + + size_t num_bytes = + op->dtype.bytes() * op->dtype.lanes() * static_cast(aligned_constant_size); + shared_memory_bytes_used_ += num_bytes; + } break; + default: + LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; } builder_->SetName(buf, op->buffer_var->name_hint); @@ -700,6 +839,13 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { const VarNode* v = op->node.as(); ICHECK(v); storage_info_[v].is_volatile = true; + } else if (op->attr_key == tir::attr::buffer_bind_scope) { + const VarNode* v = op->node.as(); + ICHECK(v); + } else if (op->attr_key == tir::attr::fragment_shape) { + const VarNode* buffer = op->node.as(); + const StringImmNode* shape_str = op->value.as(); + fragment_info_[buffer] = {shape_str->value}; } this->VisitStmt(op->body); } @@ -725,5 +871,22 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } +spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const DataType& dtype) { + ICHECK(fragment_info_.count(buffer)); + const std::string& scope = fragment_info_[buffer].scope; + const std::string& shape_str = fragment_info_.at(buffer).shape; + std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); + int64_t size = dim.first * dim.second; + spirv::SType stype = builder_->GetSType(dtype.with_lanes(size), dim.first, dim.second); + fragment_info_[buffer].stype = stype; + return stype; +} + +DataType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) { + auto it = storage_info_.find(buffer); + ICHECK(it != storage_info_.end()); + return it->second.element_type; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index f2d771070ed9..3a0336120a8f 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include "../../runtime/spirv/spirv_shader.h" @@ -171,6 +172,14 @@ class CodeGenSPIRV : public ExprFunctor, element_type_known = true; } }; + + struct FragmentInfo { + std::string shape; + std::string scope; + spirv::SType stype; + spv::StorageClass sclass; + }; + // Reset the state so it works for a new function. void InitFuncState(); // Get the thread index @@ -179,6 +188,9 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value CreateStorageSync(const CallNode* op); void Scalarize(const PrimExpr& e, std::function f); + spirv::SType GetFragmentSType(const VarNode* buffer, const DataType& dtype); + DataType GetElementDataType(const VarNode* buffer); + // SPIRV-related capabilities of the target SPIRVSupport spirv_support_; @@ -218,6 +230,8 @@ class CodeGenSPIRV : public ExprFunctor, // Running total of the number of bytes of shared memory used. // Checked against the max_shared_memory_per_group size_t shared_memory_bytes_used_{0}; + + std::unordered_map fragment_info_; }; } // namespace codegen diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 46c9c5869c79..545e677af9f2 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -60,6 +60,11 @@ void IRBuilder::InitHeader() { } #endif + if (spirv_support_.supports_cooperative_matrix) { + capabilities_used_.insert(spv::CapabilityCooperativeMatrixNV); + extensions_used_.insert("SPV_NV_cooperative_matrix"); + } + // memory model ib_.Begin(spv::OpMemoryModel) .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) @@ -74,6 +79,7 @@ void IRBuilder::InitPreDefs() { t_bool_ = DeclareType(DataType::UInt(1)); t_fp32_ = DeclareType(DataType::Float(32)); const_i32_zero_ = IntImm(t_int32_, 0); + // declare void, and void functions t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); @@ -106,7 +112,7 @@ std::vector IRBuilder::Finalize() { return data; } -SType IRBuilder::GetSType(const DataType& dtype) { +SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { if (dtype == DataType::Int(32)) { return t_int32_; } else if (dtype == DataType::UInt(1)) { @@ -116,15 +122,22 @@ SType IRBuilder::GetSType(const DataType& dtype) { } else if (dtype == DataType::UInt(32)) { return t_uint32_; } - uint32_t type_key; + uint64_t type_key; type_key = static_cast(dtype.code()); type_key |= static_cast(dtype.bits()) << 8U; - type_key |= static_cast(dtype.lanes()) << 16U; + if (row * col == 0) { + ICHECK((row == 0) && (col == 0)); + type_key |= static_cast(dtype.lanes()) << 16U; + } else { + type_key |= static_cast(row) << 32U; + type_key |= static_cast(col) << 40U; + } + auto it = pod_type_tbl_.find(type_key); if (it != pod_type_tbl_.end()) { return it->second; } - SType t = DeclareType(dtype); + SType t = DeclareType(dtype, row, col); pod_type_tbl_[type_key] = t; return t; } @@ -221,7 +234,13 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { return GetConst_(dtype, &data); } else { ICHECK_EQ(dtype.type.bits(), 16); - return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); + float fvalue = static_cast(value); + uint32_t* ptr = reinterpret_cast(&fvalue); + uint64_t data = ptr[0]; + if (data == 0) + return GetConst_(dtype, &data); + else + return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); } } @@ -475,7 +494,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { return ret; } -SType IRBuilder::DeclareType(const DataType& dtype) { +SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) { AddCapabilityFor(dtype); if (dtype.lanes() == 1) { @@ -500,7 +519,18 @@ SType IRBuilder::DeclareType(const DataType& dtype) { t.id = id_counter_++; t.type = dtype; SType base_type = GetSType(dtype.element_of()); - ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); + + if (row * col == 0) { + ICHECK((row == 0) && (col == 0)); + ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); + } else { + Value v_row = GetSpecConst(GetSType(DataType::UInt(32)), row); + Value v_col = GetSpecConst(GetSType(DataType::UInt(32)), col); + Value scope = UIntImm(GetSType(DataType::UInt(32)), spv::ScopeSubgroup); + ib_.Begin(spv::OpTypeCooperativeMatrixNV) + .AddSeq(t, base_type, scope, v_row, v_col) + .Commit(&global_); + } return t; } } @@ -727,6 +757,30 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } } +Value IRBuilder::GetCompositeConst(const SType& ele_stype, const SType& composite_stype, + const double dval) { + auto key = std::make_pair(composite_stype.id, dval); + auto it = composite_const_tbl_.find(key); + if (it != composite_const_tbl_.end()) { + return it->second; + } + spirv::Value const_val = FloatImm(ele_stype, dval); + Value new_val = NewValue(composite_stype, kNormal); + ib_.Begin(spv::OpConstantComposite).AddSeq(composite_stype, new_val, const_val); + ib_.Commit(&global_); + composite_const_tbl_[key] = new_val; + return new_val; +} + +Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) { + ICHECK_LE(dtype.type.bits(), 32); + Value ret = NewValue(dtype, kSpecConst); + ib_.Begin(spv::OpSpecConstant).AddSeq(dtype, ret); + ib_.Add(static_cast(value)); + ib_.Commit(&global_); + return ret; +} + #define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ ICHECK_EQ(a.stype.id, b.stype.id); \ diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index d642484532f9..e92e8364ee1b 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -65,7 +65,8 @@ enum ValueKind { kPushConstantPtr, kFunction, kExtInst, - kUniformPtr + kUniformPtr, + kSpecConst, }; /*! \brief Represent the SPIRV Value */ @@ -443,7 +444,7 @@ class IRBuilder { * \param dtype The data type. * \return The corresponding spirv type. */ - SType GetSType(const tvm::DataType& dtype); + SType GetSType(const tvm::DataType& dtype, uint32_t row = 0, uint32_t col = 0); /*! * \brief Get the pointer type that points to value_type * \param value_type. @@ -592,6 +593,19 @@ class IRBuilder { Value GT(Value a, Value b); Value GE(Value a, Value b); Value Select(Value cond, Value a, Value b); + /* + * \brief Get composite constant + * \param ele_stype The value type of elements in the composite. + * \param composite_type The value type of the composite. + * \param dval The initial value for all elements in the composite. + */ + Value GetCompositeConst(const SType& ele_stype, const SType& composite_stype, double dval); + /* + * Get specialization constant + * \param dtype The content value type + * \param value The default value + */ + Value GetSpecConst(const SType& dtype, uint64_t value); private: /*! @@ -640,8 +654,9 @@ class IRBuilder { // get constant given value encoded in uint64_t Value GetConst_(const SType& dtype, const uint64_t* pvalue); + // declare type - SType DeclareType(const DataType& dtype); + SType DeclareType(const DataType& dtype, uint32_t row = 0, uint32_t col = 0); // Declare the appropriate SPIR-V capabilities and extensions to use // this data type. @@ -696,13 +711,15 @@ class IRBuilder { /*! \brief whether push constant is defined */ Value push_const_; /*! \brief map from type code to the type */ - std::unordered_map pod_type_tbl_; + std::unordered_map pod_type_tbl_; /*! \brief map from value to array type */ std::map, SType> struct_array_type_tbl_; /*! \brief map from value to its pointer type */ std::map, SType> pointer_type_tbl_; /*! \brief map from constant int to its value */ std::map, Value> const_tbl_; + /*! \brief map from floating point composite constant to its value */ + std::map, Value> composite_const_tbl_; /*! \brief map from name of a ExtInstImport to its value */ std::map ext_inst_tbl_; diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 1b46e7f08339..a17a694da4dd 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -102,6 +102,10 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { } } } + // Check whether cooperative matrix is enabled in the target string. + if (target->GetAttr("supports_cooperative_matrix")) { + supports_cooperative_matrix = target->GetAttr("supports_cooperative_matrix").value(); + } } } // namespace codegen diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h index 6365e576b8cf..83f92595112e 100644 --- a/src/target/spirv/spirv_support.h +++ b/src/target/spirv/spirv_support.h @@ -276,6 +276,20 @@ struct SPIRVSupport { * attempting to perform integer dot product. */ bool supports_integer_dot_product{false}; + + /*! + * \brief Whether the driver supports operations involving cooperative matrix. + * + * Vulkan extension: VK_NV_cooperative_matrix + * SPV Extension name: SPV_NV_cooperative_matrix + * SPV Capability: spv::CapabilityCooperativeMatrixNV + * + * If support is present, can perform cooperative matrix operations. If + * support is not present, codegen will throw exception on + * attempting to perform cooperative matrix. + */ + + bool supports_cooperative_matrix{false}; }; } // namespace codegen diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 3a555e304cb0..3c4e885ef9b5 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -378,6 +378,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supports_push_descriptor") .add_attr_option("supports_dedicated_allocation") .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix") .add_attr_option("supported_subgroup_operations") // Physical device limits .add_attr_option("max_num_threads", Integer(256)) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index b798f981f7ad..b3829529eecf 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -648,6 +648,36 @@ CollectStorageAlignAnnotation(const Stmt& body) { return std::move(collector.storage_align_); } +int Stoi(const std::string& str) { + try { + return std::stoi(str); + } catch (std::invalid_argument& e) { + LOG(FATAL) << "Cannot convert \"" << str << "\" to int"; + throw; + } +} + +std::pair GetWmmaFragmentDimSize(const std::string& shape_str, + const std::string& scope) { + size_t m, n, k; + size_t last_pos = 0, pos = 0; + pos = shape_str.find(", ", last_pos); + m = Stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + pos = shape_str.find(", ", last_pos); + n = Stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + k = Stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); + if (scope == "wmma.matrix_a") { + return std::pair(m, k); + } else if (scope == "wmma.matrix_b") { + return std::pair(k, n); + } else if (scope == "wmma.accumulator") { + return std::pair(m, n); + } + return std::pair(0, 0); +} + namespace transform { Pass ConvertSSA() { auto pass_func = [](IRModule mod, PassContext ctx) { diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index afaff3447233..59dc95dcd6a0 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -342,6 +342,14 @@ using StorageAlignAnnotation = Array; */ std::unordered_map CollectStorageAlignAnnotation(const Stmt& body); +/*! + * \brief Split string separated by "," to get wmma fragment dimension size. + * \param shape_str The string to split. + * \param scope The scope to match. + * \return The result pair of fragment dimension size. + */ +std::pair GetWmmaFragmentDimSize(const std::string& shape_str, + const std::string& scope); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index bfb10ca85a38..7057ff840637 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -29,6 +29,17 @@ from tvm import relay, te from tvm.topi.math import cast from tvm.script import tir as T +from tvm.tir import TensorIntrin, IntImm, Cast, Schedule +from tvm.tir.tensor_intrin.cuda import ( + WMMA_LOAD_16x16x16_F16_A_INTRIN, + WMMA_LOAD_16x16x16_F16_B_INTRIN, + WMMA_SYNC_16x16x16_f16f16f32_INTRIN, + WMMA_FILL_16x16x16_F32_INTRIN, + WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, + WMMA_SYNC_16x16x16_f16f16f16_INTRIN, + WMMA_FILL_16x16x16_F16_INTRIN, + WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, +) dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") @@ -600,5 +611,122 @@ def func(A: T.Buffer((N, 2), "int32")): np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor) +@pytest.mark.parametrize("out_dtype", ["float32", "float16"]) +def test_cooperative_matrix(out_dtype): + def get_matmul(m, n, k, out_dtype="float32"): + X = te.placeholder((m, k), name="X", dtype="float16") + W = te.placeholder((k, n), name="W", dtype="float16") + ak = te.reduce_axis((0, k), name="k") + + if out_dtype == "float32": + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("float32") * W[ak, j].astype("float32"), + axis=ak, + ), + name="compute", + ) + else: + matmul = te.compute( + (m, n), + lambda i, j: te.sum(X[i, ak] * W[ak, j], axis=ak), + name="compute", + ) + + return te.create_prim_func([X, W, matmul]) + + M, N, K = 16, 16, 32 + func = get_matmul(M, N, K, out_dtype) + sch = Schedule(func) + block = sch.get_block("compute") + + i, j, k = sch.get_loops(block) + i_outer, i_inner = sch.split(i, factors=[None, 16]) + j_outer, j_inner = sch.split(j, factors=[None, 16]) + k_outer, k_inner = sch.split(k, factors=[None, 16]) + sch.reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner) + fused_outer = sch.fuse(i_outer, j_outer) + sch.bind(fused_outer, "blockIdx.x") + + def fetch_to_shared(block, idx): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, k_outer) + warp_size = 32 + + fused = sch.fuse(*sch.get_loops(block_read)[-2:]) + + vector_size = 4 + _, f_2, f_3 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.vectorize(f_3) + + def tensorize_load(block, dim): + loops = sch.get_loops(block) + i, j = loops[-dim : (len(loops) - dim + 2)] + + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + return i1 + + fetch_to_shared(block, 0) + fetch_to_shared(block, 1) + + c_warp_scope = "wmma.accumulator" + a_warp_scope = "wmma.matrix_a" + b_warp_scope = "wmma.matrix_b" + + A_mat = sch.cache_read(block, 0, a_warp_scope) + B_mat = sch.cache_read(block, 1, b_warp_scope) + + loop_a = tensorize_load(A_mat, 2) + sch.tensorize(loop_a, WMMA_LOAD_16x16x16_F16_A_INTRIN) + + loop_b = tensorize_load(B_mat, 2) + sch.tensorize(loop_b, WMMA_LOAD_16x16x16_F16_B_INTRIN) + + store = sch.cache_write(block, 0, c_warp_scope) + sch.reverse_compute_at(store, fused_outer) + init = sch.decompose_reduction(block, sch.get_loops(block)[1]) + + intrin = WMMA_FILL_16x16x16_F32_INTRIN + if out_dtype == "float16": + intrin = WMMA_FILL_16x16x16_F16_INTRIN + sch.tensorize(sch.get_loops(init)[1], intrin) + + intrin = WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN + if out_dtype == "float16": + intrin = WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN + sch.tensorize(sch.get_loops(store)[1], intrin) + + intrin = WMMA_SYNC_16x16x16_f16f16f32_INTRIN + if out_dtype == "float16": + intrin = WMMA_SYNC_16x16x16_f16f16f16_INTRIN + sch.tensorize(sch.get_loops(block)[2], intrin) + + target = "vulkan -from_device=0" + tgt_attrs = tvm.target.Target(target).attrs + + if tgt_attrs.get("supports_cooperative_matrix"): + f = tvm.build(sch.mod, target=target) + + dev = tvm.device(target, 0) + + A = tvm.nd.array(np.random.randn(M, K).astype("float16"), dev) + B = tvm.nd.array(np.random.randn(K, N).astype("float16"), dev) + C = tvm.nd.array(np.random.randn(M, N).astype(out_dtype), dev) + + f(A, B, C) + + A_np = A.numpy() + B_np = B.numpy() + ref = np.dot(A_np.astype("float32"), B_np.astype("float32")) + + tvm.testing.assert_allclose(C.numpy(), ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main()