Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vulkan] Add cooperative matrix support #14817

Merged
merged 1 commit into from
May 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,25 @@ def max_num_threads(self):
"""Returns the max_num_threads from the target if it exists."""
return int(self.attrs["max_num_threads"])

@property
def max_block_size_x(self):
"""Returns the max block size in x-dimension from the target if it exists."""
return int(self.attrs["max_block_size_x"])

@property
def max_block_size_y(self):
"""Returns the max block size in y-dimension from the target if it exists."""
return int(self.attrs["max_block_size_y"])

@property
def thread_warp_size(self):
"""Returns the thread_warp_size from the target if it exists."""
return int(self.attrs["thread_warp_size"])

@property
def max_shared_memory_per_block(self):
return int(self.attrs["max_shared_memory_per_block"])

@property
def max_function_args(self):
return int(self.attrs.get("max_function_args", -1))
Expand Down Expand Up @@ -219,6 +233,13 @@ def supports_integer_dot_product(self):
def libs(self):
return list(self.attrs.get("libs", []))

@property
def supports_cooperative_matrix(self):
if self.attrs.get("supports_cooperative_matrix", []):
return bool(self.attrs["supports_cooperative_matrix"])
else:
return False

@property
def features(self):
return TargetFeatures(self)
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance,

supports_integer_dot_product = device.HasExtension("VK_KHR_shader_integer_dot_product");

supports_cooperative_matrix = device.HasExtension("VK_NV_cooperative_matrix");

// The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically
// needed, since it will be set so long at least one queue has
// VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future
Expand Down Expand Up @@ -435,7 +437,8 @@ std::vector<const char*> VulkanDevice::SelectEnabledExtensions() const {
"VK_KHR_get_memory_requirements2",
"VK_KHR_dedicated_allocation",
"VK_KHR_spirv_1_4",
"VK_KHR_shader_integer_dot_product"};
"VK_KHR_shader_integer_dot_product",
"VK_NV_cooperative_matrix"};

uint32_t device_extension_prop_count;
VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr,
Expand Down
1 change: 1 addition & 0 deletions src/runtime/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ struct VulkanDeviceProperties {
bool supports_push_descriptor{false};
bool supports_dedicated_allocation{false};
bool supports_integer_dot_product{false};
bool supports_cooperative_matrix{false};
uint32_t supported_subgroup_operations{0};
uint32_t max_num_threads{1};
uint32_t thread_warp_size{1};
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property,
*rv = prop.supports_integer_dot_product;
}

if (property == "supports_cooperative_matrix") {
*rv = prop.supports_cooperative_matrix;
}

if (property == "device_name") {
*rv = prop.device_name;
}
Expand Down
23 changes: 6 additions & 17 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <utility>
#include <vector>

#include "../../tir/transforms/ir_utils.h"
#include "literal/cuda_half_t.h"
#include "ptx.h"

Expand Down Expand Up @@ -1333,23 +1334,11 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode
ICHECK(fragment_shapes.count(variable))
<< "Cannot find shape of the wmma fragment " << variable->name_hint;
std::string shape_str = fragment_shapes.at(variable);
size_t m, n, k;
size_t last_pos = 0, pos = 0;
pos = shape_str.find(", ", last_pos);
m = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
pos = shape_str.find(", ", last_pos);
n = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
k = tvm::codegen::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
if (scope == "wmma.matrix_a") {
return size / m / k;
} else if (scope == "wmma.matrix_b") {
return size / n / k;
} else if (scope == "wmma.accumulator") {
return size / m / n;
}
return 0;
std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
if (dim.first * dim.second != 0)
return size / dim.first / dim.second;
else
return 0;
}

void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
Expand Down
193 changes: 178 additions & 15 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ void CodeGenSPIRV::InitFuncState() {
builder_.reset(new spirv::IRBuilder(spirv_support_));
builder_->InitHeader();
shared_memory_bytes_used_ = 0;
fragment_info_.clear();
}

spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) {
Expand Down Expand Up @@ -394,6 +395,120 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
LOG(FATAL) << "SPIR-V shader cannot make extern calls. Graph contains extern \""
<< Downcast<StringImm>(op->args[0]) << "\"";
return spirv::Value();
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
ICHECK_EQ(op->args.size(), 6U);
const VarNode* buffer_node = op->args[0].as<VarNode>();
ICHECK(buffer_node && fragment_info_.count(buffer_node));
DataType ele_dtype = GetElementDataType(buffer_node);
ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported";
spirv::SType ele_stype = builder_->GetSType(ele_dtype);
spirv::SType& fragment_type = fragment_info_[buffer_node].stype;
double init = static_cast<uint64_t>(Downcast<FloatImm>(op->args[5])->value);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cast to uint64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't recall a good reason. Removing this cast works fine. Should I reset and re-patch?

PrimExpr prim_index = op->args[4];
spirv::Value init_val = builder_->GetCompositeConst(ele_stype, fragment_type, init);
spirv::SType ptr_type =
builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass);
spirv::Value index = MakeValue(prim_index);
ICHECK(var_map_.count(buffer_node));
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], index);
builder_->MakeInst(spv::OpStore, ptr, init_val, spv::MemoryAccessMaskNone);
return spirv::Value();

} else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
ICHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_node = op->args[0].as<VarNode>();
ICHECK(buffer_node && fragment_info_.count(buffer_node));
spirv::SType& fragment_type = fragment_info_[buffer_node].stype;
PrimExpr dst_index = op->args[4];
PrimExpr src_ptr_expr = op->args[5];
int stride = static_cast<int>(Downcast<IntImm>(op->args[6])->value);
auto type_int = builder_->GetSType(DataType::Int(32));
spirv::Value stride_val = builder_->IntImm(type_int, stride);
std::string layout = (op->args[7].as<StringImmNode>())->value;
spirv::SType dst_ptr_type =
builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass);
spirv::Value dst_ptr =
builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index));
spirv::Value src_ptr = VisitExpr(op->args[5]);
spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
spirv::Value t_val = builder_->UIntImm(type_bool, 1);
spirv::Value f_val = builder_->UIntImm(type_bool, 0);
spirv::Value loaded =
builder_->MakeValue(spv::OpCooperativeMatrixLoadNV, fragment_type, src_ptr, stride_val,
(layout != "row_major") ? t_val : f_val);
builder_->MakeInst(spv::OpStore, dst_ptr, loaded, spv::MemoryAccessMaskNone);
return spirv::Value();
} else if (op->op.same_as(builtin::tvm_mma_sync())) {
const VarNode* buffer_d = op->args[0].as<VarNode>();
const VarNode* buffer_a = op->args[2].as<VarNode>();
const VarNode* buffer_b = op->args[4].as<VarNode>();
const VarNode* buffer_c = op->args[6].as<VarNode>();
PrimExpr index_d = op->args[1];
PrimExpr index_a = op->args[3];
PrimExpr index_b = op->args[5];
tvm::tir::ExprDeepEqual expr_equal;
PrimExpr index_c = op->args[7];
bool is_equal = ((buffer_d == buffer_c) && expr_equal(index_d, index_c));
spirv::SType& fragment_type_d = fragment_info_[buffer_d].stype;
spirv::SType& fragment_type_a = fragment_info_[buffer_a].stype;
spirv::SType& fragment_type_b = fragment_info_[buffer_b].stype;
spirv::SType& fragment_type_c = fragment_info_[buffer_c].stype;
spv::StorageClass storage = fragment_info_[buffer_d].sclass;
spirv::SType ptr_type_d = builder_->GetPointerType(fragment_type_d, storage);
spirv::SType ptr_type_a = builder_->GetPointerType(fragment_type_a, storage);
spirv::SType ptr_type_b = builder_->GetPointerType(fragment_type_b, storage);
spirv::SType ptr_type_c = builder_->GetPointerType(fragment_type_c, storage);
spirv::Value ptr_d =
builder_->StructArrayAccess(ptr_type_d, var_map_[buffer_d], MakeValue(index_d));
spirv::Value ptr_a =
builder_->StructArrayAccess(ptr_type_a, var_map_[buffer_a], MakeValue(index_a));
spirv::Value ptr_b =
builder_->StructArrayAccess(ptr_type_b, var_map_[buffer_b], MakeValue(index_b));
spirv::Value ptr_c =
is_equal ? ptr_d
: builder_->StructArrayAccess(ptr_type_c, var_map_[buffer_c], MakeValue(index_c));
uint32_t mask = spv::MemoryAccessMaskNone;
spirv::Value loaded_a = builder_->MakeValue(spv::OpLoad, fragment_type_a, ptr_a, mask);
spirv::Value loaded_b = builder_->MakeValue(spv::OpLoad, fragment_type_b, ptr_b, mask);
spirv::Value loaded_c = builder_->MakeValue(spv::OpLoad, fragment_type_c, ptr_c, mask);
masahi marked this conversation as resolved.
Show resolved Hide resolved
spirv::Value result = builder_->MakeValue(spv::OpCooperativeMatrixMulAddNV, fragment_type_d,
loaded_a, loaded_b, loaded_c);
builder_->MakeInst(spv::OpStore, ptr_d, result, spv::MemoryAccessMaskNone);
return spirv::Value();
} else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
ICHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_node = op->args[0].as<VarNode>();
PrimExpr index = op->args[4];
PrimExpr buffer_ptr = op->args[5];
int stride = static_cast<int>(Downcast<IntImm>(op->args[6])->value);
auto type_int = builder_->GetSType(DataType::Int(32));
spirv::Value stride_val = builder_->IntImm(type_int, stride);
std::string layout = (op->args[7].as<StringImmNode>())->value;
spirv::Value dst_ptr = VisitExpr(op->args[5]);
spirv::SType& fragment_type = fragment_info_[buffer_node].stype;
spv::StorageClass storage = fragment_info_[buffer_node].sclass;
spirv::SType ptr_type = builder_->GetPointerType(fragment_type, storage);
spirv::Value ptr =
builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
uint32_t mask = spv::MemoryAccessMaskNone;
spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask);
spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
spirv::Value t_val = builder_->UIntImm(type_bool, 1);
spirv::Value f_val = builder_->UIntImm(type_bool, 0);
builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val,
(layout != "row_major") ? t_val : f_val);
return spirv::Value();
} else if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
Var buffer_var = load->buffer->data;
const VarNode* buffer_node = buffer_var.get();
PrimExpr index = load->indices[0];
DataType ele_dtype = GetElementDataType(buffer_node);
spirv::SType ele_stype = builder_->GetSType(ele_dtype);
spirv::Value buffer_val = MakeValue(buffer_var);
spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class);
ICHECK(var_map_.count(buffer_node));
return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
Expand Down Expand Up @@ -657,22 +772,46 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";

