Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] [type] Refine SNode with quant 5/n: Rename bit_array to quant_array #5344

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/taichi/_snode/fields_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ def bit_struct(self, num_bits: int):
self.empty = False
return self.root.bit_struct(num_bits)

def bit_array(self, indices: Union[Sequence[_Axis], _Axis],
dimensions: Union[Sequence[int], int], num_bits: int):
"""Same as :func:`taichi.lang.snode.SNode.bit_array`"""
def quant_array(self, indices: Union[Sequence[_Axis], _Axis],
dimensions: Union[Sequence[int], int], num_bits: int):
"""Same as :func:`taichi.lang.snode.SNode.quant_array`"""
self._check_not_finalized()
self.empty = False
return self.root.bit_array(indices, dimensions, num_bits)
return self.root.quant_array(indices, dimensions, num_bits)

def place(self,
*args: Any,
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def _block_dim_adaptive(block_dim_adaptive):


def _bit_vectorize():
"""Enable bit vectorization of struct fors on bit_arrays.
"""Enable bit vectorization of struct fors on quant_arrays.
"""
get_runtime().prog.current_ast_builder().bit_vectorize()

Expand All @@ -614,7 +614,7 @@ def loop_config(*,
serialize (bool): Whether to let the for loop execute serially, `serialize=True` equals to `parallelize=1`
parallelize (int): The number of threads to use on CPU
block_dim_adaptive (bool): Whether to allow backends set block_dim adaptively, enabled by default
bit_vectorize (bool): Whether to enable bit vectorization of struct fors on bit_arrays.
bit_vectorize (bool): Whether to enable bit vectorization of struct fors on quant_arrays.

Examples::

Expand Down Expand Up @@ -644,8 +644,8 @@ def fill():
x = ti.field(dtype=u1)
y = ti.field(dtype=u1)
cell = ti.root.dense(ti.ij, (128, 4))
cell.bit_array(ti.j, 32).place(x)
cell.bit_array(ti.j, 32).place(y)
cell.quant_array(ti.j, 32).place(x)
cell.quant_array(ti.j, 32).place(y)
@ti.kernel
def copy():
ti.loop_config(bit_vectorize=True)
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def bit_struct(self, num_bits: int):
"""
return SNode(self.ptr.bit_struct(num_bits, impl.current_cfg().packed))

def bit_array(self, axes, dimensions, num_bits):
"""Adds a bit_array SNode as a child component of `self`.
def quant_array(self, axes, dimensions, num_bits):
"""Adds a quant_array SNode as a child component of `self`.

Args:
axes (List[Axis]): Axes to activate.
Expand All @@ -121,8 +121,8 @@ def bit_array(self, axes, dimensions, num_bits):
if isinstance(dimensions, int):
dimensions = [dimensions] * len(axes)
return SNode(
self.ptr.bit_array(axes, dimensions, num_bits,
impl.current_cfg().packed))
self.ptr.quant_array(axes, dimensions, num_bits,
impl.current_cfg().packed))

