Skip to content

Commit

Permalink
[Lang] [type] Refine SNode with quant 8/n: Replace bit_struct with ti…
Browse files Browse the repository at this point in the history
….BitpackedFields (#5532)

* [Lang] [type] Replace bit_struct with ti.BitpackedFields

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix lifetime problem

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Lang] [type] Replace bit_struct with ti.BitpackedFields

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix lifetime problem

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Jul 29, 2022
1 parent 4423e30 commit d88dacf
Show file tree
Hide file tree
Showing 25 changed files with 240 additions and 172 deletions.
11 changes: 2 additions & 9 deletions python/taichi/_snode/fields_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,6 @@ def bitmasked(self, indices: Union[Sequence[_Axis], _Axis],
self.empty = False
return self.root.bitmasked(indices, dimensions)

def bit_struct(self, num_bits: int):
"""Same as :func:`taichi.lang.snode.SNode.bit_struct`"""
self._check_not_finalized()
self.empty = False
return self.root.bit_struct(num_bits)

def quant_array(self, indices: Union[Sequence[_Axis], _Axis],
dimensions: Union[Sequence[int], int], num_bits: int):
"""Same as :func:`taichi.lang.snode.SNode.quant_array`"""
Expand All @@ -113,12 +107,11 @@ def quant_array(self, indices: Union[Sequence[_Axis], _Axis],

def place(self,
*args: Any,
offset: Optional[Union[Sequence[int], int]] = None,
shared_exponent: bool = False):
offset: Optional[Union[Sequence[int], int]] = None):
"""Same as :func:`taichi.lang.snode.SNode.place`"""
self._check_not_finalized()
self.empty = False
self.root.place(*args, offset=offset, shared_exponent=shared_exponent)
self.root.place(*args, offset=offset)

def lazy_grad(self):
"""Same as :func:`taichi.lang.snode.SNode.lazy_grad`"""
Expand Down
32 changes: 31 additions & 1 deletion python/taichi/lang/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,34 @@ def __init__(self, accessor, key):
self.key = key


__all__ = ["Field", "ScalarField"]
class BitpackedFields:
"""Taichi bitpacked fields, where fields with quantized types are packed together.
Args:
max_num_bits (int): Maximum number of bits all fields inside can occupy in total. Only 32 or 64 is allowed.
"""
def __init__(self, max_num_bits):
self.fields = []
self.bit_struct_type_builder = _ti_core.BitStructTypeBuilder(
max_num_bits)

def place(self, *args, shared_exponent=False):
"""Places a list of fields with quantized types inside.
Args:
*args (List[Field]): A list of fields with quantized types to place.
shared_exponent (bool): Whether the fields have a shared exponent.
"""
if shared_exponent:
self.bit_struct_type_builder.begin_placing_shared_exponent()
for arg in args:
assert isinstance(arg, Field)
for var in arg._get_field_members():
self.fields.append((var.ptr,
self.bit_struct_type_builder.add_member(
var.ptr.get_dt())))
if shared_exponent:
self.bit_struct_type_builder.end_placing_shared_exponent()


__all__ = ["BitpackedFields", "Field", "ScalarField"]
31 changes: 11 additions & 20 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from taichi._lib import core as _ti_core
from taichi.lang import expr, impl, matrix
from taichi.lang.field import Field
from taichi.lang.field import BitpackedFields, Field


class SNode:
Expand Down Expand Up @@ -96,17 +96,6 @@ def bitmasked(self, axes, dimensions):
self.ptr.bitmasked(axes, dimensions,
impl.current_cfg().packed))

def bit_struct(self, num_bits: int):
"""Adds a bit_struct SNode as a child component of `self`.
Args:
num_bits: Number of bits to use.
Returns:
The added :class:`~taichi.lang.SNode` instance.
"""
return SNode(self.ptr.bit_struct(num_bits, impl.current_cfg().packed))

