Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vulkan] [bug] Change the format string of 64bit unsigned integer type from %llu to %lu #6308

Merged
merged 7 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class TaskCodegen : public IRVisitor {

auto value = ir_->query_value(arg_stmt->raw_name());
vals.push_back(value);
formats += data_type_format(arg_stmt->ret_type);
formats += data_type_format(arg_stmt->ret_type, Arch::vulkan);
} else {
auto arg_str = std::get<std::string>(content);
formats += arg_str;
Expand Down
12 changes: 6 additions & 6 deletions taichi/ir/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,16 @@ std::string tensor_type_format_helper(const std::vector<int> &shape,
return fmt;
}

std::string tensor_type_format(DataType t) {
std::string tensor_type_format(DataType t, Arch arch) {
TI_ASSERT(t->is<TensorType>());
auto tensor_type = t->as<TensorType>();
auto shape = tensor_type->get_shape();
auto element_type = tensor_type->get_element_type();
auto element_type_format = data_type_format(element_type);
auto element_type_format = data_type_format(element_type, arch);
return tensor_type_format_helper(shape, element_type_format, 0);
}

std::string data_type_format(DataType dt) {
std::string data_type_format(DataType dt, Arch arch) {
if (dt->is_primitive(PrimitiveTypeID::i8)) {
// i8/u8 is converted to i16/u16 before printing, because CUDA doesn't
// support the "%hhd"/"%hhu" specifiers.
Expand All @@ -116,9 +116,9 @@ std::string data_type_format(DataType dt) {
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
// Use %lld on Windows.
// Discussion: https://github.com/taichi-dev/taichi/issues/2522
return "%lld";
return arch == Arch::vulkan ? "%ld" : "%lld";
} else if (dt->is_primitive(PrimitiveTypeID::u64)) {
return "%llu";
return arch == Arch::vulkan ? "%lu" : "%llu";
} else if (dt->is_primitive(PrimitiveTypeID::f32)) {
return "%f";
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
Expand All @@ -131,7 +131,7 @@ std::string data_type_format(DataType dt) {
// TaskCodeGenCUDA::visit(PrintStmt *stmt) for more details.
return "%f";
} else if (dt->is<TensorType>()) {
return tensor_type_format(dt);
return tensor_type_format(dt, arch);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/ir/type.h"
#include "taichi/ir/type_factory.h"
#include "taichi/rhi/arch.h"

namespace taichi::lang {

Expand All @@ -11,7 +12,7 @@ TI_DLL_EXPORT std::string data_type_name(DataType t);

TI_DLL_EXPORT int data_type_size(DataType t);

TI_DLL_EXPORT std::string data_type_format(DataType dt);
TI_DLL_EXPORT std::string data_type_format(DataType dt, Arch arch = Arch::x64);

inline int data_type_bits(DataType t) {
return data_type_size(t) * 8;
Expand Down
24 changes: 24 additions & 0 deletions tests/python/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,27 @@ def func(i: ti.i32, f: ti.f32):

func(123, 4.56)
ti.sync()


@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan],
exclude=[vk_on_mac],
debug=True)
def test_print_u64():
@ti.kernel
def func(i: ti.u64):
print("i =", i)

func(2**64 - 1)
ti.sync()


@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan],
exclude=[vk_on_mac],
debug=True)
def test_print_i64():
@ti.kernel
def func(i: ti.i64):
print("i =", i)

func(-2**63)
ti.sync()