Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [refactor] Consistently use quant_xxx in quant-related names #5166

Merged
merged 5 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions misc/benchmark_bit_struct_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
n = 1024 * 1024 * 256

if quant:
ci16 = ti.types.quant.int(16, True)
qi16 = ti.types.quant.int(16, True)

x = ti.field(dtype=ci16)
y = ti.field(dtype=ci16)
x = ti.field(dtype=qi16)
y = ti.field(dtype=qi16)

ti.root.dense(ti.i, n).bit_struct(num_bits=32).place(x, y)
else:
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/types/quantized_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def int(bits, signed=True, compute=None): # pylint: disable=W0622
compute = impl.get_runtime().default_ip
if isinstance(compute, _ti_core.DataType):
compute = compute.get_ptr()
return _type_factory.get_custom_int_type(bits, signed, compute)
return _type_factory.get_quant_int_type(bits, signed, compute)


def fixed(frac, signed=True, range=1.0, compute=None, scale=None): # pylint: disable=W0622
Expand All @@ -51,7 +51,7 @@ def fixed(frac, signed=True, range=1.0, compute=None, scale=None): # pylint: di
scale = range / 2**(frac - 1)
else:
scale = range / 2**frac
return _type_factory.get_custom_fixed_type(frac_type, compute, scale)
return _type_factory.get_quant_fixed_type(frac_type, compute, scale)


def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622
Expand All @@ -74,7 +74,7 @@ def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622
exp_type = int(bits=exp, signed=False, compute=i32)
# TODO: handle cases with frac > 32
frac_type = int(bits=frac, signed=signed, compute=i32)
return _type_factory.get_custom_float_type(frac_type, exp_type, compute)
return _type_factory.get_quant_float_type(frac_type, exp_type, compute)