def quant_array(self, axes, dimensions, num_bits):
"""Adds a quant_array SNode as a child component of `self`.
Expand All @@ -124,13 +113,12 @@ def quant_array(self, axes, dimensions, num_bits):
self.ptr.quant_array(axes, dimensions, num_bits,
impl.current_cfg().packed))

def place(self, *args, offset=None, shared_exponent=False):
def place(self, *args, offset=None):
"""Places a list of Taichi fields under the `self` container.
Args:
*args (List[ti.field]): A list of Taichi fields to place.
offset (Union[Number, tuple[Number]]): Offset of the field domain.
shared_exponent (bool): Only useful for quant types.
Returns:
The `self` container.
Expand All @@ -139,20 +127,23 @@ def place(self, *args, offset=None, shared_exponent=False):
offset = ()
if isinstance(offset, numbers.Number):
offset = (offset, )
if shared_exponent:
self.ptr.begin_shared_exp_placement()

for arg in args:
if isinstance(arg, Field):
if isinstance(arg, BitpackedFields):
bit_struct_type = arg.bit_struct_type_builder.build()
bit_struct_snode = self.ptr.bit_struct(
bit_struct_type,
impl.current_cfg().packed)
for (field, id_in_bit_struct) in arg.fields:
bit_struct_snode.place(field, offset, id_in_bit_struct)
elif isinstance(arg, Field):
for var in arg._get_field_members():
self.ptr.place(var.ptr, offset)
self.ptr.place(var.ptr, offset, -1)
elif isinstance(arg, list):
for x in arg:
self.place(x, offset=offset)
else:
raise ValueError(f'{arg} cannot be placed')
if shared_exponent:
self.ptr.end_shared_exp_placement()
return self