spirv::Value buf;
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
const std::string scope = GetPtrStorageScope(op->buffer_var);
auto storage_scope = runtime::StorageScope::Create(scope);
spirv::SType etype = builder_->GetSType(op->dtype);
if (storage_scope.rank == runtime::StorageRank::kLocal) {
buf =
builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassFunction);
} else if (storage_scope.rank == runtime::StorageRank::kShared) {
// Shared memory
// Aligned on 4-byte boundary
int32_t aligned_constant_size = ((constant_size + 3) & ~0x3);
buf = builder_->Allocate(etype, static_cast<uint32_t>(aligned_constant_size),
spv::StorageClassWorkgroup);

size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast<uint32_t>(constant_size);
shared_memory_bytes_used_ += num_bytes;
} else {
LOG(FATAL) << "Can only allocate shared or local memory inside kernel";
runtime::StorageRank rank = storage_scope.rank;
spv::StorageClass storage_class;
const VarNode* var_node = (op->buffer_var).get();

switch (rank) {
case runtime::StorageRank::kWMMAMatrixA:
case runtime::StorageRank::kWMMAMatrixB:
case runtime::StorageRank::kWMMAAccumulator: {
ICHECK(fragment_info_.count(var_node));
fragment_info_[var_node].scope = scope;
etype = GetFragmentSType(var_node, op->dtype);
storage_class = spv::StorageClassFunction;
fragment_info_[var_node].sclass = storage_class;
ICHECK(fragment_info_.count(var_node));
const std::string& scope = fragment_info_[var_node].scope;
const std::string& shape_str = fragment_info_.at(var_node).shape;
std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
int64_t size = dim.first * dim.second;
buf = builder_->Allocate(etype, static_cast<uint32_t>(constant_size) / size, storage_class);
} break;
case runtime::StorageRank::kLocal: {
storage_class = spv::StorageClassFunction;
buf = builder_->Allocate(etype, static_cast<uint32_t>(constant_size), storage_class);
} break;
case runtime::StorageRank::kShared: {
storage_class = spv::StorageClassWorkgroup;
// Shared memory
// Aligned on 4-byte boundary
int32_t aligned_constant_size = ((constant_size + 3) & ~0x3);
buf = builder_->Allocate(etype, static_cast<uint32_t>(aligned_constant_size), storage_class);

size_t num_bytes =
op->dtype.bytes() * op->dtype.lanes() * static_cast<uint32_t>(aligned_constant_size);
shared_memory_bytes_used_ += num_bytes;
} break;
default:
LOG(FATAL) << "Can only allocate shared or local memory inside kernel";
}

builder_->SetName(buf, op->buffer_var->name_hint);
Expand Down Expand Up @@ -700,6 +839,13 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
storage_info_[v].is_volatile = true;
} else if (op->attr_key == tir::attr::buffer_bind_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
} else if (op->attr_key == tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_info_[buffer] = {shape_str->value};
}
this->VisitStmt(op->body);
}
Expand All @@ -725,5 +871,22 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) {

void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }

spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const DataType& dtype) {
ICHECK(fragment_info_.count(buffer));
const std::string& scope = fragment_info_[buffer].scope;
const std::string& shape_str = fragment_info_.at(buffer).shape;
std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
int64_t size = dim.first * dim.second;
spirv::SType stype = builder_->GetSType(dtype.with_lanes(size), dim.first, dim.second);
fragment_info_[buffer].stype = stype;
return stype;
}

DataType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) {
auto it = storage_info_.find(buffer);
ICHECK(it != storage_info_.end());
return it->second.element_type;
}

} // namespace codegen
} // namespace tvm
14 changes: 14 additions & 0 deletions src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../../runtime/spirv/spirv_shader.h"
Expand Down Expand Up @@ -171,6 +172,14 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
element_type_known = true;
}
};

struct FragmentInfo {
std::string shape;
std::string scope;
spirv::SType stype;
spv::StorageClass sclass;
};

// Reset the state so it works for a new function.
void InitFuncState();
// Get the thread index
Expand All @@ -179,6 +188,9 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
spirv::Value CreateStorageSync(const CallNode* op);
void Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f);

spirv::SType GetFragmentSType(const VarNode* buffer, const DataType& dtype);
DataType GetElementDataType(const VarNode* buffer);

// SPIRV-related capabilities of the target
SPIRVSupport spirv_support_;

Expand Down Expand Up @@ -218,6 +230,8 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
// Running total of the number of bytes of shared memory used.
// Checked against the max_shared_memory_per_group
size_t shared_memory_bytes_used_{0};

std::unordered_map<const VarNode*, FragmentInfo> fragment_info_;
};

} // namespace codegen
Expand Down
Loading