Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] Enable int8 vlut codegen for Relay take (LUT) operator #11693

Merged
merged 7 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm/topi/hexagon/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def schedule_injective(outs):
outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs
s = tvm.te.create_schedule([x.op for x in outs])
tvm.te.schedule.AutoInlineInjective(s)

# Fuse axes and vectorize inner 128 elements
for x in outs:
fused = s[x].fuse(*x.op.axis)
_, inner = s[x].split(fused, factor=128)
s[x].vectorize(inner)
return s


Expand Down
147 changes: 147 additions & 0 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,19 @@ class CodeGenHexagon final : public CodeGenCPU {
bool system_lib, bool dynamic_lookup, bool target_c_runtime) override;
void InitTarget(llvm::TargetMachine* tm) final;

using CodeGenCPU::VisitStmt_;
llvm::Value* VisitExpr_(const BufferLoadNode* op) override;

llvm::Module* GetModulePtr() const { return module_.get(); }

uint64_t GetTypeSizeInBits(llvm::Type* type) const {
#if TVM_LLVM_VERSION >= 100
return data_layout_->getTypeSizeInBits(type).getFixedSize();
#else
return data_layout_->getTypeSizeInBits(type);
#endif
}

protected:
void CreatePrintf(const std::string& format, llvm::ArrayRef<llvm::Value*> format_args) final;

Expand All @@ -86,6 +97,9 @@ class CodeGenHexagon final : public CodeGenCPU {

llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);

llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, Array<PrimExpr> index);
llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef<llvm::Value*> args);
};

void CodeGenHexagon::Init(const std::string& module_name, llvm::TargetMachine* tm,
Expand Down Expand Up @@ -281,6 +295,139 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V
return TypedPointer();
}

llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID,
llvm::ArrayRef<llvm::Value*> args) {
llvm::Function* intf = llvm::Intrinsic::getDeclaration(module_.get(), IntID);
#if TVM_LLVM_VERSION >= 90
auto intf_callee = llvm::FunctionCallee(intf);
#else
auto intf_callee = intf;
#endif
std::vector<llvm::Value*> conv_args;
llvm::FunctionType* intf_type = intf->getFunctionType();
ICHECK(args.size() == intf_type->getNumParams());

for (int i = 0, e = args.size(); i != e; ++i) {
llvm::Value* arg = args[i];
auto* need_type = llvm::dyn_cast<llvm::VectorType>(intf_type->getParamType(i));
auto* have_type = llvm::dyn_cast<llvm::VectorType>(arg->getType());
if (need_type != nullptr && have_type != nullptr && need_type != have_type) {
int need_width = GetTypeSizeInBits(need_type);
int have_width = GetTypeSizeInBits(have_type);
if (need_width == have_width) {
if (need_width == native_vector_bits_ || need_width == 2 * native_vector_bits_) {
arg = builder_->CreateBitCast(arg, need_type);
}
} // TODO(joshherr-quic): add handling of v128i1 <-> v1024i1
}
conv_args.push_back(arg);
}
return builder_->CreateCall(intf_callee, conv_args);
}

llvm::Value* CodeGenHexagon::VisitExpr_(const BufferLoadNode* op) {
if (!op->buffer.same_as(op->buffer->data)) {
// Check if we can generate a vector lookup.
if (!op->indices[0].as<RampNode>()) {
if (auto* vlut = VectorLookupLoad(op->buffer, op->dtype, op->indices)) {
return vlut;
}
}
}
return CodeGenLLVM::VisitExpr_(op);
}

llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_type,
Array<PrimExpr> indices) {
PrimExpr index = indices[0];
if (!index.dtype().is_vector()) {
return nullptr;
}

if (buffer_type.bits() != 8) return nullptr;

int table_elem_count = arith::Analyzer().Simplify(buffer->shape[0]).as<IntImmNode>()->value;
if (table_elem_count <= 0 || table_elem_count > 256) return nullptr;

auto int32 = DataType::Int(32);
auto native_vector_bytes = native_vector_bits_ / 8;

// Indexes
llvm::Value* trunc = MakeValue(Cast(index.dtype().with_bits(8), index));
llvm::Value* index_pad = CreateVecPad(trunc, native_vector_bytes);

// Values
std::vector<llvm::Value*> vloads;
DataType table_type = buffer_type.with_lanes(table_elem_count);

auto table_all =
MakeValue(BufferLoad(buffer, {
Ramp(IntImm(int32, 0), IntImm(int32, 1), table_elem_count),
}));

// The number of value vectors should be a power of 2.
int table_vec_count = llvm::PowerOf2Ceil(GetVectorBytes(table_type) / native_vector_bytes);
int table_vec_length = native_vector_bytes / buffer_type.bytes();
for (int i = 0; i != table_vec_count; ++i) {
// CreateVecSlice will generate undefs for elements outside the source vector.
vloads.push_back(CreateVecSlice(table_all, i * table_vec_length, table_vec_length));
}

#define VLO(x) Intrinsic(llvm::Intrinsic::hexagon_V6_lo_128B, {x})
#define VHI(x) Intrinsic(llvm::Intrinsic::hexagon_V6_hi_128B, {x})
#define VXOR(x, y) Intrinsic(llvm::Intrinsic::hexagon_V6_vxor_128B, {x, y})
#define VSHUFF(x) Intrinsic(llvm::Intrinsic::hexagon_V6_vshuffb_128B, {x})
#define VSPLATB(x) Intrinsic(llvm::Intrinsic::hexagon_V6_lvsplatb_128B, {x})
#define VLUT32(x, y, z) Intrinsic(llvm::Intrinsic::hexagon_V6_vlutvvbi_128B, {x, y, z})
#define VLUT32_OR(v, x, y, z) \
Intrinsic(llvm::Intrinsic::hexagon_V6_vlutvvb_oracci_128B, {v, x, y, z})

// Shuffle table bytes:
// 127, 63, 126, 62,........68, 4, 67, 3, 66, 2, 65, 1, 64, 0
std::vector<llvm::Value*> table;
for (int i = 0; i != table_vec_count; ++i) table.push_back(VSHUFF(vloads[i]));

// Get each 32 byte sub-table's output
std::vector<llvm::Value*> results;
int table_iters = table_elem_count / 32;
for (int i = 0; i < table_iters; ++i)
results.push_back(VLUT32(index_pad, table[i / 4], ConstInt32(i % 8)));

// Combine outputs
llvm::Value* result = results[0];
for (int i = 1; i < table_iters; ++i) result = VXOR(result, results[i]);

llvm::Type* res_type = result->getType();
llvm::Type* ret_type = DTypeToLLVMType(buffer_type);
if (res_type == ret_type) {
return result;
}

int res_bits = GetTypeSizeInBits(res_type);
int ret_bits = GetTypeSizeInBits(ret_type);
ICHECK_GE(res_bits, ret_bits);
if (ret_bits < res_bits) {
#if TVM_LLVM_VERSION >= 110
llvm::Type* res_byte_type = llvm::VectorType::get(t_int8_, res_bits / 8, /*Scalable*/ false);
#else
llvm::Type* res_byte_type = llvm::VectorType::get(t_int8_, res_bits / 8);
#endif
result = CreateVecSlice(builder_->CreateBitCast(result, res_byte_type), 0, ret_bits / 8);
}
if (result->getType() != ret_type) {
return builder_->CreateBitCast(result, ret_type);
}
return result;

#undef VLUT32_OR
#undef VLUT32
#undef VSPLATB
#undef VSHUFF
#undef VXOR
#undef VHI
#undef VLO
}

namespace {
DMLC_ATTRIBUTE_UNUSED std::ostream& operator<<(std::ostream& os, const llvm::Module& m) {
std::string ms;
Expand Down