Skip to content

Commit

Permalink
[Lang] [type] Disallow reading a whole bit_struct (#5061)
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier authored May 30, 2022
1 parent 359afd2 commit 0f4c950
Show file tree
Hide file tree
Showing 7 changed files with 4 additions and 211 deletions.
176 changes: 0 additions & 176 deletions misc/visualize_quant_types.py

This file was deleted.

8 changes: 0 additions & 8 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,6 @@ def subscript(value, *_indices, skip_reordered=False):
])
ret.any_array_access = any_array_access
return ret
if isinstance(value, SNode):
# When reading bit structure we only support the 0-D case for now.
field_dim = 0
if field_dim != index_dim:
raise IndexError(
f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
)
return Expr(_ti_core.subscript(value.ptr, indices_expr_group))
# Directly evaluate in Python for non-Taichi types
return value.__getitem__(*_indices)

Expand Down
6 changes: 1 addition & 5 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,7 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
}

void visit(GlobalPtrExpression *expr) override {
if (expr->snode) {
emit(expr->snode->get_node_type_name_hinted());
} else {
expr->var->accept(this);
}
expr->var->accept(this);
emit('[');
emit_vector(expr->indices.exprs);
emit(']');
Expand Down
10 changes: 2 additions & 8 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,7 @@ void GlobalVariableExpression::flatten(FlattenContext *ctx) {

void GlobalPtrExpression::type_check(CompileConfig *) {
// Currently, dimension compatibility check happens in Python
if (snode != nullptr) {
TI_ASSERT(snode->dt->is<BitStructType>());
ret_type = snode->dt->cast<BitStructType>()->get_physical_type();
} else if (var.is<GlobalVariableExpression>()) {
if (var.is<GlobalVariableExpression>()) {
ret_type =
var.cast<GlobalVariableExpression>()->snode->dt->get_compute_type();
} else if (var.is<ExternalTensorExpression>()) {
Expand All @@ -391,10 +388,7 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) {
std::vector<Stmt *> index_stmts;
std::vector<int> offsets;
SNode *snode = nullptr;
if (this->snode != nullptr) {
snode = this->snode;
}
if (bool(var) && var.is<GlobalVariableExpression>()) {
if (var.is<GlobalVariableExpression>()) {
snode = var.cast<GlobalVariableExpression>()->snode;
offsets = snode->index_offsets;
}
Expand Down
5 changes: 0 additions & 5 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,18 +463,13 @@ class GlobalVariableExpression : public Expression {

class GlobalPtrExpression : public Expression {
public:
SNode *snode{nullptr};
Expr var;
ExprGroup indices;

GlobalPtrExpression(const Expr &var, const ExprGroup &indices)
: var(var), indices(indices) {
}

GlobalPtrExpression(SNode *snode, const ExprGroup &indices)
: snode(snode), indices(indices) {
}

void type_check(CompileConfig *config) override;

void flatten(FlattenContext *ctx) override;
Expand Down
4 changes: 0 additions & 4 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,10 +865,6 @@ void export_lang(py::module &m) {
Expr::make<TensorElementExpression, const Expr &, const ExprGroup &,
const std::vector<int> &, int>);

m.def("subscript", [](SNode *snode, const ExprGroup &indices) {
return Expr::make<GlobalPtrExpression>(snode, indices);
});

m.def("get_external_tensor_dim", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
return expr.cast<ExternalTensorExpression>()->dim;
Expand Down
6 changes: 1 addition & 5 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,7 @@ class TypeCheck : public IRVisitor {

void visit(GlobalLoadStmt *stmt) override {
auto pointee_type = stmt->src->ret_type.ptr_removed();
if (auto bit_struct = pointee_type->cast<BitStructType>()) {
stmt->ret_type = bit_struct->get_physical_type();
} else {
stmt->ret_type = pointee_type->get_compute_type();
}
stmt->ret_type = pointee_type->get_compute_type();
}

void visit(SNodeOpStmt *stmt) override {
Expand Down

0 comments on commit 0f4c950

Please sign in to comment.