diff --git a/misc/visualize_quant_types.py b/misc/visualize_quant_types.py deleted file mode 100644 index 62fb507a456a4..0000000000000 --- a/misc/visualize_quant_types.py +++ /dev/null @@ -1,176 +0,0 @@ -import argparse -import math -import os -from struct import pack, unpack - -import taichi as ti - -ti.init() - -f19 = ti.types.quant.float(exp=6, frac=13, signed=True) -f16 = ti.types.quant.float(exp=5, frac=11, signed=True) -fixed16 = ti.types.quant.fixed(frac=16, range=2) - -vf19 = ti.Vector.field(2, dtype=f19) -bs_vf19 = ti.root.bit_struct(num_bits=32) -bs_vf19.place(vf19, shared_exponent=True) - -vf16 = ti.Vector.field(2, dtype=f16) -bs_vf16 = ti.root.bit_struct(num_bits=32) -bs_vf16.place(vf16) - -vfixed16 = ti.Vector.field(2, dtype=fixed16) -bs_vfixed16 = ti.root.bit_struct(num_bits=32) -bs_vfixed16.place(vfixed16) - - -@ti.kernel -def set_vals(x: ti.f32, y: ti.f32): - val = ti.Vector([x, y]) - vf16[None] = val - vf19[None] = val - vfixed16[None] = val - - -def serialize_i32(x): - s = '' - for i in reversed(range(32)): - s += f'{(x>>i) & 1}' - return s - - -def serialize_f32(x): - b = pack('f', x) - n = unpack('i', b)[0] - return serialize_i32(n) - - -@ti.kernel -def fetch_bs(bs: ti.template()) -> ti.i32: - return bs[None] - - -coord = ti.GUI(res=(800, 800), background_color=0xFFFFFF) -numbers = ti.GUI(res=(800, 800), background_color=0xFFFFFF) - - -def draw_coord(t, f): - cx, cy = 0.5, 0.5 - lx, ly = 0.4, 0.4 - l1 = lx * 0.8 - al = 0.02 - coord.line(begin=(cx - lx, cy), - end=(cx + lx, cy), - radius=3, - color=0x666666) - coord.line(begin=(cx, cy - ly), - end=(cx, cy + ly), - radius=3, - color=0x666666) - coord.line(begin=(cx + lx - al, cy - al), - end=(cx + lx, cy), - radius=3, - color=0x666666) - coord.line(begin=(cx + lx - al, cy + al), - end=(cx + lx, cy), - radius=3, - color=0x666666) - coord.line(begin=(cx - al, cy + ly - al), - end=(cx, cy + ly), - radius=3, - color=0x666666) - coord.line(begin=(cx + al, cy + ly - al), - end=(cx, cy + ly), - radius=3, - color=0x666666) - - def transform(p): - return cx + l1 * p[0], cy + l1 * p[1] - - segments = 300 - for i in range(segments): - t1 = i / segments - t2 = (i + 1) / segments - coord.line(begin=transform(f(t1)), - end=transform(f(t2)), - radius=3, - color=0x0) - - coord.circle(pos=transform(f(t)), color=0xDD1122, radius=10) - - -frames = 300 - -parser = argparse.ArgumentParser() -parser.add_argument('-c', '--curve', type=int, help='Curve type', default=0) - -args = parser.parse_args() - -if args.curve == 0: - - def f(t): - return math.cos(t * 2 * math.pi), math.sin(t * 2 * math.pi) -elif args.curve == 1: - - def f(t): - t = math.cos(t * 2 * math.pi) * 0.5 + 0.5 - return 1 - t, t -elif args.curve == 2: - - def f(t): - t = math.cos(t * 2 * math.pi) - t = t * 2.3 - s = 0.1 - return math.exp(t) * s, math.exp(-t) * s - - -folder = f'curve{args.curve}' -os.makedirs(folder, exist_ok=True) - -for i in range(frames * 2 + 1): - t = i / frames - - draw_coord(t, f) - coord.show(f'{folder}/coord_{i:04d}.png') - - x, y = f(t) - set_vals(x, y) - - fs = 100 - color = 0x111111 - - def reorder(b, seg): - r = '' - seg = [0] + seg + [32] - for i in range(len(seg) - 1): - r = r + b[32 - seg[i + 1]:32 - seg[i]] - return r - - def real_to_str(x): - s = '' - if x < 0: - s = '' - else: - s = ' ' - return s + f'{x:.4f}' - - numbers.text(real_to_str(x), (0.05, 0.9), font_size=fs, color=color) - numbers.text(real_to_str(y), (0.55, 0.9), font_size=fs, color=color) - - fs = 49 - - bits = [bs_vf19, bs_vf16, bs_vfixed16] - seg = [[], [], [6, 19], [5, 16, 21], [16]] - bits = list(map(lambda x: serialize_i32(fetch_bs(x)), bits)) - - bits = [serialize_f32(x), serialize_f32(y)] + bits - - for j in range(len(bits)): - b = reorder(bits[j], seg[j]) - numbers.text(b, (0.05, 0.7 - j * 0.15), font_size=fs, color=color) - - numbers.show(f'{folder}/numbers_{i:04d}.png') - -os.system( - f'ti video {folder}/numbers*.png -f 60 -c 2 -o numbers{args.curve}.mp4') -os.system(f'ti video {folder}/coord*.png -f 60 -c 2 -o coord{args.curve}.mp4') diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 6147f27e2e4ae..d15731be885ac 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -191,14 +191,6 @@ def subscript(value, *_indices, skip_reordered=False): ]) ret.any_array_access = any_array_access return ret - if isinstance(value, SNode): - # When reading bit structure we only support the 0-D case for now. - field_dim = 0 - if field_dim != index_dim: - raise IndexError( - f'Field with dim {field_dim} accessed with indices of dim {index_dim}' - ) - return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) # Directly evaluate in Python for non-Taichi types return value.__getitem__(*_indices) diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index f6bb7e607f2ef..29f391d17334a 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -101,11 +101,7 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } void visit(GlobalPtrExpression *expr) override { - if (expr->snode) { - emit(expr->snode->get_node_type_name_hinted()); - } else { - expr->var->accept(this); - } + expr->var->accept(this); emit('['); emit_vector(expr->indices.exprs); emit(']'); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index cfc5a8f4ae3c8..9ded4ffae1146 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -365,10 +365,7 @@ void GlobalVariableExpression::flatten(FlattenContext *ctx) { void GlobalPtrExpression::type_check(CompileConfig *) { // Currently, dimension compatibility check happens in Python - if (snode != nullptr) { - TI_ASSERT(snode->dt->is()); - ret_type = snode->dt->cast()->get_physical_type(); - } else if (var.is()) { + if (var.is()) { ret_type = var.cast()->snode->dt->get_compute_type(); } else if (var.is()) { @@ -391,10 +388,7 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) { std::vector index_stmts; std::vector offsets; SNode *snode = nullptr; - if (this->snode != nullptr) { - snode = this->snode; - } - if (bool(var) && var.is()) { + if (var.is()) { snode = var.cast()->snode; offsets = snode->index_offsets; } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index d34f7b8274c7f..93752a4fca9da 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -463,7 +463,6 @@ class GlobalVariableExpression : public Expression { class GlobalPtrExpression : public Expression { public: - SNode *snode{nullptr}; Expr var; ExprGroup indices; @@ -471,10 +470,6 @@ class GlobalPtrExpression : public Expression { : var(var), indices(indices) { } - GlobalPtrExpression(SNode *snode, const ExprGroup &indices) - : snode(snode), indices(indices) { - } - void type_check(CompileConfig *config) override; void flatten(FlattenContext *ctx) override; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 96c55be9890b7..8b2a5546d8b6b 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -865,10 +865,6 @@ void export_lang(py::module &m) { Expr::make &, int>); - m.def("subscript", [](SNode *snode, const ExprGroup &indices) { - return Expr::make(snode, indices); - }); - m.def("get_external_tensor_dim", [](const Expr &expr) { TI_ASSERT(expr.is()); return expr.cast()->dim; diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index a92b85ce96fed..87c7cc39723c4 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -123,11 +123,7 @@ class TypeCheck : public IRVisitor { void visit(GlobalLoadStmt *stmt) override { auto pointee_type = stmt->src->ret_type.ptr_removed(); - if (auto bit_struct = pointee_type->cast()) { - stmt->ret_type = bit_struct->get_physical_type(); - } else { - stmt->ret_type = pointee_type->get_compute_type(); - } + stmt->ret_type = pointee_type->get_compute_type(); } void visit(SNodeOpStmt *stmt) override {