def place(self, *args, offset=None, shared_exponent=False):
"""Places a list of Taichi fields under the `self` container.
Expand Down
24 changes: 12 additions & 12 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ std::unique_ptr<RuntimeObject> CodeGenLLVM::emit_struct_meta_object(
meta =
std::make_unique<RuntimeObject>("BitmaskedMeta", this, builder.get());
emit_struct_meta_base("Bitmasked", meta->ptr, snode);
} else if (snode->type == SNodeType::bit_array) {
} else if (snode->type == SNodeType::quant_array) {
meta = std::make_unique<RuntimeObject>("DenseMeta", this, builder.get());
emit_struct_meta_base("Dense", meta->ptr, snode);
} else {
Expand Down Expand Up @@ -1351,7 +1351,7 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
"handled by BitStructStoreStmt.",
pointee_type->to_string());
} else {
TI_ERROR("Bit array only supports quant int type.");
TI_ERROR("Quant array only supports quant int type.");
}
}
store_quant_int(llvm_val[stmt->dest], pointee_type->as<QuantIntType>(),
Expand Down Expand Up @@ -1414,8 +1414,8 @@ std::string CodeGenLLVM::get_runtime_snode_name(SNode *snode) {
return "Bitmasked";
} else if (snode->type == SNodeType::bit_struct) {
return "BitStruct";
} else if (snode->type == SNodeType::bit_array) {
return "BitArray";
} else if (snode->type == SNodeType::quant_array) {
return "QuantArray";
} else {
TI_P(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED
Expand Down Expand Up @@ -1536,9 +1536,9 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
{llvm_val[stmt->input_index]});
} else if (snode->type == SNodeType::bit_struct) {
llvm_val[stmt] = parent;
} else if (snode->type == SNodeType::bit_array) {
} else if (snode->type == SNodeType::quant_array) {
auto element_num_bits =
snode->dt->as<BitArrayType>()->get_element_num_bits();
snode->dt->as<QuantArrayType>()->get_element_num_bits();
auto offset = tlctx->get_constant(element_num_bits);
offset = builder->CreateMul(offset, llvm_val[stmt->input_index]);
llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_snode], offset);
Expand All @@ -1549,7 +1549,7 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
}

void CodeGenLLVM::visit(GetChStmt *stmt) {
if (stmt->input_snode->type == SNodeType::bit_array) {
if (stmt->input_snode->type == SNodeType::quant_array) {
llvm_val[stmt] = llvm_val[stmt->input_ptr];
} else if (stmt->ret_type->as<PointerType>()->is_bit_pointer()) {
auto bit_struct = stmt->input_snode->dt->cast<BitStructType>();
Expand Down Expand Up @@ -1728,16 +1728,16 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {
llvm::Function *body = nullptr;
auto leaf_block = stmt->snode;

// For a bit-vectorized loop over a bit array, we generate struct for on its
// For a bit-vectorized loop over a quant array, we generate struct for on its
// parent node (must be "dense") instead of itself for higher performance.
if (stmt->is_bit_vectorized) {
if (leaf_block->type == SNodeType::bit_array &&
if (leaf_block->type == SNodeType::quant_array &&
leaf_block->parent->type == SNodeType::dense) {
leaf_block = leaf_block->parent;
} else {
TI_ERROR(
"A bit-vectorized struct-for must loop over a bit array with a dense "
"parent");
"A bit-vectorized struct-for must loop over a quant array with a "
"dense parent");
}
}

Expand Down Expand Up @@ -1869,7 +1869,7 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {
create_call(refine, {parent_coordinates, new_coordinates,
builder->CreateLoad(loop_index)});

// For a bit-vectorized loop over a bit array, one more refine step is
// For a bit-vectorized loop over a quant array, one more refine step is
// needed to make final coordinates non-consecutive, since each thread will
// process multiple coordinates via vectorization
if (stmt->is_bit_vectorized) {
Expand Down
8 changes: 4 additions & 4 deletions taichi/codegen/llvm/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
}
snode.dt = snode.bit_struct_type_builder->build();
body_type = tlctx_->get_data_type(snode.physical_type);
} else if (type == SNodeType::bit_array) {
// A bit array SNode should have only one child
} else if (type == SNodeType::quant_array) {
// A quant array SNode should have only one child
TI_ASSERT(snode.ch.size() == 1);
auto &ch = snode.ch[0];
Type *ch_type = ch->dt;
if (!arch_is_cpu(arch_)) {
TI_ERROR_IF(data_type_bits(snode.physical_type) <= 16,
"bit_array physical type must be at least 32 bits on "
"quant_array physical type must be at least 32 bits on "
"non-CPU backends.");
}
snode.dt = TypeFactory::get_instance().get_bit_array_type(
snode.dt = TypeFactory::get_instance().get_quant_array_type(
snode.physical_type, ch_type, snode.num_cells_per_container);

DataType container_primitive_type(snode.physical_type);
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/metal/struct_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class StructCompiler {
if (ty == SNodeType::place) {
// do nothing
}
CHECK_UNSUPPORTED_TYPE(bit_array)
CHECK_UNSUPPORTED_TYPE(quant_array)
CHECK_UNSUPPORTED_TYPE(hash)
else {
max_snodes_ = std::max(max_snodes_, sn->id);
Expand Down
2 changes: 1 addition & 1 deletion taichi/inc/snodes.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ PER_SNODE(bitmasked)
PER_SNODE(hash)
PER_SNODE(place)
PER_SNODE(bit_struct)
PER_SNODE(bit_array)
PER_SNODE(quant_array)
PER_SNODE(undefined)
12 changes: 6 additions & 6 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ SNode &SNode::insert_children(SNodeType t) {
std::memcpy(new_ch->physical_index_position, physical_index_position,
sizeof(physical_index_position));
new_ch->num_active_indices = num_active_indices;
if (type == SNodeType::bit_struct || type == SNodeType::bit_array) {
if (type == SNodeType::bit_struct || type == SNodeType::quant_array) {
new_ch->is_bit_level = true;
} else {
new_ch->is_bit_level = is_bit_level;
Expand Down Expand Up @@ -145,11 +145,11 @@ SNode &SNode::bit_struct(int num_bits, bool packed) {
return snode;
}

SNode &SNode::bit_array(const std::vector<Axis> &axes,
const std::vector<int> &sizes,
int bits,
bool packed) {
auto &snode = create_node(axes, sizes, SNodeType::bit_array, packed);
SNode &SNode::quant_array(const std::vector<Axis> &axes,
const std::vector<int> &sizes,
int bits,
bool packed) {
auto &snode = create_node(axes, sizes, SNodeType::quant_array, packed);
snode.physical_type =
TypeFactory::get_instance().get_primitive_int_type(bits, false);
return snode;
Expand Down
12 changes: 6 additions & 6 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SNode {
int chunk_size{0};
std::size_t cell_size_bytes{0};
std::size_t offset_bytes_in_parent_cell{0};
PrimitiveType *physical_type{nullptr}; // for bit_struct and bit_array only
PrimitiveType *physical_type{nullptr}; // for bit_struct and quant_array only
DataType dt;
bool has_ambient{false};
TypedConstant ambient_val;
Expand All @@ -147,7 +147,7 @@ class SNode {

// is_bit_level=false: the SNode is not bitpacked
// is_bit_level=true: the SNode is bitpacked (i.e., strictly inside bit_struct
// or bit_array)
// or quant_array)
bool is_bit_level{false};

// Whether the path from root to |this| contains only `dense` SNodes.
Expand Down Expand Up @@ -247,10 +247,10 @@ class SNode {

SNode &bit_struct(int bits, bool packed);

SNode &bit_array(const std::vector<Axis> &axes,
const std::vector<int> &sizes,
int bits,
bool packed);
SNode &quant_array(const std::vector<Axis> &axes,
const std::vector<int> &sizes,
int bits,
bool packed);

void print();

Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ std::string BitStructType::to_string() const {
return str + ")";
}

std::string BitArrayType::to_string() const {
return fmt::format("ba({}x{})", element_type_->to_string(), num_elements_);
std::string QuantArrayType::to_string() const {
return fmt::format("qa({}x{})", element_type_->to_string(), num_elements_);
}

std::string TypedConstant::stringify() const {
Expand Down
8 changes: 4 additions & 4 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,11 @@ class BitStructType : public Type {
std::vector<int> member_bit_offsets_;
};

class BitArrayType : public Type {
class QuantArrayType : public Type {
public:
BitArrayType(PrimitiveType *physical_type,
Type *element_type_,
int num_elements_)
QuantArrayType(PrimitiveType *physical_type,
Type *element_type_,
int num_elements_)
: physical_type_(physical_type),
element_type_(element_type_),
num_elements_(num_elements_) {
Expand Down
10 changes: 5 additions & 5 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ Type *TypeFactory::get_bit_struct_type(PrimitiveType *physical_type,
return bit_struct_types_.back().get();
}

Type *TypeFactory::get_bit_array_type(PrimitiveType *physical_type,
Type *element_type,
int num_elements) {
bit_array_types_.push_back(std::make_unique<BitArrayType>(
Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type,
Type *element_type,
int num_elements) {
quant_array_types_.push_back(std::make_unique<QuantArrayType>(
physical_type, element_type, num_elements));
return bit_array_types_.back().get();
return quant_array_types_.back().get();
}

PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) {
Expand Down
8 changes: 4 additions & 4 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class TypeFactory {
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets);

Type *get_bit_array_type(PrimitiveType *physical_type,
Type *element_type,
int num_elements);
Type *get_quant_array_type(PrimitiveType *physical_type,
Type *element_type,
int num_elements);

static DataType create_vector_or_scalar_type(int width,
DataType element,
Expand Down Expand Up @@ -77,7 +77,7 @@ class TypeFactory {
std::vector<std::unique_ptr<Type>> bit_struct_types_;

// TODO: avoid duplication
std::vector<std::unique_ptr<Type>> bit_array_types_;
std::vector<std::unique_ptr<Type>> quant_array_types_;

std::mutex mut_;
};
Expand Down
3 changes: 2 additions & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,8 @@ void export_lang(py::module &m) {
bool))(&SNode::bitmasked),
py::return_value_policy::reference)
.def("bit_struct", &SNode::bit_struct, py::return_value_policy::reference)
.def("bit_array", &SNode::bit_array, py::return_value_policy::reference)
.def("quant_array", &SNode::quant_array,
py::return_value_policy::reference)
.def("place", &SNode::place)
.def("data_type", [](SNode *snode) { return snode->dt; })
.def("name", [](SNode *snode) { return snode->name; })
Expand Down
Loading