Skip to content

Commit

Permalink
[Hexagon] Enable int8 vlut codegen for Relay take (LUT) operator (apa…
Browse files Browse the repository at this point in the history
…che#11693)

* Working 8 bit vlut for relay take operator

* Formatting

* More formatting

* clang-format on codegen_hexagon.cc

* Update for llvm api

* Add return to VisitExpr(BufferLoadNode) function

* different llvm api
  • Loading branch information
joshherr-quic authored and blackkker committed Jul 7, 2022
1 parent 48ddac1 commit d5b41f9
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/tvm/topi/hexagon/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def schedule_injective(outs):
outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs
s = tvm.te.create_schedule([x.op for x in outs])
tvm.te.schedule.AutoInlineInjective(s)

# Fuse axes and vectorize inner 128 elements
for x in outs:
fused = s[x].fuse(*x.op.axis)
_, inner = s[x].split(fused, factor=128)
s[x].vectorize(inner)
return s


Expand Down
147 changes: 147 additions & 0 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,19 @@ class CodeGenHexagon final : public CodeGenCPU {
bool system_lib, bool dynamic_lookup, bool target_c_runtime) override;
void InitTarget(llvm::TargetMachine* tm) final;

using CodeGenCPU::VisitStmt_;
llvm::Value* VisitExpr_(const BufferLoadNode* op) override;

llvm::Module* GetModulePtr() const { return module_.get(); }

uint64_t GetTypeSizeInBits(llvm::Type* type) const {
#if TVM_LLVM_VERSION >= 100
return data_layout_->getTypeSizeInBits(type).getFixedSize();
#else
return data_layout_->getTypeSizeInBits(type);
#endif
}

protected:
void CreatePrintf(const std::string& format, llvm::ArrayRef<llvm::Value*> format_args) final;

Expand All @@ -86,6 +97,9 @@ class CodeGenHexagon final : public CodeGenCPU {

llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);

llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, Array<PrimExpr> index);
llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef<llvm::Value*> args);
};

void CodeGenHexagon::Init(const std::string& module_name, llvm::TargetMachine* tm,
Expand Down Expand Up @@ -281,6 +295,139 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V
return TypedPointer();
}

llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID,
llvm::ArrayRef<llvm::Value*> args) {
llvm::Function* intf = llvm::Intrinsic::getDeclaration(module_.get(), IntID);
#if TVM_LLVM_VERSION >= 90
auto intf_callee = llvm::FunctionCallee(intf);
#else
auto intf_callee = intf;
#endif
std::vector<llvm::Value*> conv_args;
llvm::FunctionType* intf_type = intf->getFunctionType();
ICHECK(args.size() == intf_type->getNumParams());

for (int i = 0, e = args.size(); i != e; ++i) {
llvm::Value* arg = args[i];
auto* need_type = llvm::dyn_cast<llvm::VectorType>(intf_type->getParamType(i));
auto* have_type = llvm::dyn_cast<llvm::VectorType>(arg->getType());
if (need_type != nullptr && have_type != nullptr && need_type != have_type) {
int need_width = GetTypeSizeInBits(need_type);
int have_width = GetTypeSizeInBits(have_type);
if (need_width == have_width) {
if (need_width == native_vector_bits_ || need_width == 2 * native_vector_bits_) {
arg = builder_->CreateBitCast(arg, need_type);
}
} // TODO(joshherr-quic): add handling of v128i1 <-> v1024i1
}
conv_args.push_back(arg);
}
return builder_->CreateCall(intf_callee, conv_args);
}

llvm::Value* CodeGenHexagon::VisitExpr_(const BufferLoadNode* op) {
if (!op->buffer.same_as(op->buffer->data)) {
// Check if we can generate a vector lookup.
if (!op->indices[0].as<RampNode>()) {
if (auto* vlut = VectorLookupLoad(op->buffer, op->dtype, op->indices)) {
return vlut;
}
}
}
return CodeGenLLVM::VisitExpr_(op);
}

llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_type,
Array<PrimExpr> indices) {
PrimExpr index = indices[0];
if (!index.dtype().is_vector()) {
return nullptr;
}

if (buffer_type.bits() != 8) return nullptr;

int table_elem_count = arith::Analyzer().Simplify(buffer->shape[0]).as<IntImmNode>()->value;
if (table_elem_count <= 0 || table_elem_count > 256) return nullptr;

auto int32 = DataType::Int(32);
auto native_vector_bytes = native_vector_bits_ / 8;

// Indexes
llvm::Value* trunc = MakeValue(Cast(index.dtype().with_bits(8), index));
llvm::Value* index_pad = CreateVecPad(trunc, native_vector_bytes);

// Values
std::vector<llvm::Value*> vloads;
DataType table_type = buffer_type.with_lanes(table_elem_count);

auto table_all =
MakeValue(BufferLoad(buffer, {
Ramp(IntImm(int32, 0), IntImm(int32, 1), table_elem_count),
}));

// The number of value vectors should be a power of 2.
int table_vec_count = llvm::PowerOf2Ceil(GetVectorBytes(table_type) / native_vector_bytes);
int table_vec_length = native_vector_bytes / buffer_type.bytes();
for (int i = 0; i != table_vec_count; ++i) {
// CreateVecSlice will generate undefs for elements outside the source vector.
vloads.push_back(CreateVecSlice(table_all, i * table_vec_length, table_vec_length));
}

#define VLO(x) Intrinsic(llvm::Intrinsic::hexagon_V6_lo_128B, {x})
#define VHI(x) Intrinsic(llvm::Intrinsic::hexagon_V6_hi_128B, {x})
#define VXOR(x, y) Intrinsic(llvm::Intrinsic::hexagon_V6_vxor_128B, {x, y})
#define VSHUFF(x) Intrinsic(llvm::Intrinsic::hexagon_V6_vshuffb_128B, {x})
#define VSPLATB(x) Intrinsic(llvm::Intrinsic::hexagon_V6_lvsplatb_128B, {x})
#define VLUT32(x, y, z) Intrinsic(llvm::Intrinsic::hexagon_V6_vlutvvbi_128B, {x, y, z})
#define VLUT32_OR(v, x, y, z) \
Intrinsic(llvm::Intrinsic::hexagon_V6_vlutvvb_oracci_128B, {v, x, y, z})

// Shuffle table bytes:
// 127, 63, 126, 62,........68, 4, 67, 3, 66, 2, 65, 1, 64, 0
std::vector<llvm::Value*> table;
for (int i = 0; i != table_vec_count; ++i) table.push_back(VSHUFF(vloads[i]));

// Get each 32 byte sub-table's output
std::vector<llvm::Value*> results;
int table_iters = table_elem_count / 32;
for (int i = 0; i < table_iters; ++i)
results.push_back(VLUT32(index_pad, table[i / 4], ConstInt32(i % 8)));

// Combine outputs
llvm::Value* result = results[0];
for (int i = 1; i < table_iters; ++i) result = VXOR(result, results[i]);

llvm::Type* res_type = result->getType();
llvm::Type* ret_type = DTypeToLLVMType(buffer_type);
if (res_type == ret_type) {
return result;
}

int res_bits = GetTypeSizeInBits(res_type);
int ret_bits = GetTypeSizeInBits(ret_type);
ICHECK_GE(res_bits, ret_bits);
if (ret_bits < res_bits) {
#if TVM_LLVM_VERSION >= 110
llvm::Type* res_byte_type = llvm::VectorType::get(t_int8_, res_bits / 8, /*Scalable*/ false);
#else
llvm::Type* res_byte_type = llvm::VectorType::get(t_int8_, res_bits / 8);
#endif
result = CreateVecSlice(builder_->CreateBitCast(result, res_byte_type), 0, ret_bits / 8);
}
if (result->getType() != ret_type) {
return builder_->CreateBitCast(result, ret_type);
}
return result;

#undef VLUT32_OR
#undef VLUT32
#undef VSPLATB
#undef VSHUFF
#undef VXOR
#undef VHI
#undef VLO
}

namespace {
DMLC_ATTRIBUTE_UNUSED std::ostream& operator<<(std::ostream& os, const llvm::Module& m) {
std::string ms;
Expand Down

0 comments on commit d5b41f9

Please sign in to comment.