def lazy_grad(self):
Expand Down
7 changes: 4 additions & 3 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ static void get_offline_cache_key_of_snode_impl(
serializer(snode->chunk_size);
serializer(snode->cell_size_bytes);
serializer(snode->offset_bytes_in_parent_cell);
if (snode->physical_type) {
serializer(snode->physical_type->to_string());
}
serializer(snode->dt->to_string());
serializer(snode->has_ambient);
if (!snode->ambient_val.dt->is_primitive(PrimitiveTypeID::unknown)) {
Expand All @@ -119,6 +116,10 @@ static void get_offline_cache_key_of_snode_impl(
get_offline_cache_key_of_snode_impl(dual_snode, serializer, visited);
}
}
if (snode->physical_type) {
serializer(snode->physical_type->to_string());
}
serializer(snode->id_in_bit_struct);
serializer(snode->is_bit_level);
serializer(snode->is_path_all_dense);
serializer(snode->node_type_name);
Expand Down
1 change: 0 additions & 1 deletion taichi/codegen/llvm/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
"bit_struct physical type must be at least 32 bits on "
"non-CPU backends.");
}
snode.dt = snode.bit_struct_type_builder->build();
body_type = tlctx_->get_data_type(snode.physical_type);
} else if (type == SNodeType::quant_array) {
// A quant array SNode should have only one child
Expand Down
18 changes: 3 additions & 15 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,10 @@ SNode &SNode::dynamic(const Axis &expr, int n, int chunk_size, bool packed) {
return snode;
}

SNode &SNode::bit_struct(int num_bits, bool packed) {
SNode &SNode::bit_struct(BitStructType *bit_struct_type, bool packed) {
auto &snode = create_node({}, {}, SNodeType::bit_struct, packed);
snode.physical_type =
TypeFactory::get_instance().get_primitive_int_type(num_bits, false);
snode.bit_struct_type_builder =
std::make_unique<BitStructTypeBuilder>(snode.physical_type);
snode.dt = bit_struct_type;
snode.physical_type = bit_struct_type->get_physical_type();
return snode;
}

Expand Down Expand Up @@ -285,16 +283,6 @@ bool SNode::need_activation() const {
type == SNodeType::bitmasked || type == SNodeType::dynamic;
}

void SNode::begin_shared_exp_placement() {
TI_ASSERT(bit_struct_type_builder);
bit_struct_type_builder->begin_placing_shared_exponent();
}

void SNode::end_shared_exp_placement() {
TI_ASSERT(bit_struct_type_builder);
bit_struct_type_builder->end_placing_shared_exponent();
}

bool SNode::is_primal() const {
return grad_info && grad_info->is_primal();
}
Expand Down
22 changes: 7 additions & 15 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,17 @@ class SNode {
int chunk_size{0};
std::size_t cell_size_bytes{0};
std::size_t offset_bytes_in_parent_cell{0};
PrimitiveType *physical_type{nullptr}; // for bit_struct and quant_array only
DataType dt;
bool has_ambient{false};
TypedConstant ambient_val;
// Note: parent will not be set until structural nodes are compiled!
SNode *parent{nullptr};
std::unique_ptr<GradInfoProvider> grad_info{nullptr};

std::unique_ptr<BitStructTypeBuilder> bit_struct_type_builder{nullptr};
int id_in_bit_struct{0}; // for children of bit_struct only

// is_bit_level=false: the SNode is not bitpacked
// is_bit_level=true: the SNode is bitpacked (i.e., strictly inside bit_struct
// or quant_array)
bool is_bit_level{false};
// Quant
PrimitiveType *physical_type{nullptr}; // for bit_struct and quant_array only
int id_in_bit_struct{-1}; // for children of bit_struct only
bool is_bit_level{false}; // true if inside bit_struct or quant_array

// Whether the path from root to |this| contains only `dense` SNodes.
bool is_path_all_dense{true};
Expand Down Expand Up @@ -239,7 +235,7 @@ class SNode {
return snode_type_name(type);
}

SNode &bit_struct(int bits, bool packed);
SNode &bit_struct(BitStructType *bit_struct_type, bool packed);

SNode &quant_array(const std::vector<Axis> &axes,
const std::vector<int> &sizes,
Expand Down Expand Up @@ -324,8 +320,8 @@ class SNode {

int shape_along_axis(int i) const;

void place(Expr &expr, const std::vector<int> &offset) {
place_child(&expr, offset, this, snode_to_glb_var_exprs_);
void place(Expr &expr, const std::vector<int> &offset, int id_in_bit_struct) {
place_child(&expr, offset, id_in_bit_struct, this, snode_to_glb_var_exprs_);
}

void lazy_grad(bool is_adjoint, bool is_dual) {
Expand All @@ -342,10 +338,6 @@ class SNode {

uint64 fetch_reader_result(); // TODO: refactor

void begin_shared_exp_placement();

void end_shared_exp_placement();

// SNodeTree part

void set_snode_tree_id(int id);
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Type *TypeFactory::get_quant_float_type(Type *digits_type,
return quant_float_types_[key].get();
}

Type *TypeFactory::get_bit_struct_type(
BitStructType *TypeFactory::get_bit_struct_type(
PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const std::vector<int> &member_bit_offsets,
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TypeFactory {
Type *exponent_type,
Type *compute_type);

Type *get_bit_struct_type(
BitStructType *get_bit_struct_type(
PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const std::vector<int> &member_bit_offsets,
Expand Down Expand Up @@ -78,7 +78,7 @@ class TypeFactory {
quant_float_types_;

// TODO: avoid duplication
std::vector<std::unique_ptr<Type>> bit_struct_types_;
std::vector<std::unique_ptr<BitStructType>> bit_struct_types_;

// TODO: avoid duplication
std::vector<std::unique_ptr<Type>> quant_array_types_;
Expand Down
7 changes: 4 additions & 3 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ inline TypedConstant get_min_value(DataType dt) {

class BitStructTypeBuilder {
public:
explicit BitStructTypeBuilder(PrimitiveType *physical_type)
: physical_type_(physical_type) {
explicit BitStructTypeBuilder(int max_num_bits) {
physical_type_ =
TypeFactory::get_instance().get_primitive_int_type(max_num_bits);
}

int add_member(Type *member_type) {
Expand Down Expand Up @@ -225,7 +226,7 @@ class BitStructTypeBuilder {
is_placing_shared_exponent_ = false;
}

Type *build() const {
BitStructType *build() const {
return TypeFactory::get_instance().get_bit_struct_type(
physical_type_, member_types_, member_bit_offsets_,
member_owns_shared_exponents_, member_exponents_,
Expand Down
10 changes: 4 additions & 6 deletions taichi/program/snode_expr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ class GradInfoImpl final : public SNode::GradInfoProvider {

void place_child(Expr *expr_arg,
const std::vector<int> &offset,
int id_in_bit_struct,
SNode *parent,
SNodeGlobalVarExprMap *snode_to_exprs) {
if (parent->type == SNodeType::root) {
// never directly place to root
auto &ds = parent->dense(std::vector<Axis>(), {}, false);
place_child(expr_arg, offset, &ds, snode_to_exprs);
place_child(expr_arg, offset, id_in_bit_struct, &ds, snode_to_exprs);
} else {
TI_ASSERT(expr_arg->is<GlobalVariableExpression>());
auto glb_var_expr = expr_arg->cast<GlobalVariableExpression>();
Expand All @@ -66,10 +67,7 @@ void place_child(Expr *expr_arg,
std::make_unique<GradInfoImpl>(glb_var_expr.get());
(*snode_to_exprs)[glb_var_expr->snode] = glb_var_expr;
child.dt = glb_var_expr->dt;
if (parent->bit_struct_type_builder) {
child.id_in_bit_struct =
parent->bit_struct_type_builder->add_member(child.dt);
}
child.id_in_bit_struct = id_in_bit_struct;
if (!offset.empty())
child.set_index_offsets(offset);
}
Expand Down Expand Up @@ -100,7 +98,7 @@ void make_lazy_grad(SNode *snode,
}
}
for (auto p : new_grads) {
place_child(&p, /*offset=*/{}, snode, snode_to_exprs);
place_child(&p, /*offset=*/{}, -1, snode, snode_to_exprs);
}
}

Expand Down
1 change: 1 addition & 0 deletions taichi/program/snode_expr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using SNodeGlobalVarExprMap =

void place_child(Expr *expr_arg,
const std::vector<int> &offset,
int id_in_bit_struct,
SNode *parent,
SNodeGlobalVarExprMap *snode_to_exprs);

Expand Down
21 changes: 18 additions & 3 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,7 @@ void export_lang(py::module &m) {
[](SNode *snode) { return snode->num_active_indices; })
.def_readonly("cell_size_bytes", &SNode::cell_size_bytes)
.def_readonly("offset_bytes_in_parent_cell",
&SNode::offset_bytes_in_parent_cell)
.def("begin_shared_exp_placement", &SNode::begin_shared_exp_placement)
.def("end_shared_exp_placement", &SNode::end_shared_exp_placement);
&SNode::offset_bytes_in_parent_cell);

py::class_<SNodeTree>(m, "SNodeTree")
.def("id", &SNodeTree::id)
Expand Down Expand Up @@ -715,6 +713,12 @@ void export_lang(py::module &m) {
.def("set_adjoint", &Expr::set_adjoint)
.def("set_dual", &Expr::set_dual)
.def("set_attribute", &Expr::set_attribute)
.def(
"get_dt",
[&](Expr *expr) -> const Type * {
return expr->cast<GlobalVariableExpression>()->dt;
},
py::return_value_policy::reference)
.def("get_ret_type", &Expr::get_ret_type)
.def("type_check", &Expr::type_check)
.def("get_expr_name",
Expand Down Expand Up @@ -1079,6 +1083,17 @@ void export_lang(py::module &m) {
m.def("get_type_factory_instance", TypeFactory::get_instance,
py::return_value_policy::reference);

py::class_<BitStructType>(m, "BitStructType");
py::class_<BitStructTypeBuilder>(m, "BitStructTypeBuilder")
.def(py::init<int>())
.def("begin_placing_shared_exponent",
&BitStructTypeBuilder::begin_placing_shared_exponent)
.def("end_placing_shared_exponent",
&BitStructTypeBuilder::end_placing_shared_exponent)
.def("add_member", &BitStructTypeBuilder::add_member)
.def("build", &BitStructTypeBuilder::build,
py::return_value_policy::reference);

m.def("decl_tensor_type",
[&](std::vector<int> shape, const DataType &element_type) {
return TypeFactory::create_tensor_type(shape, element_type);
Expand Down
Loading

0 comments on commit d88dacf

Please sign in to comment.