diff --git a/taichi/codegen/spirv/spirv_ir_builder.cpp b/taichi/codegen/spirv/spirv_ir_builder.cpp index 32383ede4bf2f..5e0727144c1db 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.cpp +++ b/taichi/codegen/spirv/spirv_ir_builder.cpp @@ -1169,7 +1169,15 @@ Value IRBuilder::load_variable(Value pointer, const SType &res_type) { pointer.flag == ValueKind::kStructArrayPtr || pointer.flag == ValueKind::kPhysicalPtr); Value ret = new_value(res_type, ValueKind::kNormal); - ib_.begin(spv::OpLoad).add_seq(res_type, ret, pointer).commit(&function_); + if (pointer.flag == ValueKind::kPhysicalPtr) { + uint32_t alignment = uint32_t(get_primitive_type_size(res_type.dt)); + ib_.begin(spv::OpLoad) + .add_seq(res_type, ret, pointer, spv::MemoryAccessAlignedMask, + alignment) + .commit(&function_); + } else { + ib_.begin(spv::OpLoad).add_seq(res_type, ret, pointer).commit(&function_); + } return ret; } void IRBuilder::store_variable(Value pointer, Value value) { @@ -1177,8 +1185,7 @@ void IRBuilder::store_variable(Value pointer, Value value) { pointer.flag == ValueKind::kPhysicalPtr); TI_ASSERT(value.stype.id == pointer.stype.element_type_id); if (pointer.flag == ValueKind::kPhysicalPtr) { - Value alignment = uint_immediate_number( - t_uint32_, get_primitive_type_size(value.stype.dt)); + uint32_t alignment = uint32_t(get_primitive_type_size(value.stype.dt)); ib_.begin(spv::OpStore) .add_seq(pointer, value, spv::MemoryAccessAlignedMask, alignment) .commit(&function_);