diff --git a/python/taichi/lang/field.py b/python/taichi/lang/field.py index 51375dcee3808..e36d88f43aafd 100644 --- a/python/taichi/lang/field.py +++ b/python/taichi/lang/field.py @@ -1,5 +1,6 @@ import taichi.lang from taichi._lib import core as _ti_core +from taichi.lang.exception import TaichiSyntaxError from taichi.lang.util import (in_python_scope, python_scope, to_numpy_type, to_paddle_type, to_pytorch_type) @@ -421,14 +422,20 @@ def place(self, *args, shared_exponent=False): """ if shared_exponent: self.bit_struct_type_builder.begin_placing_shared_exponent() + count = 0 for arg in args: assert isinstance(arg, Field) for var in arg._get_field_members(): self.fields.append((var.ptr, self.bit_struct_type_builder.add_member( var.ptr.get_dt()))) + count += 1 if shared_exponent: self.bit_struct_type_builder.end_placing_shared_exponent() + if count <= 1: + raise TaichiSyntaxError( + "At least 2 fields need to be placed when shared_exponent=True" + ) __all__ = ["BitpackedFields", "Field", "ScalarField"] diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 194b1211d1547..6b9d1a51e7990 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -170,17 +170,14 @@ BitStructType::BitStructType( PrimitiveType *physical_type, const std::vector &member_types, const std::vector &member_bit_offsets, - const std::vector &member_owns_shared_exponents, const std::vector &member_exponents, const std::vector> &member_exponent_users) : physical_type_(physical_type), member_types_(member_types), member_bit_offsets_(member_bit_offsets), - member_owns_shared_exponents_(member_owns_shared_exponents), member_exponents_(member_exponents), member_exponent_users_(member_exponent_users) { TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); - TI_ASSERT(member_types_.size() == member_owns_shared_exponents_.size()); TI_ASSERT(member_types_.size() == member_exponents_.size()); TI_ASSERT(member_types_.size() == member_exponent_users_.size()); int physical_type_bits = data_type_bits(physical_type_); @@ -202,9 +199,6 @@ BitStructType::BitStructType( TI_ASSERT(physical_type_bits >= member_total_bits); for (auto i = 0; i < member_types_.size(); ++i) { auto exponent = member_exponents_[i]; - if (member_owns_shared_exponents_[i]) { - TI_ASSERT(exponent != -1); - } if (exponent != -1) { TI_ASSERT(std::find(member_exponent_users_[exponent].begin(), member_exponent_users_[exponent].end(), @@ -224,7 +218,7 @@ std::string BitStructType::to_string() const { member_bit_offsets_[i]); if (member_exponents_[i] != -1) { str += fmt::format(" {}exp={}", - member_owns_shared_exponents_[i] ? "shared_" : "", + get_member_owns_shared_exponent(i) ? "shared_" : "", member_exponents_[i]); } if (i + 1 < num_members) { diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 9294f484103eb..339e2553ffb32 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -273,7 +273,6 @@ class BitStructType : public Type { BitStructType(PrimitiveType *physical_type, const std::vector &member_types, const std::vector &member_bit_offsets, - const std::vector &member_owns_shared_exponents, const std::vector &member_exponents, const std::vector> &member_exponent_users); @@ -296,7 +295,8 @@ class BitStructType : public Type { } bool get_member_owns_shared_exponent(int i) const { - return member_owns_shared_exponents_[i]; + return member_exponents_[i] != -1 && + member_exponent_users_[member_exponents_[i]].size() > 1; } int get_member_exponent(int i) const { @@ -311,7 +311,6 @@ class BitStructType : public Type { PrimitiveType *physical_type_; std::vector member_types_; std::vector member_bit_offsets_; - std::vector member_owns_shared_exponents_; std::vector member_exponents_; std::vector> member_exponent_users_; }; diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index c4ddee7726801..cb4c5804a2cef 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -82,12 +82,11 @@ BitStructType *TypeFactory::get_bit_struct_type( PrimitiveType *physical_type, const std::vector &member_types, const std::vector &member_bit_offsets, - const std::vector &member_owns_shared_exponents, const std::vector &member_exponents, const std::vector> &member_exponent_users) { bit_struct_types_.push_back(std::make_unique( - physical_type, member_types, member_bit_offsets, - member_owns_shared_exponents, member_exponents, member_exponent_users)); + physical_type, member_types, member_bit_offsets, member_exponents, + member_exponent_users)); return bit_struct_types_.back().get(); } diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 5242210040654..4c06c98ea17f2 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -37,7 +37,6 @@ class TypeFactory { PrimitiveType *physical_type, const std::vector &member_types, const std::vector &member_bit_offsets, - const std::vector &member_owns_shared_exponents, const std::vector &member_exponents, const std::vector> &member_exponent_users); diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 5d7bc7b4f2aab..6d3b97154c94f 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -203,9 +203,6 @@ class BitStructTypeBuilder { } } auto digits_id = add_member_impl(member_type); - if (is_placing_shared_exponent_) { - member_owns_shared_exponents_[digits_id] = true; - } member_exponents_[digits_id] = exponent_id; member_exponent_users_[exponent_id].push_back(digits_id); return digits_id; @@ -228,8 +225,7 @@ class BitStructTypeBuilder { BitStructType *build() const { return TypeFactory::get_instance().get_bit_struct_type( - physical_type_, member_types_, member_bit_offsets_, - member_owns_shared_exponents_, member_exponents_, + physical_type_, member_types_, member_bit_offsets_, member_exponents_, member_exponent_users_); } @@ -238,7 +234,6 @@ class BitStructTypeBuilder { int old_num_members = member_types_.size(); member_types_.push_back(member_type); member_bit_offsets_.push_back(member_total_bits_); - member_owns_shared_exponents_.push_back(false); member_exponents_.push_back(-1); member_exponent_users_.push_back({}); QuantIntType *member_qit = nullptr; @@ -263,7 +258,6 @@ class BitStructTypeBuilder { std::vector member_types_; std::vector member_bit_offsets_; int member_total_bits_{0}; - std::vector member_owns_shared_exponents_; std::vector member_exponents_; std::vector> member_exponent_users_; bool is_placing_shared_exponent_{false}; diff --git a/tests/cpp/ir/type_test.cpp b/tests/cpp/ir/type_test.cpp index 9a5b3fa9f5480..f14aaacd7f524 100644 --- a/tests/cpp/ir/type_test.cpp +++ b/tests/cpp/ir/type_test.cpp @@ -18,19 +18,24 @@ TEST(Type, TypeToString) { auto qfl = TypeFactory::get_instance().get_quant_float_type(qi5, qu7, f32); auto bs1 = TypeFactory::get_instance().get_bit_struct_type( - u16, {qi5, qu7}, {0, 5}, {false, false}, {-1, -1}, {{}, {}}); + /*physical_type=*/u16, /*member_types=*/{qi5, qu7}, + /*member_bit_offsets=*/{0, 5}, /*member_exponents=*/{-1, -1}, + /*member_exponent_users=*/{{}, {}}); EXPECT_EQ(bs1->to_string(), "bs(0: qi5@0, 1: qu7@5)"); auto bs2 = TypeFactory::get_instance().get_bit_struct_type( - u32, {qu7, qfl, qu7, qfl}, {0, 7, 12, 19}, {false, false, false, false}, - {-1, 0, -1, 2}, {{1}, {}, {3}, {}}); + /*physical_type=*/u32, /*member_types=*/{qu7, qfl, qu7, qfl}, + /*member_bit_offsets=*/{0, 7, 12, 19}, + /*member_exponents=*/{-1, 0, -1, 2}, + /*member_exponent_users=*/{{1}, {}, {3}, {}}); EXPECT_EQ(bs2->to_string(), "bs(0: qu7@0, 1: qfl(d=qi5 e=qu7 c=f32)@7 exp=0, 2: qu7@12, 3: " "qfl(d=qi5 e=qu7 c=f32)@19 exp=2)"); auto bs3 = TypeFactory::get_instance().get_bit_struct_type( - u32, {qu7, qfl, qfl}, {0, 7, 12}, {false, true, true}, {-1, 0, 0}, - {{1, 2}, {}, {}}); + /*physical_type=*/u32, /*member_types=*/{qu7, qfl, qfl}, + /*member_bit_offsets=*/{0, 7, 12}, /*member_exponents=*/{-1, 0, 0}, + /*member_exponent_users=*/{{1, 2}, {}, {}}); EXPECT_EQ(bs3->to_string(), "bs(0: qu7@0, 1: qfl(d=qi5 e=qu7 c=f32)@7 shared_exp=0, 2: " "qfl(d=qi5 e=qu7 c=f32)@12 shared_exp=0)"); diff --git a/tests/python/test_bitpacked_fields.py b/tests/python/test_bitpacked_fields.py index 2a1c8294628f1..67b8345abc74c 100644 --- a/tests/python/test_bitpacked_fields.py +++ b/tests/python/test_bitpacked_fields.py @@ -1,5 +1,5 @@ import numpy as np -from pytest import approx +import pytest import taichi as ti from tests import test_utils @@ -177,7 +177,7 @@ def assign(): for i in range(N): if i // block_size % 2 == 0: - assert x[i] == approx(i, abs=1e-3) + assert x[i] == pytest.approx(i, abs=1e-3) else: assert x[i] == 0 @@ -213,3 +213,15 @@ def verify_val(): set_val() verify_val() + + +@test_utils.test() +def test_invalid_place(): + f15 = ti.types.quant.float(exp=5, frac=10) + p = ti.field(dtype=f15) + bitpack = ti.BitpackedFields(max_num_bits=32) + with pytest.raises( + ti.TaichiCompilationError, + match= + 'At least 2 fields need to be placed when shared_exponent=True'): + bitpack.place(p, shared_exponent=True)