diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 90475611d7341..fe11c0ddba9bb 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -2,6 +2,7 @@ import collections.abc import itertools import operator +import re import warnings from collections import ChainMap from sys import version_info @@ -391,6 +392,17 @@ def build_Starred(ctx, node): node.ptr = build_stmt(ctx, node.value) return node.ptr + @staticmethod + def build_FormattedValue(ctx, node): + node.ptr = build_stmt(ctx, node.value) + if node.format_spec is None or len(node.format_spec.values) == 0: + return node.ptr + values = node.format_spec.values + assert len(values) == 1 + format_str = values[0].s if version_info < (3, 8) else values[0].value + assert format_str is not None + return [node.ptr, format_str] + @staticmethod def build_JoinedStr(ctx, node): str_spec = '' @@ -398,7 +410,7 @@ def build_JoinedStr(ctx, node): for sub_node in node.values: if isinstance(sub_node, ast.FormattedValue): str_spec += '{}' - args.append(build_stmt(ctx, sub_node.value)) + args.append(build_stmt(ctx, sub_node)) elif isinstance(sub_node, ast.Constant): str_spec += sub_node.value elif isinstance(sub_node, ast.Str): @@ -495,6 +507,50 @@ def warn_if_is_external_func(ctx, node): node.lineno + ctx.lineno_offset, module="taichi") + @staticmethod + # extract format specifier from raw_string, and also handles positional arguments, if there are any + def extract_printf_format(raw_string, *raw_args, **keywords): + raw_brackets = re.findall(r'{(.*?)}', raw_string) + # fallback to old behavior + if all(':' not in bracket and not bracket.isdigit() + for bracket in raw_brackets): + raw_args = list(raw_args) + raw_args.insert(0, raw_string) + return raw_args, keywords + + brackets = [] + unnamed = 0 + for bracket in raw_brackets: + item, spec = bracket.split(':') if ':' in bracket else (bracket, + None) + item = int(item) if item.isdigit() else item + # handle unnamed positional args + if item == "": + item = unnamed + unnamed += 1 + # handle empty spec + if spec == "": + spec = None + brackets.append([item, spec]) + + # check error first + for (item, _) in brackets: + if isinstance(item, int) and item >= len(raw_args): + raise TaichiSyntaxError(f'Index {item} is out of range.') + if isinstance(item, str) and item not in keywords: + raise TaichiSyntaxError(f'Keyword "{item}" is not found.') + + args = [] + for (item, spec) in brackets: + args.append([ + raw_args[item] if isinstance(item, int) else keywords[item], + spec + ]) + + args.insert(0, re.sub(r'{.*?}', '{}', raw_string)) + # TODO: unify keyword args handling + return args, {} + @staticmethod def build_Call(ctx, node): if ASTTransformer.get_decorator(ctx, node) == 'static': @@ -530,7 +586,9 @@ def build_Call(ctx, node): if isinstance(node.func, ast.Attribute) and isinstance( node.func.value.ptr, str) and node.func.attr == 'format': - args.insert(0, node.func.value.ptr) + raw_string = node.func.value.ptr + args, keywords = ASTTransformer.extract_printf_format( + raw_string, *args, **keywords) node.ptr = impl.ti_format(*args, **keywords) return node.ptr @@ -1467,7 +1525,8 @@ def _handle_string_mod_args(ctx, node): @staticmethod def ti_format_list_to_assert_msg(raw): - entries = impl.ti_format_list_to_content_entries([raw]) + #TODO: ignore formats here for now + entries, _ = impl.ti_format_list_to_content_entries([raw]) msg = "" args = [] for entry in entries: diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index bbba25736163c..70ebc66957d5f 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -813,6 +813,12 @@ def ti_format_list_to_content_entries(raw): def entry2content(_var): if isinstance(_var, str): return _var + # handle optional format specifier + if isinstance(_var, list): + assert len(_var) == 2 and (isinstance(_var[1], str) + or _var[1] is None) + _var[0] = Expr(_var[0]).ptr + return _var return Expr(_var).ptr def list_ti_repr(_var): @@ -832,6 +838,10 @@ def vars2entries(_vars): if len(_var) > 0 and isinstance( _var[0], str) and _var[0] == '__ti_format__': res = _var[1:] + elif len(_var) > 0 and isinstance( + _var[0], str) and _var[0] == '__ti_fmt_list__': + yield _var[1:] + continue else: res = list_ti_repr(_var) else: @@ -854,9 +864,24 @@ def fused_string(entries): if accumated: yield accumated + def extract_formats(entries): + contents = [] + formats = [] + for entry in entries: + if isinstance(entry, list): + assert len(entry) == 2 + contents.append(entry[0]) + formats.append(entry[1]) + else: + contents.append(entry) + formats.append(None) + return contents, formats + entries = vars2entries(raw) entries = fused_string(entries) - return [entry2content(entry) for entry in entries] + entries = [entry2content(entry) for entry in entries] + contents, formats = extract_formats(entries) + return contents, formats @taichi_scope @@ -869,8 +894,9 @@ def add_separators(_vars): yield end _vars = add_separators(_vars) - entries = ti_format_list_to_content_entries(_vars) - get_runtime().compiling_callable.ast_builder().create_print(entries) + contents, formats = ti_format_list_to_content_entries(_vars) + get_runtime().compiling_callable.ast_builder().create_print( + contents, formats) @taichi_scope @@ -884,6 +910,11 @@ def ti_format(*args, **kwargs): if isinstance(x, Expr): new_mixed.append('{}') args.append(x) + # add tag if encounter an Expr with format specifier + elif isinstance(x, list): + new_mixed.append('{}') + x.insert(0, '__ti_fmt_list__') + args.append(x) else: new_mixed.append(x) for k, v in kwargs.items(): diff --git a/taichi/codegen/cc/codegen_cc.cpp b/taichi/codegen/cc/codegen_cc.cpp index 29e89ab285a3d..33bf2eb3d907b 100644 --- a/taichi/codegen/cc/codegen_cc.cpp +++ b/taichi/codegen/cc/codegen_cc.cpp @@ -8,6 +8,7 @@ #include "taichi/util/line_appender.h" #include "taichi/util/str.h" #include "cc_utils.h" +#include "taichi/codegen/codegen_utils.h" #define C90_COMPAT 0 @@ -385,25 +386,27 @@ class CCTransformer : public IRVisitor { } void visit(PrintStmt *stmt) override { - std::string format; + std::string formats; std::vector values; for (int i = 0; i < stmt->contents.size(); i++) { auto const &content = stmt->contents[i]; + auto const &format = stmt->formats[i]; if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); - format += data_type_format(arg_stmt->ret_type); + formats += merge_printf_specifier( + format, data_type_format(arg_stmt->ret_type), Arch::cc); values.push_back(arg_stmt->raw_name()); } else { auto str = std::get(content); - format += "%s"; + formats += "%s"; values.push_back(c_quoted(str)); } } - values.insert(values.begin(), c_quoted(format)); + values.insert(values.begin(), c_quoted(formats)); emit("printf({});", fmt::join(values, ", ")); } diff --git a/taichi/codegen/codegen_utils.h b/taichi/codegen/codegen_utils.h index 4e40bb98433d1..18a6cba87cb87 100644 --- a/taichi/codegen/codegen_utils.h +++ b/taichi/codegen/codegen_utils.h @@ -1,5 +1,6 @@ #pragma once #include "taichi/program/program.h" +#include namespace taichi::lang { @@ -7,4 +8,140 @@ inline bool codegen_vector_type(const CompileConfig &config) { return !config.real_matrix_scalarize; } +// Merge the printf specifiers from user-defined one and taichi data types, and +// rewrite the specifiers accroding to backend capabbilities. +inline std::string merge_printf_specifier( + std::optional const &from_user, + std::string const &from_data_type, + Arch arch) { + if (!from_user.has_value()) { + return from_data_type; + } + std::string const &user = from_user.value(); + + // printf format string specifiers: + // %[flags]+[width][.precision][length][conversion] + // https://en.cppreference.com/w/cpp/io/c/fprintf + const std::regex user_re = std::regex( + "([-+ #0]+)?" + "(\\d+|\\*)?" + "(\\.(?:\\d+|\\*))?" + "([hljztL]|hh|ll)?" + "([csdioxXufFeEaAgGnp])?"); + std::smatch user_match; + bool user_matched = std::regex_match(user, user_match, user_re); + if (user_matched == false) { + TI_ERROR("{} is not a valid printf specifier.", user) + } + std::string user_flags = user_match[1]; + std::string user_width = user_match[2]; + std::string user_precision = user_match[3]; + std::string user_length = user_match[4]; + std::string user_conversion = user_match[5]; + + if (user_width == "*" || user_precision == ".*" || user_conversion == "n") { + TI_ERROR("The {} printf specifier is not supported", user) + } + + const std::regex dt_re = std::regex( + "%" + "(\\.(?:\\d+))?" + "([hljztL]|hh|ll)?" + "([csdioxXufFeEaAgGnp])?"); + std::smatch dt_match; + bool dt_matched = std::regex_match(from_data_type, dt_match, dt_re); + TI_ASSERT(dt_matched); + std::string dt_precision = dt_match[1]; + std::string dt_length = dt_match[2]; + std::string dt_conversion = dt_match[3]; + + // Constant for convensions in group. + constexpr std::string_view signed_group = "di"; + constexpr std::string_view unsigned_group = "oxXu"; + constexpr std::string_view float_group = "fFeEaAgG"; + + // Vulkan doesn't support length, flags, or width specifier. + // https://vulkan.lunarg.com/doc/view/1.2.162.1/linux/debug_printf.html + // + if (arch == Arch::vulkan) { + if (!user_flags.empty()) { + TI_WARN( + "The printf flags '{}' are not supported in Vulkan, " + "and will be discarded.", + user_flags); + user_flags.clear(); + } + if (!user_width.empty()) { + TI_WARN( + "The printf width modifier '{}' is not supported in Vulkan, " + "and will be discarded.", + user_width); + user_width.clear(); + } + // except for unsigned long + if (!user_length.empty() && + !(user_length == "l" && !user_conversion.empty() && + unsigned_group.find(user_conversion) != std::string::npos)) { + TI_WARN( + "The printf length modifier '{}' is not supported in Vulkan, " + "and will be discarded.", + user_length); + user_length.clear(); + } + if (dt_precision == ".12" || dt_length == "ll") { + TI_WARN( + "Vulkan does not support 64-bit printing, except for unsigned long."); + } + } + // CUDA supports all of them but not 'F'. + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#format-specifiers + else if (arch == Arch::cuda) { + if (user_conversion == "F") { + user_conversion = "f"; + } + } + + // Replace user_precision with dt_precision if the former is empty, + // otherwise use user specified precision. + if (user_precision.empty()) { + user_precision = dt_precision; + } + + // Discard user_length and give warning if it doesn't match with dt_length. + if (user_length != dt_length) { + if (!user_length.empty()) { + TI_WARN("The printf length specifier '{}' is overritten by '{}'", + user_length, dt_length); + } + user_length = dt_length; + } + + // Override user_conversion with dt_conversion. + if (user_conversion != dt_conversion) { + if (!user_conversion.empty() && + user_conversion.back() != dt_conversion.back()) { + // Preserves user_conversion when user and dt conversions belong to the + // same group, e.g., when printing unsigned decimal numbers in hexadecimal + // or octal format, or floating point numbers in exponential notation. + if ((signed_group.find(user_conversion.back()) != std::string::npos && + signed_group.find(dt_conversion.back()) != std::string::npos) || + (unsigned_group.find(user_conversion.back()) != std::string::npos && + unsigned_group.find(dt_conversion.back()) != std::string::npos) || + (float_group.find(user_conversion.back()) != std::string::npos && + float_group.find(dt_conversion.back()) != std::string::npos)) { + dt_conversion.back() = user_conversion.back(); + } else { + TI_WARN("The printf conversion specifier '{}' is overritten by '{}'", + user_conversion, dt_conversion); + } + } + user_conversion = dt_conversion; + } + + std::string res = "%" + user_flags + user_width + user_precision + + user_length + user_conversion; + TI_TRACE("Merge %{} and {} into {}.", user, from_data_type, res); + return res; +} + } // namespace taichi::lang diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index c89e824191305..4c4a14ad5e119 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -95,11 +95,15 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { std::string formats; size_t num_contents = 0; - for (auto const &content : stmt->contents) { + for (auto i = 0; i < stmt->contents.size(); ++i) { + auto const &content = stmt->contents[i]; + auto const &format = stmt->formats[i]; + if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); - formats += data_type_format(arg_stmt->ret_type); + formats += merge_printf_specifier( + format, data_type_format(arg_stmt->ret_type), Arch::cuda); auto value = llvm_val[arg_stmt]; auto value_type = value->getType(); diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 18acf93f31163..ca64ed6fd9577 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -955,7 +955,10 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { tlctx->get_data_type(PrimitiveType::u16)); return to_print; }; - for (auto const &content : stmt->contents) { + for (auto i = 0; i < stmt->contents.size(); ++i) { + auto const &content = stmt->contents[i]; + auto const &format = stmt->formats[i]; + if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; @@ -977,7 +980,8 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { formats += data_type_format(arg_stmt->ret_type); } else { args.push_back(value_for_printf(value, arg_stmt->ret_type)); - formats += data_type_format(arg_stmt->ret_type); + formats += merge_printf_specifier( + format, data_type_format(arg_stmt->ret_type), current_arch()); } } else { auto arg_str = std::get(content); diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 223b6095bc821..12c0bc2fcdf34 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -4,6 +4,7 @@ #include #include +#include "taichi/codegen/codegen_utils.h" #include "taichi/program/program.h" #include "taichi/program/kernel.h" #include "taichi/ir/statements.h" @@ -97,6 +98,19 @@ class TaskCodegen : public IRVisitor { } } + // Replace the wild '%' in the format string with "%%". + std::string sanitize_format_string(std::string const &str) { + std::string sanitized_str; + for (char c : str) { + if (c == '%') { + sanitized_str += "%%"; + } else { + sanitized_str += c; + } + } + return sanitized_str; + } + struct Result { std::vector spirv_code; TaskAttributes task_attribs; @@ -151,17 +165,21 @@ class TaskCodegen : public IRVisitor { std::string formats; std::vector vals; - for (auto const &content : stmt->contents) { + for (auto i = 0; i < stmt->contents.size(); ++i) { + auto const &content = stmt->contents[i]; + auto const &format = stmt->formats[i]; if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); TI_ASSERT(!arg_stmt->ret_type->is()); auto value = ir_->query_value(arg_stmt->raw_name()); vals.push_back(value); - formats += data_type_format(arg_stmt->ret_type, Arch::vulkan); + formats += merge_printf_specifier( + format, data_type_format(arg_stmt->ret_type, Arch::vulkan), + Arch::vulkan); } else { auto arg_str = std::get(content); - formats += arg_str; + formats += sanitize_format_string(arg_str); } } ir_->call_debugprintf(formats, vals); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 03fb60677384a..84e75309b1324 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1349,8 +1349,9 @@ void ASTBuilder::create_kernel_exprgroup_return(const ExprGroup &group) { } void ASTBuilder::create_print( - std::vector> contents) { - this->insert(std::make_unique(contents)); + std::vector> contents, + std::vector> formats) { + this->insert(std::make_unique(contents, formats)); } void ASTBuilder::begin_func(const std::string &funcid) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 1de36d5b5118b..10fbc31f075ea 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -165,15 +165,13 @@ class FrontendIfStmt : public Stmt { class FrontendPrintStmt : public Stmt { public: using EntryType = std::variant; - std::vector contents; - - explicit FrontendPrintStmt(const std::vector &contents_) { - for (const auto &c : contents_) { - if (std::holds_alternative(c)) - contents.push_back(std::get(c)); - else - contents.push_back(c); - } + using FormatType = std::optional; + const std::vector contents; + const std::vector formats; + + FrontendPrintStmt(const std::vector &contents_, + const std::vector &formats_) + : contents(contents_), formats(formats_) { } TI_DEFINE_ACCEPT @@ -973,7 +971,8 @@ class ASTBuilder { Expr insert_thread_idx_expr(); Expr insert_patch_idx_expr(); void create_kernel_exprgroup_return(const ExprGroup &group); - void create_print(std::vector> contents); + void create_print(std::vector> contents, + std::vector> formats); void begin_func(const std::string &funcid); void end_func(const std::string &funcid); void begin_frontend_if(const Expr &cond); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 93e623a613a50..a91bea0e0c883 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -853,13 +853,21 @@ class IfStmt : public Stmt { class PrintStmt : public Stmt { public: using EntryType = std::variant; - std::vector contents; + using FormatType = std::optional; + const std::vector contents; + const std::vector formats; explicit PrintStmt(const std::vector &contents_) : contents(contents_) { TI_STMT_REG_FIELDS; } + PrintStmt(const std::vector &contents_, + const std::vector &formats_) + : contents(contents_), formats(formats_) { + TI_STMT_REG_FIELDS; + } + template explicit PrintStmt(Stmt *t, Args &&...args) : contents(make_entries(t, std::forward(args)...)) { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index b26fd531301a5..62b653e5f2c2a 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -252,12 +252,20 @@ class IRPrinter : public IRVisitor { void visit(FrontendPrintStmt *print_stmt) override { std::vector contents; - for (auto const &c : print_stmt->contents) { + for (auto i = 0; i != print_stmt->contents.size(); ++i) { + auto const &c = print_stmt->contents[i]; + auto const &f = print_stmt->formats[i]; + std::string name; if (std::holds_alternative(c)) name = expr_to_string(std::get(c).expr.get()); else name = c_quoted(std::get(c)); + + if (f.has_value()) { + name += ":"; + name += f.value(); + } contents.push_back(name); } print("print {}", fmt::join(contents, ", ")); @@ -265,12 +273,20 @@ class IRPrinter : public IRVisitor { void visit(PrintStmt *print_stmt) override { std::vector names; - for (auto const &c : print_stmt->contents) { + for (auto i = 0; i != print_stmt->contents.size(); ++i) { + auto const &c = print_stmt->contents[i]; + auto const &f = print_stmt->formats[i]; + std::string name; if (std::holds_alternative(c)) name = std::get(c)->name(); else name = c_quoted(std::get(c)); + + if (f.has_value()) { + name += ":"; + name += f.value(); + } names.push_back(name); } print("print {}", fmt::join(names, ", ")); diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 0f1563011f35a..7a633a6f7eb83 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -141,7 +141,7 @@ class LowerAST : public IRVisitor { new_contents.push_back(x); } } - fctx.push_back(new_contents); + fctx.push_back(new_contents, stmt->formats); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); } diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 42305cf394f25..a26d8bc388deb 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -263,10 +263,14 @@ class Scalarize : public BasicStmtVisitor { } void visit(PrintStmt *stmt) override { - auto &contents = stmt->contents; + auto const &contents = stmt->contents; + auto const &formats = stmt->formats; std::vector> new_contents; + // Sparse mapping between formatted expr and its specifier + std::map new_formats; for (size_t i = 0; i < contents.size(); i++) { - auto content = contents[i]; + auto const &content = contents[i]; + auto const &format = formats[i]; if (auto string_ptr = std::get_if(&content)) { new_contents.push_back(*string_ptr); } else { @@ -287,6 +291,9 @@ class Scalarize : public BasicStmtVisitor { for (size_t j = 0; j < n; j++) { size_t index = i * n + j; new_contents.push_back(matrix_init_stmt->values[index]); + if (format.has_value()) { + new_formats[matrix_init_stmt->values[index]] = format.value(); + } if (j != n - 1) new_contents.push_back(", "); } @@ -298,6 +305,9 @@ class Scalarize : public BasicStmtVisitor { } else { for (size_t i = 0; i < m; i++) { new_contents.push_back(matrix_init_stmt->values[i]); + if (format.has_value()) { + new_formats[matrix_init_stmt->values[i]] = format.value(); + } if (i != m - 1) new_contents.push_back(", "); } @@ -305,12 +315,16 @@ class Scalarize : public BasicStmtVisitor { new_contents.push_back("]"); } else { new_contents.push_back(print_stmt); + if (format.has_value()) { + new_formats[print_stmt] = format.value(); + } } } } // Merge string contents std::vector> merged_contents; + std::vector> merged_formats; std::string merged_string = ""; for (const auto &content : new_contents) { if (auto string_content = std::get_if(&content)) { @@ -318,16 +332,26 @@ class Scalarize : public BasicStmtVisitor { } else { if (!merged_string.empty()) { merged_contents.push_back(merged_string); + merged_formats.push_back(std::nullopt); merged_string = ""; } merged_contents.push_back(content); + const auto format = new_formats.find(std::get(content)); + if (format != new_formats.end()) { + merged_formats.push_back(format->second); + } else { + merged_formats.push_back(std::nullopt); + } } } - if (!merged_string.empty()) + if (!merged_string.empty()) { merged_contents.push_back(merged_string); + merged_formats.push_back(std::nullopt); + } - delayed_modifier_.insert_before(stmt, - Stmt::make(merged_contents)); + assert(merged_contents.size() == merged_formats.size()); + delayed_modifier_.insert_before( + stmt, Stmt::make(merged_contents, merged_formats)); delayed_modifier_.erase(stmt); } diff --git a/tests/python/test_print.py b/tests/python/test_print.py index 73c1a87e9d8e2..a0185f09bd4de 100644 --- a/tests/python/test_print.py +++ b/tests/python/test_print.py @@ -1,3 +1,5 @@ +import sys + import pytest import taichi as ti @@ -39,8 +41,7 @@ def func(x: ti.i32, y: ti.f32): ti.sync() -# TODO: vulkan doesn't support %s but we should ignore it instead of crashing. -@test_utils.test(exclude=[ti.vulkan, ti.dx11, ti.amdgpu]) +@test_utils.test(exclude=[ti.dx11, ti.amdgpu]) def test_print_string(): @ti.kernel def func(x: ti.i32, y: ti.f32): @@ -68,6 +69,22 @@ def func(k: ti.f32): ti.sync() +@test_utils.test(exclude=[ti.dx11, vk_on_mac, ti.amdgpu], debug=True) +def test_print_matrix_format(): + x = ti.Matrix.field(2, 3, dtype=ti.f32, shape=()) + y = ti.Vector.field(3, dtype=ti.f32, shape=3) + + @ti.kernel + def func(k: ti.f32): + x[None][0, 0] = -1.0 + y[2] += 1.0 + print(f'hello {x[None]:.2f} world!') + print(f'{(y[2] * k):.2e} {(x[None] / k):.8} {y[2]:.3}') + + func(233.3) + ti.sync() + + @test_utils.test(exclude=[ti.dx11, vk_on_mac, ti.amdgpu], debug=True) def test_print_sep_end(): @ti.kernel @@ -104,7 +121,7 @@ def func(k: ti.f32): ti.sync() -@test_utils.test(exclude=[ti.cc, ti.dx11, vk_on_mac, ti.amdgpu], debug=True) +@test_utils.test(exclude=[ti.dx11, vk_on_mac, ti.amdgpu], debug=True) def test_print_list(): x = ti.Matrix.field(2, 3, dtype=ti.f32, shape=(2, 3)) y = ti.Vector.field(3, dtype=ti.f32, shape=()) @@ -125,7 +142,7 @@ def func(k: ti.f32): ti.sync() -@test_utils.test(arch=[ti.cpu, ti.vulkan], +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac, ti.amdgpu], debug=True) def test_python_scope_print_field(): @@ -138,7 +155,7 @@ def test_python_scope_print_field(): print(z) -@test_utils.test(arch=[ti.cpu, ti.vulkan], +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac, ti.amdgpu], debug=True) def test_print_string_format(): @@ -156,7 +173,149 @@ def func(k: ti.f32): ti.sync() -@test_utils.test(arch=[ti.cpu, ti.vulkan], +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], + exclude=[vk_on_mac, ti.amdgpu], + debug=True) +def test_print_string_format_specifier_32(): + a = ti.field(ti.f32, 2) + + @ti.kernel + def func(u: ti.u32, i: ti.i32): + a[0] = ti.f32(1.111111) + a[1] = ti.f32(2.222222) + + print("{:} {:}".format(a[0], a[1])) + print("{:f} {:F}".format(a[0], a[1])) + print("{:e} {:E}".format(a[0], a[1])) + print("{:a} {:A}".format(a[0], a[1])) + print("{:g} {:G}".format(a[0], a[1])) + + print("{:.2} {:.3}".format(a[0], a[1])) + print("{:.2f} {:.3F}".format(a[0], a[1])) + print("{:.2e} {:.3E}".format(a[0], a[1])) + print("{:.2a} {:.3A}".format(a[0], a[1])) + print("{:.2G} {:.3G}".format(a[0], a[1])) + + print("{a0:.2f} {a1:.3f}".format(a0=a[0], a1=a[1])) + + print("{:x}".format(u)) + print("{:X}".format(u)) + print("{:o}".format(u)) + print("{:u}".format(u)) + print("{:}".format(u)) + print("{name:x}".format(name=u)) + + print("{:d}".format(i)) + print("{:i}".format(i)) + print("{:}".format(i)) + print("{name:}".format(name=i)) + + func(0xdeadbeef, 2**31 - 1) + ti.sync() + + +@test_utils.test(arch=[ti.cpu, ti.cuda], + exclude=[ti.amdgpu, ti.vulkan], + debug=True) +def test_print_string_format_specifier_64(): + a = ti.field(ti.f64, 2) + + @ti.kernel + def func(llu: ti.u64, lli: ti.i64): + a[0] = ti.f64(1.111111111111) + a[1] = ti.f64(2.222222222222) + + print("{:} {:}".format(a[0], a[1])) + print("{:f} {:F}".format(a[0], a[1])) + print("{:e} {:E}".format(a[0], a[1])) + print("{:a} {:A}".format(a[0], a[1])) + print("{:g} {:G}".format(a[0], a[1])) + + print("{:.2} {:.3}".format(a[0], a[1])) + print("{:.2f} {:.3F}".format(a[0], a[1])) + print("{:.2e} {:.3E}".format(a[0], a[1])) + print("{:.2a} {:.3A}".format(a[0], a[1])) + print("{:.2G} {:.3G}".format(a[0], a[1])) + + print("{a0:f} {a1:f}".format(a0=a[0], a1=a[1])) + + print("{:llx}".format(llu)) + print("{:llX}".format(llu)) + print("{:llo}".format(llu)) + print("{:llu}".format(llu)) + print("{:ll}".format(llu)) + print("{:x}".format(llu)) + print("{:X}".format(llu)) + print("{:o}".format(llu)) + print("{:u}".format(llu)) + print("{:}".format(llu)) + print("{name:llx}".format(name=llu)) + print("{name:x}".format(name=llu)) + + print("{:lld}".format(lli)) + print("{:lli}".format(lli)) + print("{:ll}".format(lli)) + print("{:d}".format(lli)) + print("{:i}".format(lli)) + print("{:}".format(lli)) + print("{name:ll}".format(name=lli)) + + func(0xcafebabedeadbeef, 2**63 - 1) + ti.sync() + + +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], + exclude=[vk_on_mac, ti.amdgpu], + debug=True) +def test_print_string_format_specifier_vulkan_ul(): + @ti.kernel + def func(llu: ti.u64): + print("{:}".format(llu)) + print("{:u}".format(llu)) + print("{:lu}".format(llu)) + print("{name:lu}".format(name=llu)) + + # FIXME: %lx works on vulkan but %lX only prints lower 32 bits... why? + print("{:lx}".format(llu)) + print("{name:lx}".format(name=llu)) + + print("{:lX}".format(llu)) + print("{name:lX}".format(name=llu)) + + func(0xcafebabedeadbeef) + ti.sync() + + +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], + exclude=[vk_on_mac, ti.amdgpu], + debug=True) +def test_print_string_format_positional_arg(): + a = ti.field(ti.f32, 2) + + @ti.kernel + def func(u: ti.u32): + a[0] = ti.f32(1.111111) + a[1] = ti.f32(2.222222) + print("{0:.2f} {1:.3f}".format(a[0], a[1])) + print("{1:.2f} {0:.3f}".format(a[1], a[0])) + + print("{a0:.2f} {0:.3f}".format(a[1], a0=a[0])) + print("{0:.2f} {a1:.3f}".format(a[0], a1=a[1])) + + print("{a0:.2f} {:.3f}".format(a[1], a0=a[0])) + print("{:.2f} {a1:.3f}".format(a[0], a1=a[1])) + + print("{0:x}".format(u)) + + print( + "a[0] = {0:.2f}, f = {name:.2f}, u = {u:x}, a[1] = {1:.3f}, a[1] = {1:.4f}" + .format(a[0], a[1], name=42., u=u)) + + func(0xdeadbeef) + ti.sync() + + +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac, ti.amdgpu], debug=True) def test_print_fstring(): @@ -171,7 +330,101 @@ def func(i: ti.i32, f: ti.f32): ti.sync() -@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan], +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], + exclude=[vk_on_mac, ti.amdgpu], + debug=True) +def test_print_fstring_specifier_32(): + a = ti.field(ti.f32, 2) + + @ti.kernel + def func(u: ti.u32, i: ti.i32): + a[0] = ti.f32(1.111111) + a[1] = ti.f32(2.222222) + + print(f"{a[0]:} {a[1]:}") + print(f"{a[0]:f} {a[1]:F}") + print(f"{a[0]:e} {a[1]:E}") + print(f"{a[0]:a} {a[1]:A}") + print(f"{a[0]:g} {a[1]:G}") + + print(f"{a[0]:.2} {a[1]:.3}") + print(f"{a[0]:.2f} {a[1]:.3F}") + print(f"{a[0]:.2e} {a[1]:.3E}") + print(f"{a[0]:.2a} {a[1]:.3A}") + print(f"{a[0]:.2G} {a[1]:.3G}") + + print(f"{u:x}") + print(f"{u:X}") + print(f"{u:o}") + print(f"{u:u}") + print(f"{u:}") + + print(f"{i:d}") + print(f"{i:i}") + print(f"{i:}") + + func(0xdeadbeef, 2**31 - 1) + ti.sync() + + +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda], + exclude=[ti.vulkan, ti.amdgpu], + debug=True) +def test_print_fstring_specifier_64(): + a = ti.field(ti.f64, 2) + + @ti.kernel + def func(u: ti.u64, i: ti.i64): + a[0] = ti.f32(1.111111111111) + a[1] = ti.f32(2.222222222222) + + print(f"{a[0]:} {a[1]:}") + print(f"{a[0]:f} {a[1]:F}") + print(f"{a[0]:e} {a[1]:E}") + print(f"{a[0]:a} {a[1]:A}") + print(f"{a[0]:g} {a[1]:G}") + + print(f"{a[0]:.2} {a[1]:.3}") + print(f"{a[0]:.2f} {a[1]:.3F}") + print(f"{a[0]:.2e} {a[1]:.3E}") + print(f"{a[0]:.2a} {a[1]:.3A}") + print(f"{a[0]:.2G} {a[1]:.3G}") + + print(f"{a[0]:.2f} {a[1]:.3f}") + + print(f"{u:x}") + print(f"{u:X}") + print(f"{u:o}") + print(f"{u:u}") + print(f"{u:}") + + print(f"{i:d}") + print(f"{i:i}") + print(f"{i:}") + + func(0xcafebabedeadbeef, 2**63 - 1) + ti.sync() + + +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], + exclude=[vk_on_mac, ti.amdgpu], + debug=True) +def test_print_fstring_specifier_vulkan_ul(): + @ti.kernel + def func(llu: ti.u64): + print(f"{llu:}") + print(f"{llu:u}") + print(f"{llu:lu}") + + # FIXME: %lx works on vulkan but %lX only prints lower 32 bits... why? + print(f"{llu:lx}") + print(f"{llu:lX}") + + func(0xcafebabedeadbeef) + ti.sync() + + +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac, ti.amdgpu], debug=True) def test_print_u64(): @@ -183,7 +436,7 @@ def func(i: ti.u64): ti.sync() -@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan], +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac, ti.amdgpu], debug=True) def test_print_i64(): @@ -195,7 +448,7 @@ def func(i: ti.i64): ti.sync() -@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan], +@test_utils.test(arch=[ti.cc, ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac, cuda_on_windows, ti.amdgpu], debug=True) def test_print_seq(capfd):