diff --git a/misc/benchmark_bit_struct_stores.py b/misc/benchmark_bit_struct_stores.py index dcf1b9a8a85b9..821c021191f10 100644 --- a/misc/benchmark_bit_struct_stores.py +++ b/misc/benchmark_bit_struct_stores.py @@ -7,10 +7,10 @@ n = 1024 * 1024 * 256 if quant: - ci16 = ti.types.quant.int(16, True) + qi16 = ti.types.quant.int(16, True) - x = ti.field(dtype=ci16) - y = ti.field(dtype=ci16) + x = ti.field(dtype=qi16) + y = ti.field(dtype=qi16) ti.root.dense(ti.i, n).bit_struct(num_bits=32).place(x, y) else: diff --git a/python/taichi/types/quantized_types.py b/python/taichi/types/quantized_types.py index 81c281a290f85..8cb3a6bd382ba 100644 --- a/python/taichi/types/quantized_types.py +++ b/python/taichi/types/quantized_types.py @@ -24,7 +24,7 @@ def int(bits, signed=True, compute=None): # pylint: disable=W0622 compute = impl.get_runtime().default_ip if isinstance(compute, _ti_core.DataType): compute = compute.get_ptr() - return _type_factory.get_custom_int_type(bits, signed, compute) + return _type_factory.get_quant_int_type(bits, signed, compute) def fixed(frac, signed=True, range=1.0, compute=None, scale=None): # pylint: disable=W0622 @@ -51,7 +51,7 @@ def fixed(frac, signed=True, range=1.0, compute=None, scale=None): # pylint: di scale = range / 2**(frac - 1) else: scale = range / 2**frac - return _type_factory.get_custom_fixed_type(frac_type, compute, scale) + return _type_factory.get_quant_fixed_type(frac_type, compute, scale) def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622 @@ -74,7 +74,7 @@ def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622 exp_type = int(bits=exp, signed=False, compute=i32) # TODO: handle cases with frac > 32 frac_type = int(bits=frac, signed=signed, compute=i32) - return _type_factory.get_custom_float_type(frac_type, exp_type, compute) + return _type_factory.get_quant_float_type(frac_type, exp_type, compute) __all__ = ['int', 'fixed', 'float'] diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 10c9c31ffc6ad..9ebe952e8b978 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -538,16 +538,16 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { ptr_type->is_bit_pointer()) { // Bit pointer case. auto val_type = ptr_type->get_pointee_type(); - if (auto cit = val_type->cast()) { - dtype = cit->get_physical_type(); + if (auto qit = val_type->cast()) { + dtype = qit->get_physical_type(); auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->src]); data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype)); auto data = create_intrinsic_load(dtype, data_ptr); llvm_val[stmt] = extract_quant_int(data, bit_offset, val_type); } else { // TODO: support __ldg - TI_ASSERT(val_type->is() || - val_type->is()); + TI_ASSERT(val_type->is() || + val_type->is()); llvm_val[stmt] = load_quant_fixed_or_quant_float(stmt->src); } } else { diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 86d1127b6fd10..0e89df6c3d617 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -92,8 +92,8 @@ bool is_full_bits(int bits) { return bits == (sizeof(uint32_t) * 8); } -void validate_cfxt_for_metal(CustomFixedType *cft) { - if (cft->get_compute_type()->as() != PrimitiveType::f32) { +void validate_qfxt_for_metal(QuantFixedType *qfxt) { + if (qfxt->get_compute_type()->as() != PrimitiveType::f32) { TI_ERROR("Metal only supports 32-bit float"); } } @@ -969,22 +969,22 @@ class KernelCodegenImpl : public IRVisitor { auto *ptr_type = stmt->dest->ret_type->as(); TI_ASSERT(ptr_type->is_bit_pointer()); auto *pointee_type = ptr_type->get_pointee_type(); - CustomIntType *cit = nullptr; + QuantIntType *qit = nullptr; std::string store_value_expr; - if (auto *cit_cast = pointee_type->cast()) { - cit = cit_cast; + if (auto *qit_cast = pointee_type->cast()) { + qit = qit_cast; store_value_expr = stmt->val->raw_name(); - } else if (auto *cfxt = pointee_type->cast()) { - validate_cfxt_for_metal(cfxt); - auto *digits_cit = cfxt->get_digits_type()->as(); - cit = digits_cit; + } else if (auto *qfxt = pointee_type->cast()) { + validate_qfxt_for_metal(qfxt); + auto *digits_qit = qfxt->get_digits_type()->as(); + qit = digits_qit; store_value_expr = construct_quant_fixed_to_quant_int_expr( - stmt->val, cfxt->get_scale(), digits_cit); + stmt->val, qfxt->get_scale(), digits_qit); } else { TI_NOT_IMPLEMENTED; } // Type of |stmt->dest| is SNodeBitPointer - const auto num_bits = cit->get_num_bits(); + const auto num_bits = qit->get_num_bits(); if (is_full_bits(num_bits)) { emit("mtl_set_full_bits({}, {});", stmt->dest->raw_name(), store_value_expr); @@ -1000,16 +1000,16 @@ class KernelCodegenImpl : public IRVisitor { auto *ptr_type = stmt->src->ret_type->as(); TI_ASSERT(ptr_type->is_bit_pointer()); auto *pointee_type = ptr_type->get_pointee_type(); - if (auto *cit = pointee_type->cast()) { - return construct_load_quant_int(stmt->src, cit); - } else if (auto *cfxt = pointee_type->cast()) { - validate_cfxt_for_metal(cfxt); + if (auto *qit = pointee_type->cast()) { + return construct_load_quant_int(stmt->src, qit); + } else if (auto *qfxt = pointee_type->cast()) { + validate_qfxt_for_metal(qfxt); const auto loaded = construct_load_quant_int( - stmt->src, cfxt->get_digits_type()->as()); + stmt->src, qfxt->get_digits_type()->as()); // Computes `float(digits_expr) * scale` // See LLVM backend's reconstruct_quant_fixed() return fmt::format("(static_cast({}) * {})", loaded, - cfxt->get_scale()); + qfxt->get_scale()); } TI_NOT_IMPLEMENTED; return ""; @@ -1023,19 +1023,19 @@ class KernelCodegenImpl : public IRVisitor { auto *ptr_type = dest_ptr->ret_type->as(); TI_ASSERT(ptr_type->is_bit_pointer()); auto *pointee_type = ptr_type->get_pointee_type(); - CustomIntType *cit = nullptr; + QuantIntType *qit = nullptr; std::string val_expr; - if (auto *cit_cast = pointee_type->cast()) { - cit = cit_cast; + if (auto *qit_cast = pointee_type->cast()) { + qit = qit_cast; val_expr = stmt->val->raw_name(); - } else if (auto *cfxt = pointee_type->cast()) { - cit = cfxt->get_digits_type()->as(); + } else if (auto *qfxt = pointee_type->cast()) { + qit = qfxt->get_digits_type()->as(); val_expr = construct_quant_fixed_to_quant_int_expr( - stmt->val, cfxt->get_scale(), cit); + stmt->val, qfxt->get_scale(), qit); } else { TI_NOT_IMPLEMENTED; } - const auto num_bits = cit->get_num_bits(); + const auto num_bits = qit->get_num_bits(); if (is_full_bits(num_bits)) { emit("const auto {} = mtl_atomic_add_full_bits({}, {});", stmt->raw_name(), dest_ptr->raw_name(), val_expr); @@ -1051,8 +1051,8 @@ class KernelCodegenImpl : public IRVisitor { std::string construct_quant_fixed_to_quant_int_expr( const Stmt *val_stmt, float64 scale, - CustomIntType *digits_cit) const { - DataType compute_dt(digits_cit->get_compute_type()->as()); + QuantIntType *digits_qit) const { + DataType compute_dt(digits_qit->get_compute_type()->as()); // This implicitly casts double to float on the host. const float inv_scale = 1.0 / scale; // Creating an expression (instead of holding intermediate results with @@ -1066,9 +1066,9 @@ class KernelCodegenImpl : public IRVisitor { // Returns expression of the loaded integer. std::string construct_load_quant_int(const Stmt *bit_ptr_stmt, - CustomIntType *cit) const { - DataType compute_dt(cit->get_compute_type()->as()); - const auto num_bits = cit->get_num_bits(); + QuantIntType *qit) const { + DataType compute_dt(qit->get_compute_type()->as()); + const auto num_bits = qit->get_num_bits(); if (is_full_bits(num_bits)) { return fmt::format("mtl_get_full_bits<{}>({})", metal_data_type_name(compute_dt), diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 79b5da824931d..ed811d756a983 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1025,8 +1025,8 @@ void CodeGenLLVM::visit(RangeForStmt *for_stmt) { llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) { llvm::Type *dest_ty = nullptr; TI_ASSERT(!type->is()); - if (auto cit = type->cast()) { - if (cit->get_is_signed()) + if (auto qit = type->cast()) { + if (qit->get_is_signed()) dest_ty = tlctx->get_data_type(PrimitiveType::i32); else dest_ty = tlctx->get_data_type(PrimitiveType::u32); @@ -1056,8 +1056,8 @@ llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) { if (type.is_pointer()) { return builder->CreatePtrToInt(val, tlctx->get_data_type()); } - if (auto cit = type->cast()) { - intermediate_bits = data_type_bits(cit->get_compute_type()); + if (auto qit = type->cast()) { + intermediate_bits = data_type_bits(qit->get_compute_type()); } else { intermediate_bits = tlctx->get_data_type(type)->getPrimitiveSizeInBits(); } @@ -1188,17 +1188,17 @@ llvm::Value *CodeGenLLVM::optimized_reduction(AtomicOpStmt *stmt) { return nullptr; } -llvm::Value *CodeGenLLVM::custom_type_atomic(AtomicOpStmt *stmt) { - // TODO(type): support all AtomicOpTypes on custom types +llvm::Value *CodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) { + // TODO(type): support all AtomicOpTypes on quant types if (stmt->op_type != AtomicOpType::add) { return nullptr; } auto dst_type = stmt->dest->ret_type->as()->get_pointee_type(); - if (auto cit = dst_type->cast()) { - return atomic_add_quant_int(stmt, cit); - } else if (auto cfxt = dst_type->cast()) { - return atomic_add_quant_fixed(stmt, cfxt); + if (auto qit = dst_type->cast()) { + return atomic_add_quant_int(stmt, qit); + } else if (auto qfxt = dst_type->cast()) { + return atomic_add_quant_fixed(stmt, qfxt); } else { return nullptr; } @@ -1318,7 +1318,7 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { if (llvm::Value *result = optimized_reduction(stmt)) { old_value = result; - } else if (llvm::Value *result = custom_type_atomic(stmt)) { + } else if (llvm::Value *result = quant_type_atomic(stmt)) { old_value = result; } else if (llvm::Value *result = real_type_atomic(stmt)) { old_value = result; @@ -1341,7 +1341,7 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto ptr_type = stmt->dest->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto pointee_type = ptr_type->get_pointee_type(); - if (!pointee_type->is()) { + if (!pointee_type->is()) { if (stmt->dest->as()->input_snode->type == SNodeType::bit_struct) { TI_ERROR( @@ -1349,13 +1349,13 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { "handled by BitStructStoreStmt.", pointee_type->to_string()); } else { - TI_ERROR("Bit array only supports custom int type."); + TI_ERROR("Bit array only supports quant int type."); } } llvm::Value *store_value = nullptr; - auto *cit = pointee_type->as(); + auto *qit = pointee_type->as(); store_value = llvm_val[stmt->val]; - store_quant_int(llvm_val[stmt->dest], cit, store_value, /*atomic=*/true); + store_quant_int(llvm_val[stmt->dest], qit, store_value, /*atomic=*/true); } else { builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); } @@ -1367,11 +1367,11 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { auto ptr_type = stmt->src->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto val_type = ptr_type->get_pointee_type(); - if (val_type->is()) { + if (val_type->is()) { llvm_val[stmt] = load_quant_int(llvm_val[stmt->src], val_type); } else { - TI_ASSERT(val_type->is() || - val_type->is()); + TI_ASSERT(val_type->is() || + val_type->is()); TI_ASSERT(stmt->src->is()); llvm_val[stmt] = load_quant_fixed_or_quant_float(stmt->src); } diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 05fbbb8ce09f8..e48d79fa6c766 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -219,18 +219,17 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(SNodeOpStmt *stmt) override; - llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt, - CustomFixedType *cfxt); + llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt, QuantFixedType *qfxt); - llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, CustomIntType *cit); + llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, QuantIntType *qit); - llvm::Value *quant_fixed_to_quant_int(CustomFixedType *cfxt, - CustomIntType *cit, + llvm::Value *quant_fixed_to_quant_int(QuantFixedType *qfxt, + QuantIntType *qit, llvm::Value *real); virtual llvm::Value *optimized_reduction(AtomicOpStmt *stmt); - virtual llvm::Value *custom_type_atomic(AtomicOpStmt *stmt); + virtual llvm::Value *quant_type_atomic(AtomicOpStmt *stmt); virtual llvm::Value *integral_type_atomic(AtomicOpStmt *stmt); @@ -248,13 +247,13 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(PtrOffsetStmt *stmt) override; void store_quant_int(llvm::Value *bit_ptr, - CustomIntType *cit, + QuantIntType *qit, llvm::Value *value, bool atomic); void store_quant_int(llvm::Value *byte_ptr, llvm::Value *bit_offset, - CustomIntType *cit, + QuantIntType *qit, llvm::Value *value, bool atomic); @@ -284,16 +283,16 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { Type *load_type); llvm::Value *reconstruct_quant_fixed(llvm::Value *digits, - CustomFixedType *cfxt); + QuantFixedType *qfxt); llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr, llvm::Value *exponent_bit_ptr, - CustomFloatType *cft, + QuantFloatType *qflt, bool shared_exponent); llvm::Value *reconstruct_quant_float(llvm::Value *input_digits, llvm::Value *input_exponent_val, - CustomFloatType *cft, + QuantFloatType *qflt, bool shared_exponent); llvm::Value *load_quant_fixed_or_quant_float(Stmt *ptr_stmt); @@ -404,7 +403,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *f, llvm::Value *shared_exp); - llvm::Value *get_exponent_offset(llvm::Value *exponent, CustomFloatType *cft); + llvm::Value *get_exponent_offset(llvm::Value *exponent, QuantFloatType *qflt); void visit(FuncCallStmt *stmt) override; diff --git a/taichi/codegen/codegen_llvm_quant.cpp b/taichi/codegen/codegen_llvm_quant.cpp index d77b785271eb9..c92a0720c5ec8 100644 --- a/taichi/codegen/codegen_llvm_quant.cpp +++ b/taichi/codegen/codegen_llvm_quant.cpp @@ -18,39 +18,39 @@ inline void update_mask(uint64 &mask, uint32 num_bits, uint32 offset) { } // namespace llvm::Value *CodeGenLLVM::atomic_add_quant_int(AtomicOpStmt *stmt, - CustomIntType *cit) { + QuantIntType *qit) { auto [byte_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->dest]); - auto physical_type = cit->get_physical_type(); + auto physical_type = qit->get_physical_type(); return create_call( fmt::format("atomic_add_partial_bits_b{}", data_type_bits(physical_type)), {builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type)), - bit_offset, tlctx->get_constant(cit->get_num_bits()), + bit_offset, tlctx->get_constant(qit->get_num_bits()), builder->CreateIntCast(llvm_val[stmt->val], llvm_type(physical_type), is_signed(stmt->val->ret_type))}); } llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt, - CustomFixedType *cfxt) { + QuantFixedType *qfxt) { auto [byte_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->dest]); - auto cit = cfxt->get_digits_type()->as(); - auto val_store = quant_fixed_to_quant_int(cfxt, cit, llvm_val[stmt->val]); - auto physical_type = cit->get_physical_type(); + auto qit = qfxt->get_digits_type()->as(); + auto val_store = quant_fixed_to_quant_int(qfxt, qit, llvm_val[stmt->val]); + auto physical_type = qit->get_physical_type(); val_store = builder->CreateSExt(val_store, llvm_type(physical_type)); return create_call( fmt::format("atomic_add_partial_bits_b{}", data_type_bits(physical_type)), {builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type)), - bit_offset, tlctx->get_constant(cit->get_num_bits()), val_store}); + bit_offset, tlctx->get_constant(qit->get_num_bits()), val_store}); } -llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(CustomFixedType *cfxt, - CustomIntType *cit, +llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt, + QuantIntType *qit, llvm::Value *real) { llvm::Value *s = nullptr; // Compute int(real * (1.0 / scale) + 0.5) - auto s_numeric = 1.0 / cfxt->get_scale(); - auto compute_type = cfxt->get_compute_type(); + auto s_numeric = 1.0 / qfxt->get_scale(); + auto compute_type = qfxt->get_compute_type(); s = builder->CreateFPCast(tlctx->get_constant(s_numeric), llvm_type(compute_type)); auto input_real = builder->CreateFPCast(real, llvm_type(compute_type)); @@ -61,36 +61,35 @@ llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(CustomFixedType *cfxt, fmt::format("rounding_prepare_f{}", data_type_bits(compute_type)), {scaled}); - if (cit->get_is_signed()) { - return builder->CreateFPToSI(scaled, llvm_type(cit->get_compute_type())); + if (qit->get_is_signed()) { + return builder->CreateFPToSI(scaled, llvm_type(qit->get_compute_type())); } else { - return builder->CreateFPToUI(scaled, llvm_type(cit->get_compute_type())); + return builder->CreateFPToUI(scaled, llvm_type(qit->get_compute_type())); } } void CodeGenLLVM::store_quant_int(llvm::Value *bit_ptr, - CustomIntType *cit, + QuantIntType *qit, llvm::Value *value, bool atomic) { auto [byte_ptr, bit_offset] = load_bit_pointer(bit_ptr); - store_quant_int(byte_ptr, bit_offset, cit, value, atomic); + store_quant_int(byte_ptr, bit_offset, qit, value, atomic); } void CodeGenLLVM::store_quant_int(llvm::Value *byte_ptr, llvm::Value *bit_offset, - CustomIntType *cit, + QuantIntType *qit, llvm::Value *value, bool atomic) { // TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers. - // Try to support CustomInt/FloatType with 8/16-bit physical - // types. + // Try to support 8/16-bit physical types. create_call(fmt::format("{}set_partial_bits_b{}", atomic ? "atomic_" : "", - data_type_bits(cit->get_physical_type())), + data_type_bits(qit->get_physical_type())), {builder->CreateBitCast(byte_ptr, - llvm_ptr_type(cit->get_physical_type())), - bit_offset, tlctx->get_constant(cit->get_num_bits()), + llvm_ptr_type(qit->get_physical_type())), + bit_offset, tlctx->get_constant(qit->get_num_bits()), builder->CreateIntCast( - value, llvm_type(cit->get_physical_type()), false)}); + value, llvm_type(qit->get_physical_type()), false)}); } void CodeGenLLVM::store_masked(llvm::Value *byte_ptr, @@ -116,31 +115,31 @@ void CodeGenLLVM::store_masked(llvm::Value *byte_ptr, } llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent, - CustomFloatType *cft) { + QuantFloatType *qflt) { // Since we have fewer bits in the exponent type than in f32, an // offset is necessary to make sure the stored exponent values are - // representable by the exponent custom int type. + // representable by the exponent quant int type. auto cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_NE, exponent, tlctx->get_constant(0)); return builder->CreateSelect( - cond, tlctx->get_constant(cft->get_exponent_conversion_offset()), + cond, tlctx->get_constant(qflt->get_exponent_conversion_offset()), tlctx->get_constant(0)); } llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val, Type *input_type, Type *output_type) { - CustomIntType *cit = nullptr; - if (auto cfxt = input_type->cast()) { - cit = cfxt->get_digits_type()->as(); - val = quant_fixed_to_quant_int(cfxt, cit, val); + QuantIntType *qit = nullptr; + if (auto qfxt = input_type->cast()) { + qit = qfxt->get_digits_type()->as(); + val = quant_fixed_to_quant_int(qfxt, qit, val); } else { - cit = input_type->as(); + qit = input_type->as(); } - if (cit->get_num_bits() < val->getType()->getIntegerBitWidth()) { + if (qit->get_num_bits() < val->getType()->getIntegerBitWidth()) { val = builder->CreateAnd( - val, tlctx->get_constant(cit->get_compute_type(), - uint64((1ULL << cit->get_num_bits()) - 1))); + val, tlctx->get_constant(qit->get_compute_type(), + uint64((1ULL << qit->get_num_bits()) - 1))); } val = builder->CreateZExt(val, llvm_type(output_type)); return val; @@ -189,12 +188,12 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { } auto dtype = ch->dt; - if (auto cft = dtype->cast()) { - // Custom float type with non-shared exponent. + if (auto qflt = dtype->cast()) { + // Quant float type with non-shared exponent. llvm::Value *digit_bits = nullptr; // Extract exponent and digits from compute type (assumed to be f32 for // now). - TI_ASSERT(cft->get_compute_type()->is_primitive(PrimitiveTypeID::f32)); + TI_ASSERT(qflt->get_compute_type()->is_primitive(PrimitiveTypeID::f32)); // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits @@ -202,34 +201,34 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { builder->CreateBitCast(val, llvm::Type::getInt32Ty(*llvm_context)); // Rounding to nearest here. Note that if the digits overflows then the // carry-on will contribute to the exponent, which is desired. - if (cft->get_digit_bits() < 23) { + if (qflt->get_digit_bits() < 23) { f32_bits = builder->CreateAdd( - f32_bits, tlctx->get_constant(1 << (22 - cft->get_digit_bits()))); + f32_bits, tlctx->get_constant(1 << (22 - qflt->get_digit_bits()))); } auto exponent_bits = builder->CreateAShr(f32_bits, 23); exponent_bits = builder->CreateAnd(exponent_bits, tlctx->get_constant((1 << 8) - 1)); auto value_bits = builder->CreateAShr( - f32_bits, tlctx->get_constant(23 - cft->get_digit_bits())); + f32_bits, tlctx->get_constant(23 - qflt->get_digit_bits())); digit_bits = builder->CreateAnd( - value_bits, tlctx->get_constant((1 << (cft->get_digit_bits())) - 1)); + value_bits, tlctx->get_constant((1 << (qflt->get_digit_bits())) - 1)); - if (cft->get_is_signed()) { + if (qflt->get_is_signed()) { // extract the sign bit auto sign_bit = builder->CreateAnd(f32_bits, tlctx->get_constant(0x80000000u)); // insert the sign bit to digit bits digit_bits = builder->CreateOr( digit_bits, - builder->CreateLShr(sign_bit, 31 - cft->get_digit_bits())); + builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits())); } auto digits_snode = ch.get(); auto exponent_snode = digits_snode->exp_snode; - auto exponent_offset = get_exponent_offset(exponent_bits, cft); + auto exponent_offset = get_exponent_offset(exponent_bits, qflt); exponent_bits = builder->CreateSub(exponent_bits, exponent_offset); exponent_bits = create_call("max_i32", {exponent_bits, tlctx->get_constant(0)}); @@ -283,20 +282,20 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { continue; } auto dtype = ch->dt; - CustomIntType *cit = nullptr; - if (auto cft = dtype->cast()) { - auto exp = cft->get_exponent_type(); - auto exponent_cit = exp->as(); + QuantIntType *qit = nullptr; + if (auto qflt = dtype->cast()) { + auto exp = qflt->get_exponent_type(); + auto exponent_qit = exp->as(); auto exponent_snode = ch->exp_snode; - update_mask(mask, exponent_cit->get_num_bits(), + update_mask(mask, exponent_qit->get_num_bits(), exponent_snode->bit_offset); - cit = cft->get_digits_type()->as(); - } else if (auto cfxt = dtype->cast()) { - cit = cfxt->get_digits_type()->as(); + qit = qflt->get_digits_type()->as(); + } else if (auto qfxt = dtype->cast()) { + qit = qfxt->get_digits_type()->as(); } else { - cit = dtype->as(); + qit = dtype->as(); } - update_mask(mask, cit->get_num_bits(), ch->bit_offset); + update_mask(mask, qit->get_num_bits(), ch->bit_offset); } store_masked(llvm_val[stmt->ptr], mask, bit_struct_physical_type, bit_struct_val, stmt->is_atomic); @@ -346,8 +345,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( } } - auto first_cft = exp->exponent_users[0]->dt->as(); - auto exponent_offset = get_exponent_offset(max_exp_bits, first_cft); + auto first_qflt = exp->exponent_users[0]->dt->as(); + auto exponent_offset = get_exponent_offset(max_exp_bits, first_qflt); auto max_exp_bits_to_store = builder->CreateSub(max_exp_bits, exponent_offset); @@ -365,7 +364,7 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( } else { masked_val = builder->CreateOr(masked_val, val); } - update_mask(mask, exp->dt->as()->get_num_bits(), + update_mask(mask, exp->dt->as()->get_num_bits(), exp->bit_offset); for (int c = 0; c < (int)exp->exponent_users.size(); c++) { @@ -374,11 +373,12 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( auto digits = extract_digits_from_f32_with_shared_exponent(floats[c], max_exp_bits); auto digits_snode = snode->ch[ch_id].get(); - auto cft = digits_snode->dt->as(); + auto qflt = digits_snode->dt->as(); auto digits_bit_offset = digits_snode->bit_offset; - int right_shift_bits = 23 + cft->get_is_signed() - cft->get_digit_bits(); - if (!cft->get_is_signed()) { + int right_shift_bits = + 23 + qflt->get_is_signed() - qflt->get_digit_bits(); + if (!qflt->get_is_signed()) { // unsigned right_shift_bits += 1; } @@ -390,15 +390,15 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( digits = create_call("min_u32", {digits, tlctx->get_constant((1u << 24) - 1)}); - // Compress f32 digits to cft digits. + // Compress f32 digits to qflt digits. // Note that we need to keep the leading 1 bit so 24 instead of 23 in the // following code. digits = builder->CreateLShr(digits, right_shift_bits); - if (cft->get_is_signed()) { + if (qflt->get_is_signed()) { auto float_bits = builder->CreateBitCast( floats[c], llvm::Type::getInt32Ty(*llvm_context)); auto sign_bit = builder->CreateAnd(float_bits, 1 << 31); - sign_bit = builder->CreateLShr(sign_bit, 31 - cft->get_digit_bits()); + sign_bit = builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits()); digits = builder->CreateOr(digits, sign_bit); } @@ -407,7 +407,7 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( val = builder->CreateShl(val, digits_bit_offset); masked_val = builder->CreateOr(masked_val, val); auto num_digit_bits = - cft->get_digits_type()->as()->get_num_bits(); + qflt->get_digits_type()->as()->get_num_bits(); update_mask(mask, num_digit_bits, digits_bit_offset); } } @@ -459,25 +459,25 @@ llvm::Value *CodeGenLLVM::extract_digits_from_f32_with_shared_exponent( llvm::Value *CodeGenLLVM::extract_quant_float(llvm::Value *local_bit_struct, SNode *digits_snode) { - auto cft = digits_snode->dt->as(); - auto exponent_type = cft->get_exponent_type()->as(); - auto digits_type = cft->get_digits_type()->as(); + auto qflt = digits_snode->dt->as(); + auto exponent_type = qflt->get_exponent_type()->as(); + auto digits_type = qflt->get_digits_type()->as(); auto digits = extract_quant_int(local_bit_struct, tlctx->get_constant(digits_snode->bit_offset), digits_type); auto exponent = extract_quant_int( local_bit_struct, tlctx->get_constant(digits_snode->exp_snode->bit_offset), exponent_type); - return reconstruct_quant_float(digits, exponent, cft, + return reconstruct_quant_float(digits, exponent, qflt, digits_snode->owns_shared_exponent); } llvm::Value *CodeGenLLVM::load_quant_int(llvm::Value *ptr, Type *load_type) { - auto *cit = load_type->as(); + auto *qit = load_type->as(); auto [byte_ptr, bit_offset] = load_bit_pointer(ptr); auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( - byte_ptr, llvm_ptr_type(cit->get_physical_type()))); + byte_ptr, llvm_ptr_type(qit->get_physical_type()))); return extract_quant_int(bit_level_container, bit_offset, load_type); } @@ -488,66 +488,66 @@ llvm::Value *CodeGenLLVM::extract_quant_int(llvm::Value *physical_value, // bit shifting // first left shift `physical_type - (offset + num_bits)` // then right shift `physical_type - num_bits` - auto cit = load_type->as(); + auto qit = load_type->as(); auto bit_end = - builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); + builder->CreateAdd(bit_offset, tlctx->get_constant(qit->get_num_bits())); auto left = builder->CreateSub( - tlctx->get_constant(data_type_bits(cit->get_physical_type())), bit_end); + tlctx->get_constant(data_type_bits(qit->get_physical_type())), bit_end); auto right = builder->CreateSub( - tlctx->get_constant(data_type_bits(cit->get_physical_type())), - tlctx->get_constant(cit->get_num_bits())); + tlctx->get_constant(data_type_bits(qit->get_physical_type())), + tlctx->get_constant(qit->get_num_bits())); left = builder->CreateIntCast(left, physical_value->getType(), false); right = builder->CreateIntCast(right, physical_value->getType(), false); auto step1 = builder->CreateShl(physical_value, left); llvm::Value *step2 = nullptr; - if (cit->get_is_signed()) + if (qit->get_is_signed()) step2 = builder->CreateAShr(step1, right); else step2 = builder->CreateLShr(step1, right); - return builder->CreateIntCast(step2, llvm_type(cit->get_compute_type()), - cit->get_is_signed()); + return builder->CreateIntCast(step2, llvm_type(qit->get_compute_type()), + qit->get_is_signed()); } llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits, - CustomFixedType *cfxt) { + QuantFixedType *qfxt) { // Compute float(digits) * scale llvm::Value *cast = nullptr; - auto compute_type = cfxt->get_compute_type()->as(); - if (cfxt->get_is_signed()) { + auto compute_type = qfxt->get_compute_type()->as(); + if (qfxt->get_is_signed()) { cast = builder->CreateSIToFP(digits, llvm_type(compute_type)); } else { cast = builder->CreateUIToFP(digits, llvm_type(compute_type)); } - llvm::Value *s = tlctx->get_constant(cfxt->get_scale()); + llvm::Value *s = tlctx->get_constant(qfxt->get_scale()); s = builder->CreateFPCast(s, llvm_type(compute_type)); return builder->CreateFMul(cast, s); } llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr, llvm::Value *exponent_bit_ptr, - CustomFloatType *cft, + QuantFloatType *qflt, bool shared_exponent) { - auto digits = load_quant_int(digits_bit_ptr, cft->get_digits_type()); + auto digits = load_quant_int(digits_bit_ptr, qflt->get_digits_type()); auto exponent_val = load_quant_int( - exponent_bit_ptr, cft->get_exponent_type()->as()); - return reconstruct_quant_float(digits, exponent_val, cft, shared_exponent); + exponent_bit_ptr, qflt->get_exponent_type()->as()); + return reconstruct_quant_float(digits, exponent_val, qflt, shared_exponent); } llvm::Value *CodeGenLLVM::reconstruct_quant_float( llvm::Value *input_digits, llvm::Value *input_exponent_val, - CustomFloatType *cft, + QuantFloatType *qflt, bool shared_exponent) { auto digits = input_digits; auto exponent_val = input_exponent_val; // Make sure the exponent is within the range of the exponent type auto exponent_offset = - tlctx->get_constant(cft->get_exponent_conversion_offset()); + tlctx->get_constant(qflt->get_exponent_conversion_offset()); // Note that zeros need special treatment, when truncated during store. - auto exponent_type = cft->get_exponent_type()->as(); + auto exponent_type = qflt->get_exponent_type()->as(); if (exponent_type->get_num_bits() < 8) { auto cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_NE, exponent_val, tlctx->get_constant(0)); @@ -555,24 +555,24 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_float( builder->CreateSelect(cond, exponent_offset, tlctx->get_constant(0)); } - if (cft->get_compute_type()->is_primitive(PrimitiveTypeID::f32)) { + if (qflt->get_compute_type()->is_primitive(PrimitiveTypeID::f32)) { // Construct an f32 out of exponent_val and digits // Assuming digits and exponent_val are i32 // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits digits = builder->CreateAnd( digits, - (1u << cft->get_digits_type()->as()->get_num_bits()) - + (1u << qflt->get_digits_type()->as()->get_num_bits()) - 1); llvm::Value *sign_bit = nullptr; if (shared_exponent) { - if (cft->get_is_signed()) { + if (qflt->get_is_signed()) { sign_bit = builder->CreateAnd( - digits, tlctx->get_constant(1u << cft->get_digit_bits())); + digits, tlctx->get_constant(1u << qflt->get_digit_bits())); digits = builder->CreateXor(digits, sign_bit); - sign_bit = builder->CreateShl(sign_bit, 31 - cft->get_digit_bits()); + sign_bit = builder->CreateShl(sign_bit, 31 - qflt->get_digit_bits()); digits = builder->CreateShl(digits, 1); } // There is a leading 1 that marks the beginning of the digits. @@ -583,19 +583,19 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_float( llvm::Intrinsic::ctlz, {llvm::Type::getInt32Ty(*llvm_context)}, {digits, tlctx->get_constant(false)}); auto extra_shift = builder->CreateSub( - tlctx->get_constant(31 - cft->get_digit_bits()), num_leading_zeros); + tlctx->get_constant(31 - qflt->get_digit_bits()), num_leading_zeros); exponent_offset = builder->CreateAdd(exponent_offset, extra_shift); - if (!cft->get_is_signed()) + if (!qflt->get_is_signed()) exponent_offset = builder->CreateAdd(exponent_offset, tlctx->get_constant(1)); auto digits_shift = builder->CreateSub( - tlctx->get_constant(23 - cft->get_digit_bits()), extra_shift); + tlctx->get_constant(23 - qflt->get_digit_bits()), extra_shift); digits = builder->CreateShl(digits, digits_shift); } else { digits = builder->CreateShl( - digits, tlctx->get_constant(23 - cft->get_digit_bits())); + digits, tlctx->get_constant(23 - qflt->get_digit_bits())); } auto fraction_bits = builder->CreateAnd(digits, (1u << 23) - 1); @@ -619,7 +619,7 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_float( builder->CreateSelect(zero_output, tlctx->get_constant(0), f32_bits); } - if (cft->get_is_signed()) { + if (qflt->get_is_signed()) { if (!sign_bit) { sign_bit = builder->CreateAnd(digits, tlctx->get_constant(1u << 23)); sign_bit = builder->CreateShl(sign_bit, tlctx->get_constant(31 - 23)); @@ -637,7 +637,7 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_float( llvm::Value *CodeGenLLVM::load_quant_fixed_or_quant_float(Stmt *ptr_stmt) { auto ptr = ptr_stmt->as(); auto load_type = ptr->ret_type->as()->get_pointee_type(); - if (auto cft = load_type->cast()) { + if (auto qflt = load_type->cast()) { TI_ASSERT(ptr->width() == 1); auto digits_bit_ptr = llvm_val[ptr]; auto digits_snode = ptr->output_snode; @@ -646,12 +646,12 @@ llvm::Value *CodeGenLLVM::load_quant_fixed_or_quant_float(Stmt *ptr_stmt) { TI_ASSERT(digits_snode->parent == exponent_snode->parent); auto exponent_bit_ptr = offset_bit_ptr( digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset); - return load_quant_float(digits_bit_ptr, exponent_bit_ptr, cft, + return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, digits_snode->owns_shared_exponent); } else { - auto cfxt = load_type->as(); - auto digits = load_quant_int(llvm_val[ptr], cfxt->get_digits_type()); - return reconstruct_quant_fixed(digits, cfxt); + auto qfxt = load_type->as(); + auto digits = load_quant_int(llvm_val[ptr], qfxt->get_digits_type()); + return reconstruct_quant_fixed(digits, qfxt); } } diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index 9fb816a272835..5fd05e0dce7b4 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -135,7 +135,7 @@ class SNode { // Note: parent will not be set until structural nodes are compiled! SNode *parent{nullptr}; std::unique_ptr grad_info{nullptr}; - SNode *exp_snode{nullptr}; // for CustomFloatType + SNode *exp_snode{nullptr}; // for QuantFloatType int bit_offset{0}; // for children of bit_struct only bool placing_shared_exp{false}; SNode *currently_placing_exp_snode{nullptr}; diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 789358bb64701..4de6dd12134cc 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -99,14 +99,14 @@ bool Type::is_primitive(PrimitiveTypeID type) const { } } -std::string CustomIntType::to_string() const { - return fmt::format("c{}{}", is_signed_ ? 'i' : 'u', num_bits_); +std::string QuantIntType::to_string() const { + return fmt::format("q{}{}", is_signed_ ? 'i' : 'u', num_bits_); } -CustomIntType::CustomIntType(int num_bits, - bool is_signed, - Type *compute_type, - Type *physical_type) +QuantIntType::QuantIntType(int num_bits, + bool is_signed, + Type *compute_type, + Type *physical_type) : compute_type_(compute_type), physical_type_(physical_type), num_bits_(num_bits), @@ -118,58 +118,58 @@ CustomIntType::CustomIntType(int num_bits, } } -CustomFixedType::CustomFixedType(Type *digits_type, - Type *compute_type, - float64 scale) +QuantFixedType::QuantFixedType(Type *digits_type, + Type *compute_type, + float64 scale) : digits_type_(digits_type), compute_type_(compute_type), scale_(scale) { - TI_ASSERT(digits_type->is()); + TI_ASSERT(digits_type->is()); TI_ASSERT(compute_type->is()); TI_ASSERT(is_real(compute_type)); } -std::string CustomFixedType::to_string() const { - return fmt::format("cfx(d={} c={} s={})", digits_type_->to_string(), +std::string QuantFixedType::to_string() const { + return fmt::format("qfx(d={} c={} s={})", digits_type_->to_string(), compute_type_->to_string(), scale_); } -bool CustomFixedType::get_is_signed() const { - return digits_type_->as()->get_is_signed(); +bool QuantFixedType::get_is_signed() const { + return digits_type_->as()->get_is_signed(); } -CustomFloatType::CustomFloatType(Type *digits_type, - Type *exponent_type, - Type *compute_type) +QuantFloatType::QuantFloatType(Type *digits_type, + Type *exponent_type, + Type *compute_type) : digits_type_(digits_type), exponent_type_(exponent_type), compute_type_(compute_type) { - TI_ASSERT(digits_type->is()); + TI_ASSERT(digits_type->is()); // We only support f32 as compute type when when using exponents TI_ASSERT(compute_type_->is_primitive(PrimitiveTypeID::f32)); - // Exponent must be unsigned custom int - TI_ASSERT(exponent_type->is()); - TI_ASSERT(exponent_type->as()->get_num_bits() <= 8); - TI_ASSERT(exponent_type->as()->get_is_signed() == false); + // Exponent must be unsigned quant int + TI_ASSERT(exponent_type->is()); + TI_ASSERT(exponent_type->as()->get_num_bits() <= 8); + TI_ASSERT(exponent_type->as()->get_is_signed() == false); TI_ASSERT(get_digit_bits() <= 23); } -std::string CustomFloatType::to_string() const { - return fmt::format("cf(d={} e={} c={})", digits_type_->to_string(), +std::string QuantFloatType::to_string() const { + return fmt::format("qfl(d={} e={} c={})", digits_type_->to_string(), exponent_type_->to_string(), compute_type_->to_string()); } -int CustomFloatType::get_exponent_conversion_offset() const { +int QuantFloatType::get_exponent_conversion_offset() const { // Note that f32 has exponent offset -127 - return 127 - - (1 << (exponent_type_->as()->get_num_bits() - 1)) + 1; + return 127 - (1 << (exponent_type_->as()->get_num_bits() - 1)) + + 1; } -int CustomFloatType::get_digit_bits() const { - return digits_type_->as()->get_num_bits() - +int QuantFloatType::get_digit_bits() const { + return digits_type_->as()->get_num_bits() - (int)get_is_signed(); } -bool CustomFloatType::get_is_signed() const { - return digits_type_->as()->get_is_signed(); +bool QuantFloatType::get_is_signed() const { + return digits_type_->as()->get_is_signed(); } BitStructType::BitStructType(PrimitiveType *physical_type, @@ -181,17 +181,17 @@ BitStructType::BitStructType(PrimitiveType *physical_type, TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); int physical_type_bits = data_type_bits(physical_type); for (auto i = 0; i < member_types_.size(); ++i) { - CustomIntType *component_cit = nullptr; - if (auto cit = member_types_[i]->cast()) { - component_cit = cit; - } else if (auto cfxt = member_types_[i]->cast()) { - component_cit = cfxt->get_digits_type()->as(); - } else if (auto cft = member_types_[i]->cast()) { - component_cit = cft->get_digits_type()->as(); + QuantIntType *component_qit = nullptr; + if (auto qit = member_types_[i]->cast()) { + component_qit = qit; + } else if (auto qfxt = member_types_[i]->cast()) { + component_qit = qfxt->get_digits_type()->as(); + } else if (auto qflt = member_types_[i]->cast()) { + component_qit = qflt->get_digits_type()->as(); } else { TI_NOT_IMPLEMENTED } - auto bits_end = component_cit->get_num_bits() + member_bit_offsets_[i]; + auto bits_end = component_qit->get_num_bits() + member_bit_offsets_[i]; TI_ASSERT(physical_type_bits >= bits_end) } } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 2e8823b1f9bb0..ed5de75cedd20 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -186,12 +186,12 @@ class TensorType : public Type { Type *element_{nullptr}; }; -class CustomIntType : public Type { +class QuantIntType : public Type { public: - CustomIntType(int num_bits, - bool is_signed, - Type *compute_type = nullptr, - Type *physical_type = nullptr); + QuantIntType(int num_bits, + bool is_signed, + Type *compute_type = nullptr, + Type *physical_type = nullptr); std::string to_string() const override; @@ -224,9 +224,9 @@ class CustomIntType : public Type { bool is_signed_{true}; }; -class CustomFixedType : public Type { +class QuantFixedType : public Type { public: - CustomFixedType(Type *digits_type, Type *compute_type, float64 scale); + QuantFixedType(Type *digits_type, Type *compute_type, float64 scale); std::string to_string() const override; @@ -250,9 +250,9 @@ class CustomFixedType : public Type { float64 scale_{1.0}; }; -class CustomFloatType : public Type { +class QuantFloatType : public Type { public: - CustomFloatType(Type *digits_type, Type *exponent_type, Type *compute_type); + QuantFloatType(Type *digits_type, Type *exponent_type, Type *compute_type); std::string to_string() const override; @@ -319,8 +319,8 @@ class BitArrayType : public Type { element_type_(element_type_), num_elements_(num_elements_) { // TODO: avoid assertion? - TI_ASSERT(element_type_->is()); - element_num_bits_ = element_type_->as()->get_num_bits(); + TI_ASSERT(element_type_->is()); + element_num_bits_ = element_type_->as()->get_num_bits(); } std::string to_string() const override; diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index c54de8611fdb5..ce7bb4247f25e 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -45,37 +45,37 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { return pointer_types_[key].get(); } -Type *TypeFactory::get_custom_int_type(int num_bits, - bool is_signed, - Type *compute_type) { +Type *TypeFactory::get_quant_int_type(int num_bits, + bool is_signed, + Type *compute_type) { auto key = std::make_tuple(num_bits, is_signed, compute_type); - if (custom_int_types_.find(key) == custom_int_types_.end()) { - custom_int_types_[key] = - std::make_unique(num_bits, is_signed, compute_type); + if (quant_int_types_.find(key) == quant_int_types_.end()) { + quant_int_types_[key] = + std::make_unique(num_bits, is_signed, compute_type); } - return custom_int_types_[key].get(); + return quant_int_types_[key].get(); } -Type *TypeFactory::get_custom_fixed_type(Type *digits_type, - Type *compute_type, - float64 scale) { +Type *TypeFactory::get_quant_fixed_type(Type *digits_type, + Type *compute_type, + float64 scale) { auto key = std::make_tuple(digits_type, compute_type, scale); - if (custom_fixed_types_.find(key) == custom_fixed_types_.end()) { - custom_fixed_types_[key] = - std::make_unique(digits_type, compute_type, scale); + if (quant_fixed_types_.find(key) == quant_fixed_types_.end()) { + quant_fixed_types_[key] = + std::make_unique(digits_type, compute_type, scale); } - return custom_fixed_types_[key].get(); + return quant_fixed_types_[key].get(); } -Type *TypeFactory::get_custom_float_type(Type *digits_type, - Type *exponent_type, - Type *compute_type) { +Type *TypeFactory::get_quant_float_type(Type *digits_type, + Type *exponent_type, + Type *compute_type) { auto key = std::make_tuple(digits_type, exponent_type, compute_type); - if (custom_float_types_.find(key) == custom_float_types_.end()) { - custom_float_types_[key] = std::make_unique( + if (quant_float_types_.find(key) == quant_float_types_.end()) { + quant_float_types_[key] = std::make_unique( digits_type, exponent_type, compute_type); } - return custom_float_types_[key].get(); + return quant_float_types_[key].get(); } Type *TypeFactory::get_bit_struct_type(PrimitiveType *physical_type, diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index cd20513ae2f3e..9ab9aeee17edb 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -23,15 +23,15 @@ class TypeFactory { Type *get_pointer_type(Type *element, bool is_bit_pointer = false); - Type *get_custom_int_type(int num_bits, bool is_signed, Type *compute_type); + Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type); - Type *get_custom_fixed_type(Type *digits_type, - Type *compute_type, - float64 scale); + Type *get_quant_fixed_type(Type *digits_type, + Type *compute_type, + float64 scale); - Type *get_custom_float_type(Type *digits_type, - Type *exponent_type, - Type *compute_type); + Type *get_quant_float_type(Type *digits_type, + Type *exponent_type, + Type *compute_type); Type *get_bit_struct_type(PrimitiveType *physical_type, std::vector member_types, @@ -63,15 +63,15 @@ class TypeFactory { // TODO: use unordered map std::map, std::unique_ptr> - custom_int_types_; + quant_int_types_; // TODO: use unordered map std::map, std::unique_ptr> - custom_fixed_types_; + quant_fixed_types_; // TODO: use unordered map std::map, std::unique_ptr> - custom_float_types_; + quant_float_types_; // TODO: avoid duplication std::vector> bit_struct_types_; diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index a14e435473c02..6970a4d3a0ef6 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -38,7 +38,7 @@ std::string data_type_format(DataType dt) { return "%f"; } else if (dt->is_primitive(PrimitiveTypeID::f64)) { return "%.12f"; - } else if (dt->is()) { + } else if (dt->is()) { return "%d"; } else if (dt->is_primitive(PrimitiveTypeID::f16)) { // f16 (and f32) is converted to f64 before printing, see diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index a616505de3c6d..25d5f11848fa8 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -74,15 +74,15 @@ inline PrimitiveTypeID get_primitive_data_type() { } inline bool is_quant(DataType dt) { - return dt->is() || dt->is() || - dt->is(); + return dt->is() || dt->is() || + dt->is(); } inline bool is_real(DataType dt) { return dt->is_primitive(PrimitiveTypeID::f16) || dt->is_primitive(PrimitiveTypeID::f32) || - dt->is_primitive(PrimitiveTypeID::f64) || dt->is() || - dt->is(); + dt->is_primitive(PrimitiveTypeID::f64) || dt->is() || + dt->is(); } inline bool is_integral(DataType dt) { @@ -93,13 +93,13 @@ inline bool is_integral(DataType dt) { dt->is_primitive(PrimitiveTypeID::u8) || dt->is_primitive(PrimitiveTypeID::u16) || dt->is_primitive(PrimitiveTypeID::u32) || - dt->is_primitive(PrimitiveTypeID::u64) || dt->is(); + dt->is_primitive(PrimitiveTypeID::u64) || dt->is(); } inline bool is_signed(DataType dt) { // Shall we return false if is_integral returns false? TI_ASSERT(is_integral(dt)); - if (auto t = dt->cast()) + if (auto t = dt->cast()) return t->get_is_signed(); return dt->is_primitive(PrimitiveTypeID::i8) || dt->is_primitive(PrimitiveTypeID::i16) || diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 8186a9be18a02..e6149271d38c0 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -50,7 +50,7 @@ Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this), ndarray_rw_accessors_bank_(this) { TI_TRACE("Program initializing..."); - // For performance considerations and correctness of CustomFloatType + // For performance considerations and correctness of QuantFloatType // operations, we force floating-point operations to flush to zero on all // backends (including CPUs). #if defined(TI_ARCH_x64) diff --git a/taichi/program/snode_expr_utils.cpp b/taichi/program/snode_expr_utils.cpp index e3addfd7364f4..f0c75e7ecfa29 100644 --- a/taichi/program/snode_expr_utils.cpp +++ b/taichi/program/snode_expr_utils.cpp @@ -52,15 +52,15 @@ void place_child(Expr *expr_arg, TI_ERROR_IF(glb_var_expr->snode != nullptr, "This variable has been placed."); SNode *new_exp_snode = nullptr; - if (auto cft = glb_var_expr->dt->cast()) { - auto exp = cft->get_exponent_type(); + if (auto qflt = glb_var_expr->dt->cast()) { + auto exp = qflt->get_exponent_type(); // Non-empty exponent type. First create a place SNode for the // exponent value. if (parent->placing_shared_exp && parent->currently_placing_exp_snode != nullptr) { // Reuse existing exponent TI_ASSERT_INFO(parent->currently_placing_exp_snode_dtype == exp, - "CustomFloatTypes with shared exponents must have " + "QuantFloatTypes with shared exponents must have " "exactly the same exponent type."); new_exp_snode = parent->currently_placing_exp_snode; } else { diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 69e691df192bb..7c58541537b02 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -1036,13 +1036,13 @@ void export_lang(py::module &m) { // the factory methods, otherwise pybind11 will delete the Types owned by // TypeFactory on Python-scope pointer destruction. py::class_(m, "TypeFactory") - .def("get_custom_int_type", &TypeFactory::get_custom_int_type, + .def("get_quant_int_type", &TypeFactory::get_quant_int_type, py::arg("num_bits"), py::arg("is_signed"), py::arg("compute_type"), py::return_value_policy::reference) - .def("get_custom_fixed_type", &TypeFactory::get_custom_fixed_type, + .def("get_quant_fixed_type", &TypeFactory::get_quant_fixed_type, py::arg("digits_type"), py::arg("compute_type"), py::arg("scale"), py::return_value_policy::reference) - .def("get_custom_float_type", &TypeFactory::get_custom_float_type, + .def("get_quant_float_type", &TypeFactory::get_quant_float_type, py::arg("digits_type"), py::arg("exponent_type"), py::arg("compute_type"), py::return_value_policy::reference); diff --git a/taichi/struct/struct_llvm.cpp b/taichi/struct/struct_llvm.cpp index da39db90a579d..adf6a9be08a9b 100644 --- a/taichi/struct/struct_llvm.cpp +++ b/taichi/struct/struct_llvm.cpp @@ -89,24 +89,24 @@ void StructCompilerLLVM::generate_types(SNode &snode) { auto &ch = snode.ch[i]; ch_types.push_back(ch->dt); ch_offsets.push_back(total_offset); - CustomIntType *component_cit = nullptr; - if (auto cit = ch->dt->cast()) { - component_cit = cit; - } else if (auto cfxt = ch->dt->cast()) { - component_cit = cfxt->get_digits_type()->as(); - } else if (auto cft = ch->dt->cast()) { - component_cit = cft->get_digits_type()->as(); + QuantIntType *component_qit = nullptr; + if (auto qit = ch->dt->cast()) { + component_qit = qit; + } else if (auto qfxt = ch->dt->cast()) { + component_qit = qfxt->get_digits_type()->as(); + } else if (auto qflt = ch->dt->cast()) { + component_qit = qflt->get_digits_type()->as(); } else { TI_ERROR("Type {} not supported.", ch->dt->to_string()); } - component_cit->set_physical_type(snode.physical_type); + component_qit->set_physical_type(snode.physical_type); if (!arch_is_cpu(arch_)) { TI_ERROR_IF(data_type_bits(snode.physical_type) < 32, "bit_struct physical type must be at least 32 bits on " "non-CPU backends."); } ch->bit_offset = total_offset; - total_offset += component_cit->get_num_bits(); + total_offset += component_qit->get_num_bits(); auto bit_struct_size = data_type_bits(snode.physical_type); TI_ERROR_IF(total_offset > bit_struct_size, "Bit struct overflows: {} bits used out of {}.", total_offset, @@ -123,7 +123,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) { TI_ASSERT(snode.ch.size() == 1); auto &ch = snode.ch[0]; Type *ch_type = ch->dt; - ch->dt->as()->set_physical_type(snode.physical_type); + ch->dt->as()->set_physical_type(snode.physical_type); if (!arch_is_cpu(arch_)) { TI_ERROR_IF(data_type_bits(snode.physical_type) <= 16, "bit_array physical type must be at least 32 bits on " diff --git a/taichi/transforms/bit_loop_vectorize.cpp b/taichi/transforms/bit_loop_vectorize.cpp index f684c02144d54..ccbec9c0a84af 100644 --- a/taichi/transforms/bit_loop_vectorize.cpp +++ b/taichi/transforms/bit_loop_vectorize.cpp @@ -40,8 +40,8 @@ class BitLoopVectorize : public IRVisitor { void visit(GlobalLoadStmt *stmt) override { auto ptr_type = stmt->src->ret_type->as(); if (in_struct_for_loop && bit_vectorize != 1) { - if (ptr_type->get_pointee_type()->cast()) { - // rewrite the previous GlobalPtrStmt's return type from *cit to + if (ptr_type->get_pointee_type()->cast()) { + // rewrite the previous GlobalPtrStmt's return type from *qit to // *phy_type auto ptr = stmt->src->cast(); auto ptr_physical_type = TypeFactory::get_instance().get_pointer_type( @@ -127,8 +127,8 @@ class BitLoopVectorize : public IRVisitor { void visit(GlobalStoreStmt *stmt) override { auto ptr_type = stmt->dest->ret_type->as(); if (in_struct_for_loop && bit_vectorize != 1) { - if (ptr_type->get_pointee_type()->cast()) { - // rewrite the previous GlobalPtrStmt's return type from *cit to + if (ptr_type->get_pointee_type()->cast()) { + // rewrite the previous GlobalPtrStmt's return type from *qit to // *phy_type auto ptr = stmt->dest->cast(); auto ptr_physical_type = TypeFactory::get_instance().get_pointer_type( diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp index 780dae2a3cf2c..1a65bcb19689d 100644 --- a/taichi/transforms/demote_atomics.cpp +++ b/taichi/transforms/demote_atomics.cpp @@ -105,9 +105,9 @@ class DemoteAtomics : public BasicStmtVisitor { } if (auto dest_pointer_type = stmt->dest->ret_type->cast()) { - if (dest_pointer_type->get_pointee_type()->is()) { + if (dest_pointer_type->get_pointee_type()->is()) { TI_WARN( - "AtomicOp on CustomFloatType is not supported. " + "AtomicOp on QuantFloatType is not supported. " "Demoting to non-atomic RMW."); demote = true; } diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index 68c082e822172..fc6a8ba2cfd50 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -106,7 +106,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { offload, [](GlobalPtrStmt *dest) { // We can only optimized reductions to global ptrs with form like // loss[None] (0-D fields) for now. - // No TLS on CustomInt/FloatType. + // No TLS on quant types. return (dest->snodes[0]->type == SNodeType::place) && dest->indices.empty() && dest->snodes[0]->dt->is(); diff --git a/tests/cpp/ir/type_test.cpp b/tests/cpp/ir/type_test.cpp index 620117700e1f8..fb250f7193c79 100644 --- a/tests/cpp/ir/type_test.cpp +++ b/tests/cpp/ir/type_test.cpp @@ -14,19 +14,19 @@ TEST(Type, BitTypes) { auto i32 = TypeFactory::get_instance() .get_primitive_type(PrimitiveTypeID::i32) ->as(); - auto ci5 = TypeFactory::get_instance().get_custom_int_type(5, true, i32); - auto cu11 = TypeFactory::get_instance().get_custom_int_type(11, false, i32); + auto qi5 = TypeFactory::get_instance().get_quant_int_type(5, true, i32); + auto qu11 = TypeFactory::get_instance().get_quant_int_type(11, false, i32); auto u16 = TypeFactory::get_instance().get_primitive_int_type(16, false); auto bs = - TypeFactory::get_instance().get_bit_struct_type(u16, {ci5, cu11}, {0, 5}); + TypeFactory::get_instance().get_bit_struct_type(u16, {qi5, qu11}, {0, 5}); - EXPECT_EQ(bs->to_string(), "bs(ci5@0, cu11@5)"); + EXPECT_EQ(bs->to_string(), "bs(qi5@0, qu11@5)"); - auto ci1 = TypeFactory::get_instance().get_custom_int_type(1, true, i32); - auto ba = TypeFactory::get_instance().get_bit_array_type(i32, ci1, 32); + auto qi1 = TypeFactory::get_instance().get_quant_int_type(1, true, i32); + auto ba = TypeFactory::get_instance().get_bit_array_type(i32, qi1, 32); - EXPECT_EQ(ba->to_string(), "ba(ci1x32)"); + EXPECT_EQ(ba->to_string(), "ba(qi1x32)"); } } // namespace lang diff --git a/tests/python/test_bit_array.py b/tests/python/test_bit_array.py index 3c3f76fc30d87..169278b039a7a 100644 --- a/tests/python/test_bit_array.py +++ b/tests/python/test_bit_array.py @@ -6,9 +6,9 @@ @test_utils.test(require=ti.extension.quant, debug=True) def test_1D_bit_array(): - cu1 = ti.types.quant.int(1, False) + qu1 = ti.types.quant.int(1, False) - x = ti.field(dtype=cu1) + x = ti.field(dtype=qu1) N = 32 @@ -30,9 +30,9 @@ def verify_val(): @test_utils.test(require=ti.extension.quant, debug=True) def test_2D_bit_array(): - ci1 = ti.types.quant.int(1, False) + qi1 = ti.types.quant.int(1, False) - x = ti.field(dtype=ci1) + x = ti.field(dtype=qi1) M, N = 4, 8 diff --git a/tests/python/test_bit_array_vectorization.py b/tests/python/test_bit_array_vectorization.py index 44adc9e1943df..fb0bb4b804ccb 100644 --- a/tests/python/test_bit_array_vectorization.py +++ b/tests/python/test_bit_array_vectorization.py @@ -8,10 +8,10 @@ debug=True, cfg_optimization=False) def test_vectorized_struct_for(): - cu1 = ti.types.quant.int(1, False) + qu1 = ti.types.quant.int(1, False) - x = ti.field(dtype=cu1) - y = ti.field(dtype=cu1) + x = ti.field(dtype=qu1) + y = ti.field(dtype=qu1) N = 4096 n_blocks = 4 @@ -49,11 +49,11 @@ def verify(): @test_utils.test(require=ti.extension.quant) def test_offset_load(): - ci1 = ti.types.quant.int(1, False) + qi1 = ti.types.quant.int(1, False) - x = ti.field(dtype=ci1) - y = ti.field(dtype=ci1) - z = ti.field(dtype=ci1) + x = ti.field(dtype=qi1) + y = ti.field(dtype=qi1) + z = ti.field(dtype=qi1) N = 4096 n_blocks = 4 @@ -109,11 +109,11 @@ def verify(dx: ti.template(), dy: ti.template()): @test_utils.test(require=ti.extension.quant, debug=True) def test_evolve(): - ci1 = ti.types.quant.int(1, False) + qi1 = ti.types.quant.int(1, False) - x = ti.field(dtype=ci1) - y = ti.field(dtype=ci1) - z = ti.field(dtype=ci1) + x = ti.field(dtype=qi1) + y = ti.field(dtype=qi1) + z = ti.field(dtype=qi1) N = 4096 n_blocks = 4 diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 03610eabf6e4b..04403cd150106 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -7,11 +7,11 @@ @test_utils.test(require=ti.extension.quant_basic, debug=True) def test_simple_array(): - ci13 = ti.types.quant.int(13, True) - cu19 = ti.types.quant.int(19, False) + qi13 = ti.types.quant.int(13, True) + qu19 = ti.types.quant.int(19, False) - x = ti.field(dtype=ci13) - y = ti.field(dtype=cu19) + x = ti.field(dtype=qi13) + y = ti.field(dtype=qu19) N = 12 @@ -41,14 +41,14 @@ def verify_val(): @test_utils.test(require=ti.extension.quant_basic, exclude=[ti.metal], debug=True) -def test_custom_int_load_and_store(): - ci13 = ti.types.quant.int(13, True) - cu14 = ti.types.quant.int(14, False) - ci5 = ti.types.quant.int(5, True) +def test_quant_int_load_and_store(): + qi13 = ti.types.quant.int(13, True) + qu14 = ti.types.quant.int(14, False) + qi5 = ti.types.quant.int(5, True) - x = ti.field(dtype=ci13) - y = ti.field(dtype=cu14) - z = ti.field(dtype=ci5) + x = ti.field(dtype=qi13) + y = ti.field(dtype=qu14) + z = ti.field(dtype=qi5) test_case_np = np.array( [[2**12 - 1, 2**14 - 1, -(2**3)], [2**11 - 1, 2**13 - 1, -(2**2)], @@ -82,9 +82,9 @@ def verify_val(idx: ti.i32): @test_utils.test(require=ti.extension.quant_basic) -def test_custom_int_full_struct(): - cit = ti.types.quant.int(32, True) - x = ti.field(dtype=cit) +def test_quant_int_full_struct(): + qit = ti.types.quant.int(32, True) + x = ti.field(dtype=qit) ti.root.dense(ti.i, 1).bit_struct(num_bits=32).place(x) x[0] = 15 @@ -95,17 +95,17 @@ def test_custom_int_full_struct(): def test_bit_struct(): - def test_single_bit_struct(physical_type, compute_type, custom_bits, + def test_single_bit_struct(physical_type, compute_type, quant_bits, test_case): ti.init(arch=ti.cpu, debug=True) - cit1 = ti.types.quant.int(custom_bits[0], True, compute_type) - cit2 = ti.types.quant.int(custom_bits[1], False, compute_type) - cit3 = ti.types.quant.int(custom_bits[2], True, compute_type) + qit1 = ti.types.quant.int(quant_bits[0], True, compute_type) + qit2 = ti.types.quant.int(quant_bits[1], False, compute_type) + qit3 = ti.types.quant.int(quant_bits[2], True, compute_type) - a = ti.field(dtype=cit1) - b = ti.field(dtype=cit2) - c = ti.field(dtype=cit3) + a = ti.field(dtype=qit1) + b = ti.field(dtype=qit2) + c = ti.field(dtype=qit3) ti.root.bit_struct(num_bits=physical_type).place(a, b, c) @ti.kernel diff --git a/tests/python/test_cast.py b/tests/python/test_cast.py index 06d48ce73e438..c4475d0c26371 100644 --- a/tests/python/test_cast.py +++ b/tests/python/test_cast.py @@ -141,15 +141,15 @@ def run_cast_u32(): @test_utils.test(arch=ti.cpu) -def test_custom_int_extension(): +def test_quant_int_extension(): x = ti.field(dtype=ti.i32, shape=2) y = ti.field(dtype=ti.u32, shape=2) - ci5 = ti.types.quant.int(5, True, ti.i16) - cu7 = ti.types.quant.int(7, False, ti.u16) + qi5 = ti.types.quant.int(5, True, ti.i16) + qu7 = ti.types.quant.int(7, False, ti.u16) - a = ti.field(dtype=ci5) - b = ti.field(dtype=cu7) + a = ti.field(dtype=qi5) + b = ti.field(dtype=qu7) ti.root.bit_struct(num_bits=32).place(a, b) diff --git a/tests/python/test_matrix_different_type.py b/tests/python/test_matrix_different_type.py index 05288080f3fae..7e4a0d5f629ca 100644 --- a/tests/python/test_matrix_different_type.py +++ b/tests/python/test_matrix_different_type.py @@ -69,12 +69,12 @@ def verify(): @test_utils.test(require=ti.extension.quant_basic) -def test_custom_type(): - cit1 = ti.types.quant.int(bits=10, signed=True) - cft1 = ti.types.quant.fixed(frac=10, signed=True, scale=0.1) - cit2 = ti.types.quant.int(bits=22, signed=False) - cft2 = ti.types.quant.fixed(frac=22, signed=False, scale=0.1) - type_list = [[cit1, cft2], [cft1, cit2]] +def test_quant_type(): + qit1 = ti.types.quant.int(bits=10, signed=True) + qfxt1 = ti.types.quant.fixed(frac=10, signed=True, scale=0.1) + qit2 = ti.types.quant.int(bits=22, signed=False) + qfxt2 = ti.types.quant.fixed(frac=22, signed=False, scale=0.1) + type_list = [[qit1, qfxt2], [qfxt1, qit2]] a = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) b = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) c = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) diff --git a/tests/python/test_custom_type_atomics.py b/tests/python/test_quant_atomics.py similarity index 70% rename from tests/python/test_custom_type_atomics.py rename to tests/python/test_quant_atomics.py index 5569d4bf5f132..97ed1c930d92d 100644 --- a/tests/python/test_custom_type_atomics.py +++ b/tests/python/test_quant_atomics.py @@ -8,14 +8,14 @@ @test_utils.test(require=ti.extension.quant_basic, exclude=[ti.metal], debug=True) -def test_custom_int_atomics(): - ci13 = ti.types.quant.int(13, True) - ci5 = ti.types.quant.int(5, True) - cu2 = ti.types.quant.int(2, False) +def test_quant_int_atomics(): + qi13 = ti.types.quant.int(13, True) + qi5 = ti.types.quant.int(5, True) + qu2 = ti.types.quant.int(2, False) - x = ti.field(dtype=ci13) - y = ti.field(dtype=ci5) - z = ti.field(dtype=cu2) + x = ti.field(dtype=qi13) + y = ti.field(dtype=qi5) + z = ti.field(dtype=qu2) ti.root.bit_struct(num_bits=32).place(x, y, z) @@ -43,10 +43,10 @@ def foo(): @test_utils.test(require=[ti.extension.quant_basic, ti.extension.data64], debug=True) -def test_custom_int_atomics_b64(): - ci13 = ti.types.quant.int(13, True) +def test_quant_int_atomics_b64(): + qi13 = ti.types.quant.int(13, True) - x = ti.field(dtype=ci13) + x = ti.field(dtype=qi13) ti.root.bit_array(ti.i, 4, num_bits=64).place(x) @@ -67,12 +67,12 @@ def foo(): @test_utils.test(require=ti.extension.quant_basic, debug=True) -def test_custom_float_atomics(): - cft13 = ti.types.quant.fixed(frac=13, signed=True, scale=0.1) - cft19 = ti.types.quant.fixed(frac=19, signed=False, scale=0.1) +def test_quant_fixed_atomics(): + qfxt13 = ti.types.quant.fixed(frac=13, signed=True, scale=0.1) + qfxt19 = ti.types.quant.fixed(frac=19, signed=False, scale=0.1) - x = ti.field(dtype=cft13) - y = ti.field(dtype=cft19) + x = ti.field(dtype=qfxt13) + y = ti.field(dtype=qfxt19) ti.root.bit_struct(num_bits=32).place(x, y) diff --git a/tests/python/test_custom_float.py b/tests/python/test_quant_fixed.py similarity index 78% rename from tests/python/test_custom_float.py rename to tests/python/test_quant_fixed.py index 789079385777f..082b697576dc4 100644 --- a/tests/python/test_custom_float.py +++ b/tests/python/test_quant_fixed.py @@ -7,9 +7,9 @@ @test_utils.test(require=ti.extension.quant_basic) -def test_custom_float(): - cft = ti.types.quant.fixed(frac=32, range=2) - x = ti.field(dtype=cft) +def test_quant_fixed(): + qfxt = ti.types.quant.fixed(frac=32, range=2) + x = ti.field(dtype=qfxt) ti.root.bit_struct(num_bits=32).place(x) @@ -28,10 +28,10 @@ def foo(): @test_utils.test(require=ti.extension.quant_basic) -def test_custom_matrix_rotation(): - cft = ti.types.quant.fixed(frac=16, range=1.2) +def test_quant_fixed_matrix_rotation(): + qfxt = ti.types.quant.fixed(frac=16, range=1.2) - x = ti.Matrix.field(2, 2, dtype=cft) + x = ti.Matrix.field(2, 2, dtype=qfxt) ti.root.bit_struct(num_bits=32).place(x.get_scalar_field(0, 0), x.get_scalar_field(0, 1)) @@ -56,9 +56,9 @@ def rotate_18_degrees(): @test_utils.test(require=ti.extension.quant_basic) -def test_custom_float_implicit_cast(): - cft = ti.types.quant.fixed(frac=13, scale=0.1) - x = ti.field(dtype=cft) +def test_quant_fixed_implicit_cast(): + qfxt = ti.types.quant.fixed(frac=13, scale=0.1) + x = ti.field(dtype=qfxt) ti.root.bit_struct(num_bits=32).place(x) @@ -71,9 +71,9 @@ def foo(): @test_utils.test(require=ti.extension.quant_basic) -def test_cache_read_only(): - cft = ti.types.quant.fixed(frac=15, scale=0.1) - x = ti.field(dtype=cft) +def test_quant_fixed_cache_read_only(): + qfxt = ti.types.quant.fixed(frac=15, scale=0.1) + x = ti.field(dtype=qfxt) ti.root.bit_struct(num_bits=32).place(x) diff --git a/tests/python/test_custom_float_exponents.py b/tests/python/test_quant_float.py similarity index 78% rename from tests/python/test_custom_float_exponents.py rename to tests/python/test_quant_float.py index 5c2be6f58eab5..3ae22f1bcc165 100644 --- a/tests/python/test_custom_float_exponents.py +++ b/tests/python/test_quant_float.py @@ -7,9 +7,9 @@ @test_utils.test(require=ti.extension.quant) -def test_custom_float_unsigned(): - cft = ti.types.quant.float(exp=6, frac=13, signed=False) - x = ti.field(dtype=cft) +def test_quant_float_unsigned(): + qflt = ti.types.quant.float(exp=6, frac=13, signed=False) + x = ti.field(dtype=qflt) ti.root.bit_struct(num_bits=32).place(x) @@ -26,9 +26,9 @@ def test_custom_float_unsigned(): @test_utils.test(require=ti.extension.quant) -def test_custom_float_signed(): - cft = ti.types.quant.float(exp=6, frac=13, signed=True) - x = ti.field(dtype=cft) +def test_quant_float_signed(): + qflt = ti.types.quant.float(exp=6, frac=13, signed=True) + x = ti.field(dtype=qflt) ti.root.bit_struct(num_bits=32).place(x) @@ -54,9 +54,9 @@ def test_custom_float_signed(): @pytest.mark.parametrize('digits_bits', [23, 24]) @test_utils.test(require=ti.extension.quant) -def test_custom_float_precision(digits_bits): - cft = ti.types.quant.float(exp=8, frac=digits_bits) - x = ti.field(dtype=cft) +def test_quant_float_precision(digits_bits): + qflt = ti.types.quant.float(exp=8, frac=digits_bits) + x = ti.field(dtype=qflt) ti.root.bit_struct(num_bits=32).place(x) @@ -75,9 +75,9 @@ def test_custom_float_precision(digits_bits): @pytest.mark.parametrize('signed', [True, False]) @test_utils.test(require=ti.extension.quant) -def test_custom_float_truncation(signed): - cft = ti.types.quant.float(exp=5, frac=2, signed=signed) - x = ti.field(dtype=cft) +def test_quant_float_truncation(signed): + qflt = ti.types.quant.float(exp=5, frac=2, signed=signed) + x = ti.field(dtype=qflt) ti.root.bit_struct(num_bits=32).place(x) @@ -103,9 +103,9 @@ def test_custom_float_truncation(signed): @test_utils.test(require=ti.extension.quant) -def test_custom_float_atomic_demotion(): - cft = ti.types.quant.float(exp=5, frac=2) - x = ti.field(dtype=cft) +def test_quant_float_atomic_demotion(): + qflt = ti.types.quant.float(exp=5, frac=2) + x = ti.field(dtype=qflt) ti.root.bit_struct(num_bits=32).place(x) diff --git a/tests/python/test_custom_float_shared_exp.py b/tests/python/test_quant_float_shared_exp.py similarity index 75% rename from tests/python/test_custom_float_shared_exp.py rename to tests/python/test_quant_float_shared_exp.py index 9ca7dd8f15b0f..00decdd5e1d77 100644 --- a/tests/python/test_custom_float_shared_exp.py +++ b/tests/python/test_quant_float_shared_exp.py @@ -8,10 +8,10 @@ @pytest.mark.parametrize('exponent_bits', [5, 6, 7, 8]) @test_utils.test(require=ti.extension.quant) def test_shared_exponents(exponent_bits): - cft1 = ti.types.quant.float(exp=exponent_bits, frac=10, signed=False) - cft2 = ti.types.quant.float(exp=exponent_bits, frac=14, signed=False) - a = ti.field(dtype=cft1) - b = ti.field(dtype=cft2) + qflt1 = ti.types.quant.float(exp=exponent_bits, frac=10, signed=False) + qflt2 = ti.types.quant.float(exp=exponent_bits, frac=14, signed=False) + a = ti.field(dtype=qflt1) + b = ti.field(dtype=qflt2) ti.root.bit_struct(num_bits=32).place(a, b, shared_exponent=True) assert a[None] == 0.0 @@ -71,10 +71,10 @@ def foo(x: ti.f32, y: ti.f32): @pytest.mark.parametrize('exponent_bits', [5, 6, 7, 8]) @test_utils.test(require=ti.extension.quant) def test_shared_exponent_add(exponent_bits): - cft1 = ti.types.quant.float(exp=exponent_bits, frac=10, signed=False) - cft2 = ti.types.quant.float(exp=exponent_bits, frac=14, signed=False) - a = ti.field(dtype=cft1) - b = ti.field(dtype=cft2) + qflt1 = ti.types.quant.float(exp=exponent_bits, frac=10, signed=False) + qflt2 = ti.types.quant.float(exp=exponent_bits, frac=14, signed=False) + a = ti.field(dtype=qflt1) + b = ti.field(dtype=qflt2) ti.root.bit_struct(num_bits=32).place(a, b, shared_exponent=True) @ti.kernel @@ -104,10 +104,10 @@ def foo(x: ti.f32, y: ti.f32): @pytest.mark.parametrize('exponent_bits', [5, 6, 7, 8]) @test_utils.test(require=ti.extension.quant) def test_shared_exponent_borrow(exponent_bits): - cft1 = ti.types.quant.float(exp=exponent_bits, frac=10, signed=False) - cft2 = ti.types.quant.float(exp=exponent_bits, frac=14, signed=False) - a = ti.field(dtype=cft1) - b = ti.field(dtype=cft2) + qflt1 = ti.types.quant.float(exp=exponent_bits, frac=10, signed=False) + qflt2 = ti.types.quant.float(exp=exponent_bits, frac=14, signed=False) + a = ti.field(dtype=qflt1) + b = ti.field(dtype=qflt2) ti.root.bit_struct(num_bits=32).place(a, b, shared_exponent=True) @ti.kernel @@ -129,11 +129,11 @@ def inc(): @pytest.mark.parametrize('exponent_bits', [5, 6, 7, 8]) @test_utils.test(require=ti.extension.quant) -def test_negative(exponent_bits): - cft1 = ti.types.quant.float(exp=exponent_bits, frac=10, signed=False) - cft2 = ti.types.quant.float(exp=exponent_bits, frac=14, signed=True) - a = ti.field(dtype=cft1) - b = ti.field(dtype=cft2) +def test_shared_exponent_negative(exponent_bits): + qflt1 = ti.types.quant.float(exp=exponent_bits, frac=10, signed=False) + qflt2 = ti.types.quant.float(exp=exponent_bits, frac=14, signed=True) + a = ti.field(dtype=qflt1) + b = ti.field(dtype=qflt2) ti.root.bit_struct(num_bits=32).place(a, b, shared_exponent=True) a[None] = 37 @@ -144,4 +144,4 @@ def test_negative(exponent_bits): # TODO: test precision # TODO: make sure unsigned has one more effective significand bit -# TODO: test shared exponent floats with custom int in a single bit struct +# TODO: test shared exponent floats with quant int in a single bit struct diff --git a/tests/python/test_custom_int.py b/tests/python/test_quant_int.py similarity index 69% rename from tests/python/test_custom_int.py rename to tests/python/test_quant_int.py index 1fd0077e9280f..35578f80d9a89 100644 --- a/tests/python/test_custom_int.py +++ b/tests/python/test_quant_int.py @@ -3,9 +3,9 @@ @test_utils.test(require=ti.extension.quant_basic) -def test_custom_int_implicit_cast(): - ci13 = ti.types.quant.int(13, True) - x = ti.field(dtype=ci13) +def test_quant_int_implicit_cast(): + qi13 = ti.types.quant.int(13, True) + x = ti.field(dtype=qi13) ti.root.bit_struct(num_bits=32).place(x) diff --git a/tests/python/test_custom_float_time_integration.py b/tests/python/test_quant_time_integration.py similarity index 79% rename from tests/python/test_custom_float_time_integration.py rename to tests/python/test_quant_time_integration.py index 459276781e001..566fb4b8d4b8a 100644 --- a/tests/python/test_custom_float_time_integration.py +++ b/tests/python/test_quant_time_integration.py @@ -7,23 +7,23 @@ from tests import test_utils -@pytest.mark.parametrize('use_cft,use_exponent,use_shared_exp', +@pytest.mark.parametrize('use_quant,use_exponent,use_shared_exp', [(False, False, False), (True, False, False), (True, True, False), (True, True, True)]) @test_utils.test(require=ti.extension.quant) -def test_custom_float_time_integration(use_cft, use_exponent, use_shared_exp): - if use_cft: +def test_quant_time_integration(use_quant, use_exponent, use_shared_exp): + if use_quant: if use_exponent: - cft = ti.types.quant.float(exp=6, frac=13) - x = ti.Vector.field(2, dtype=cft) + qflt = ti.types.quant.float(exp=6, frac=13) + x = ti.Vector.field(2, dtype=qflt) if use_shared_exp: ti.root.bit_struct(num_bits=32).place(x, shared_exponent=True) else: ti.root.bit_struct(num_bits=32).place(x.get_scalar_field(0)) ti.root.bit_struct(num_bits=32).place(x.get_scalar_field(1)) else: - cft = ti.types.quant.fixed(frac=16, range=2) - x = ti.Vector.field(2, dtype=cft) + qfxt = ti.types.quant.fixed(frac=16, range=2) + x = ti.Vector.field(2, dtype=qfxt) ti.root.bit_struct(num_bits=32).place(x) else: x = ti.Vector.field(2, dtype=ti.f32, shape=()) diff --git a/tests/python/test_struct_for.py b/tests/python/test_struct_for.py index 4e4b557c0711b..18bf87e2e99fc 100644 --- a/tests/python/test_struct_for.py +++ b/tests/python/test_struct_for.py @@ -267,8 +267,8 @@ def count() -> int: def test_struct_for_quant(): n = 8 - ci13 = ti.types.quant.int(13, True) - x = ti.field(dtype=ci13) + qi13 = ti.types.quant.int(13, True) + x = ti.field(dtype=qi13) ti.root.dense(ti.i, n).bit_struct(num_bits=32).place(x)