__all__ = ['int', 'fixed', 'float']
8 changes: 4 additions & 4 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,16 +538,16 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
ptr_type->is_bit_pointer()) {
// Bit pointer case.
auto val_type = ptr_type->get_pointee_type();
if (auto cit = val_type->cast<CustomIntType>()) {
dtype = cit->get_physical_type();
if (auto qit = val_type->cast<QuantIntType>()) {
dtype = qit->get_physical_type();
auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->src]);
data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype));
auto data = create_intrinsic_load(dtype, data_ptr);
llvm_val[stmt] = extract_quant_int(data, bit_offset, val_type);
} else {
// TODO: support __ldg
TI_ASSERT(val_type->is<CustomFixedType>() ||
val_type->is<CustomFloatType>());
TI_ASSERT(val_type->is<QuantFixedType>() ||
val_type->is<QuantFloatType>());
llvm_val[stmt] = load_quant_fixed_or_quant_float(stmt->src);
}
} else {
Expand Down
58 changes: 29 additions & 29 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ bool is_full_bits(int bits) {
return bits == (sizeof(uint32_t) * 8);
}

void validate_cfxt_for_metal(CustomFixedType *cft) {
if (cft->get_compute_type()->as<PrimitiveType>() != PrimitiveType::f32) {
void validate_qfxt_for_metal(QuantFixedType *qfxt) {
if (qfxt->get_compute_type()->as<PrimitiveType>() != PrimitiveType::f32) {
TI_ERROR("Metal only supports 32-bit float");
}
}
Expand Down Expand Up @@ -969,22 +969,22 @@ class KernelCodegenImpl : public IRVisitor {
auto *ptr_type = stmt->dest->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
CustomIntType *cit = nullptr;
QuantIntType *qit = nullptr;
std::string store_value_expr;
if (auto *cit_cast = pointee_type->cast<CustomIntType>()) {
cit = cit_cast;
if (auto *qit_cast = pointee_type->cast<QuantIntType>()) {
qit = qit_cast;
store_value_expr = stmt->val->raw_name();
} else if (auto *cfxt = pointee_type->cast<CustomFixedType>()) {
validate_cfxt_for_metal(cfxt);
auto *digits_cit = cfxt->get_digits_type()->as<CustomIntType>();
cit = digits_cit;
} else if (auto *qfxt = pointee_type->cast<QuantFixedType>()) {
validate_qfxt_for_metal(qfxt);
auto *digits_qit = qfxt->get_digits_type()->as<QuantIntType>();
qit = digits_qit;
store_value_expr = construct_quant_fixed_to_quant_int_expr(
stmt->val, cfxt->get_scale(), digits_cit);
stmt->val, qfxt->get_scale(), digits_qit);
} else {
TI_NOT_IMPLEMENTED;
}
// Type of |stmt->dest| is SNodeBitPointer
const auto num_bits = cit->get_num_bits();
const auto num_bits = qit->get_num_bits();
if (is_full_bits(num_bits)) {
emit("mtl_set_full_bits({}, {});", stmt->dest->raw_name(),
store_value_expr);
Expand All @@ -1000,16 +1000,16 @@ class KernelCodegenImpl : public IRVisitor {
auto *ptr_type = stmt->src->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
if (auto *cit = pointee_type->cast<CustomIntType>()) {
return construct_load_quant_int(stmt->src, cit);
} else if (auto *cfxt = pointee_type->cast<CustomFixedType>()) {
validate_cfxt_for_metal(cfxt);
if (auto *qit = pointee_type->cast<QuantIntType>()) {
return construct_load_quant_int(stmt->src, qit);
} else if (auto *qfxt = pointee_type->cast<QuantFixedType>()) {
validate_qfxt_for_metal(qfxt);
const auto loaded = construct_load_quant_int(
stmt->src, cfxt->get_digits_type()->as<CustomIntType>());
stmt->src, qfxt->get_digits_type()->as<QuantIntType>());
// Computes `float(digits_expr) * scale`
// See LLVM backend's reconstruct_quant_fixed()
return fmt::format("(static_cast<float>({}) * {})", loaded,
cfxt->get_scale());
qfxt->get_scale());
}
TI_NOT_IMPLEMENTED;
return "";
Expand All @@ -1023,19 +1023,19 @@ class KernelCodegenImpl : public IRVisitor {
auto *ptr_type = dest_ptr->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
CustomIntType *cit = nullptr;
QuantIntType *qit = nullptr;
std::string val_expr;
if (auto *cit_cast = pointee_type->cast<CustomIntType>()) {
cit = cit_cast;
if (auto *qit_cast = pointee_type->cast<QuantIntType>()) {
qit = qit_cast;
val_expr = stmt->val->raw_name();
} else if (auto *cfxt = pointee_type->cast<CustomFixedType>()) {
cit = cfxt->get_digits_type()->as<CustomIntType>();
} else if (auto *qfxt = pointee_type->cast<QuantFixedType>()) {
qit = qfxt->get_digits_type()->as<QuantIntType>();
val_expr = construct_quant_fixed_to_quant_int_expr(
stmt->val, cfxt->get_scale(), cit);
stmt->val, qfxt->get_scale(), qit);
} else {
TI_NOT_IMPLEMENTED;
}
const auto num_bits = cit->get_num_bits();
const auto num_bits = qit->get_num_bits();
if (is_full_bits(num_bits)) {
emit("const auto {} = mtl_atomic_add_full_bits({}, {});",
stmt->raw_name(), dest_ptr->raw_name(), val_expr);
Expand All @@ -1051,8 +1051,8 @@ class KernelCodegenImpl : public IRVisitor {
std::string construct_quant_fixed_to_quant_int_expr(
const Stmt *val_stmt,
float64 scale,
CustomIntType *digits_cit) const {
DataType compute_dt(digits_cit->get_compute_type()->as<PrimitiveType>());
QuantIntType *digits_qit) const {
DataType compute_dt(digits_qit->get_compute_type()->as<PrimitiveType>());
// This implicitly casts double to float on the host.
const float inv_scale = 1.0 / scale;
// Creating an expression (instead of holding intermediate results with
Expand All @@ -1066,9 +1066,9 @@ class KernelCodegenImpl : public IRVisitor {

// Returns expression of the loaded integer.
std::string construct_load_quant_int(const Stmt *bit_ptr_stmt,
CustomIntType *cit) const {
DataType compute_dt(cit->get_compute_type()->as<PrimitiveType>());
const auto num_bits = cit->get_num_bits();
QuantIntType *qit) const {
DataType compute_dt(qit->get_compute_type()->as<PrimitiveType>());
const auto num_bits = qit->get_num_bits();
if (is_full_bits(num_bits)) {
return fmt::format("mtl_get_full_bits<{}>({})",
metal_data_type_name(compute_dt),
Expand Down
36 changes: 18 additions & 18 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,8 @@ void CodeGenLLVM::visit(RangeForStmt *for_stmt) {
llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) {
llvm::Type *dest_ty = nullptr;
TI_ASSERT(!type->is<PointerType>());
if (auto cit = type->cast<CustomIntType>()) {
if (cit->get_is_signed())
if (auto qit = type->cast<QuantIntType>()) {
if (qit->get_is_signed())
dest_ty = tlctx->get_data_type(PrimitiveType::i32);
else
dest_ty = tlctx->get_data_type(PrimitiveType::u32);
Expand Down Expand Up @@ -1056,8 +1056,8 @@ llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) {
if (type.is_pointer()) {
return builder->CreatePtrToInt(val, tlctx->get_data_type<int64>());
}
if (auto cit = type->cast<CustomIntType>()) {
intermediate_bits = data_type_bits(cit->get_compute_type());
if (auto qit = type->cast<QuantIntType>()) {
intermediate_bits = data_type_bits(qit->get_compute_type());
} else {
intermediate_bits = tlctx->get_data_type(type)->getPrimitiveSizeInBits();
}
Expand Down Expand Up @@ -1188,17 +1188,17 @@ llvm::Value *CodeGenLLVM::optimized_reduction(AtomicOpStmt *stmt) {
return nullptr;
}

llvm::Value *CodeGenLLVM::custom_type_atomic(AtomicOpStmt *stmt) {
// TODO(type): support all AtomicOpTypes on custom types
llvm::Value *CodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) {
// TODO(type): support all AtomicOpTypes on quant types
if (stmt->op_type != AtomicOpType::add) {
return nullptr;
}

auto dst_type = stmt->dest->ret_type->as<PointerType>()->get_pointee_type();
if (auto cit = dst_type->cast<CustomIntType>()) {
return atomic_add_quant_int(stmt, cit);
} else if (auto cfxt = dst_type->cast<CustomFixedType>()) {
return atomic_add_quant_fixed(stmt, cfxt);
if (auto qit = dst_type->cast<QuantIntType>()) {
return atomic_add_quant_int(stmt, qit);
} else if (auto qfxt = dst_type->cast<QuantFixedType>()) {
return atomic_add_quant_fixed(stmt, qfxt);
} else {
return nullptr;
}
Expand Down Expand Up @@ -1318,7 +1318,7 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) {

if (llvm::Value *result = optimized_reduction(stmt)) {
old_value = result;
} else if (llvm::Value *result = custom_type_atomic(stmt)) {
} else if (llvm::Value *result = quant_type_atomic(stmt)) {
old_value = result;
} else if (llvm::Value *result = real_type_atomic(stmt)) {
old_value = result;
Expand All @@ -1341,21 +1341,21 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
auto ptr_type = stmt->dest->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
auto pointee_type = ptr_type->get_pointee_type();
if (!pointee_type->is<CustomIntType>()) {
if (!pointee_type->is<QuantIntType>()) {
if (stmt->dest->as<GetChStmt>()->input_snode->type ==
SNodeType::bit_struct) {
TI_ERROR(
"Bit struct stores with type {} should have been "
"handled by BitStructStoreStmt.",
pointee_type->to_string());
} else {
TI_ERROR("Bit array only supports custom int type.");
TI_ERROR("Bit array only supports quant int type.");
}
}
llvm::Value *store_value = nullptr;
auto *cit = pointee_type->as<CustomIntType>();
auto *qit = pointee_type->as<QuantIntType>();
store_value = llvm_val[stmt->val];
store_quant_int(llvm_val[stmt->dest], cit, store_value, /*atomic=*/true);
store_quant_int(llvm_val[stmt->dest], qit, store_value, /*atomic=*/true);
} else {
builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
}
Expand All @@ -1367,11 +1367,11 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
auto ptr_type = stmt->src->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
if (val_type->is<CustomIntType>()) {
if (val_type->is<QuantIntType>()) {
llvm_val[stmt] = load_quant_int(llvm_val[stmt->src], val_type);
} else {
TI_ASSERT(val_type->is<CustomFixedType>() ||
val_type->is<CustomFloatType>());
TI_ASSERT(val_type->is<QuantFixedType>() ||
val_type->is<QuantFloatType>());
TI_ASSERT(stmt->src->is<GetChStmt>());
llvm_val[stmt] = load_quant_fixed_or_quant_float(stmt->src);
}
Expand Down
23 changes: 11 additions & 12 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,17 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(SNodeOpStmt *stmt) override;

llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt,
CustomFixedType *cfxt);
llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt, QuantFixedType *qfxt);

llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, CustomIntType *cit);
llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, QuantIntType *qit);

llvm::Value *quant_fixed_to_quant_int(CustomFixedType *cfxt,
CustomIntType *cit,
llvm::Value *quant_fixed_to_quant_int(QuantFixedType *qfxt,
QuantIntType *qit,
llvm::Value *real);

virtual llvm::Value *optimized_reduction(AtomicOpStmt *stmt);

virtual llvm::Value *custom_type_atomic(AtomicOpStmt *stmt);
virtual llvm::Value *quant_type_atomic(AtomicOpStmt *stmt);

virtual llvm::Value *integral_type_atomic(AtomicOpStmt *stmt);

Expand All @@ -248,13 +247,13 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
void visit(PtrOffsetStmt *stmt) override;

void store_quant_int(llvm::Value *bit_ptr,
CustomIntType *cit,
QuantIntType *qit,
llvm::Value *value,
bool atomic);

void store_quant_int(llvm::Value *byte_ptr,
llvm::Value *bit_offset,
CustomIntType *cit,
QuantIntType *qit,
llvm::Value *value,
bool atomic);

Expand Down Expand Up @@ -284,16 +283,16 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
Type *load_type);

llvm::Value *reconstruct_quant_fixed(llvm::Value *digits,
CustomFixedType *cfxt);
QuantFixedType *qfxt);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
CustomFloatType *cft,
QuantFloatType *qflt,
bool shared_exponent);

llvm::Value *reconstruct_quant_float(llvm::Value *input_digits,
llvm::Value *input_exponent_val,
CustomFloatType *cft,
QuantFloatType *qflt,
bool shared_exponent);

llvm::Value *load_quant_fixed_or_quant_float(Stmt *ptr_stmt);
Expand Down Expand Up @@ -404,7 +403,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *f,
llvm::Value *shared_exp);

llvm::Value *get_exponent_offset(llvm::Value *exponent, CustomFloatType *cft);
llvm::Value *get_exponent_offset(llvm::Value *exponent, QuantFloatType *qflt);

void visit(FuncCallStmt *stmt) override;

Expand Down
Loading