From 949515309767d6997b3cad747fb5b65ac6161184 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 18 Sep 2023 15:25:11 -0700 Subject: [PATCH 01/10] Prevent use of uninitialized scalar Parameters in JIT code (#7847, partial) --- src/InferArguments.cpp | 5 ++++- src/Parameter.cpp | 28 +++++++++++++++++++++++++++- src/Parameter.h | 25 +++++++++++++++---------- src/Pipeline.cpp | 2 +- test/error/CMakeLists.txt | 1 + test/error/uninitialized_param.cpp | 22 ++++++++++++++++++++++ 6 files changed, 70 insertions(+), 13 deletions(-) create mode 100644 test/error/uninitialized_param.cpp diff --git a/src/InferArguments.cpp b/src/InferArguments.cpp index d2f55b1fa781..9fabb47b9ef8 100644 --- a/src/InferArguments.cpp +++ b/src/InferArguments.cpp @@ -5,6 +5,7 @@ #include "ExternFuncArgument.h" #include "Function.h" +#include "IROperator.h" #include "IRVisitor.h" #include "InferArguments.h" @@ -197,7 +198,9 @@ class InferArguments : public IRGraphVisitor { ArgumentEstimates argument_estimates = p.get_argument_estimates(); if (!p.is_buffer()) { - argument_estimates.scalar_def = p.scalar_expr(); + // We don't want to crater here if a scalar param isn't set; + // instead, default to a zero of the right type, like we used to. + argument_estimates.scalar_def = p.has_scalar_expr() ? p.scalar_expr() : make_zero(p.type()); argument_estimates.scalar_min = p.min_value(); argument_estimates.scalar_max = p.max_value(); argument_estimates.scalar_estimate = p.estimate(); diff --git a/src/Parameter.cpp b/src/Parameter.cpp index 1155d5468f30..f3455c699f47 100644 --- a/src/Parameter.cpp +++ b/src/Parameter.cpp @@ -20,6 +20,7 @@ struct ParameterContents { std::vector buffer_constraints; Expr scalar_default, scalar_min, scalar_max, scalar_estimate; const bool is_buffer; + bool data_ever_set = false; MemoryType memory_type = MemoryType::Auto; ParameterContents(Type t, bool b, int d, const std::string &n) @@ -46,6 +47,10 @@ void destroy(const ParameterContents *p) { } // namespace Internal +void Parameter::check_data_ever_set() const { + user_assert(contents->data_ever_set) << "Parameter " << name() << " has never had a scalar value set.\n"; +} + void Parameter::check_defined() const { user_assert(defined()) << "Parameter is undefined\n"; } @@ -123,8 +128,14 @@ bool Parameter::is_buffer() const { return contents->is_buffer; } +bool Parameter::has_scalar_expr() const { + return defined() && !contents->is_buffer && contents->data_ever_set; +} + Expr Parameter::scalar_expr() const { check_is_scalar(); + // Redundant here, since every call to scalar<>() also checks this. + // check_data_ever_set(); const Type t = type(); if (t.is_float()) { switch (t.bits()) { @@ -198,16 +209,31 @@ void Parameter::set_buffer(const Buffer<> &b) { contents->buffer = b; } -void *Parameter::scalar_address() const { +const void *Parameter::read_only_scalar_address() const { check_is_scalar(); + // Code that calls this method is (presumably) going + // to read from the address, so complain if the scalar value + // has never been set. + check_data_ever_set(); return &contents->data; } uint64_t Parameter::scalar_raw_value() const { check_is_scalar(); + check_data_ever_set(); return contents->data; } +void Parameter::set_scalar(const Type &val_type, halide_scalar_value_t val) { + check_type(val_type); + // Setting this to zero isn't strictly necessary, but it does + // mean that the 'unused' bits of the field are never affected by what + // may have previously been there. + contents->data = 0; + memcpy(&contents->data, &val, val_type.bytes()); + contents->data_ever_set = true; +} + /** Tests if this handle is the same as another handle */ bool Parameter::same_as(const Parameter &other) const { return contents.same_as(other.contents); diff --git a/src/Parameter.h b/src/Parameter.h index 712380c2576e..424555ed9253 100644 --- a/src/Parameter.h +++ b/src/Parameter.h @@ -35,6 +35,7 @@ struct ParameterContents; /** A reference-counted handle to a parameter to a halide * pipeline. May be a scalar parameter or a buffer */ class Parameter { + void check_data_ever_set() const; void check_defined() const; void check_is_buffer() const; void check_is_scalar() const; @@ -54,8 +55,9 @@ class Parameter { /** Get the pointer to the current value of the scalar * parameter. For a given parameter, this address will never - * change. Only relevant when jitting. */ - void *scalar_address() const; + * change. Note that this can only be used to *read* from -- it must + * not be written to, so don't cast away the constness. Only relevant when jitting. */ + const void *read_only_scalar_address() const; /** Get the raw data of the current value of the scalar * parameter. Only relevant when serializing. */ @@ -112,27 +114,30 @@ class Parameter { template HALIDE_NO_USER_CODE_INLINE T scalar() const { check_type(type_of()); - return *((const T *)(scalar_address())); + check_data_ever_set(); + return *((const T *)(read_only_scalar_address())); } /** This returns the current value of scalar() - * as an Expr. */ + * as an Expr. If no value has ever been set, it will assert-fail */ Expr scalar_expr() const; + /** This returns true if scalar_expr() would return a valid Expr, + * false if not. */ + bool has_scalar_expr() const; + /** If the parameter is a scalar parameter, set its current * value. Only relevant when jitting */ template HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) { - check_type(type_of()); - *((T *)(scalar_address())) = val; + halide_scalar_value_t sv; + memcpy(&sv, &val, sizeof(val)); + set_scalar(type_of(), sv); } /** If the parameter is a scalar parameter, set its current * value. Only relevant when jitting */ - HALIDE_NO_USER_CODE_INLINE void set_scalar(const Type &val_type, halide_scalar_value_t val) { - check_type(val_type); - memcpy(scalar_address(), &val, val_type.bytes()); - } + void set_scalar(const Type &val_type, halide_scalar_value_t val); /** If the parameter is a buffer parameter, get its currently * bound buffer. Only relevant when jitting */ diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index d51438e275fe..631033404137 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -843,7 +843,7 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target } debug(2) << "JIT input ImageParam argument "; } else { - args_result.store[arg_index++] = p.scalar_address(); + args_result.store[arg_index++] = p.read_only_scalar_address(); debug(2) << "JIT input scalar argument "; } } diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 6e69490657f5..69e0979163b5 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -107,6 +107,7 @@ tests(GROUPS error undefined_pipeline_compile.cpp undefined_pipeline_realize.cpp undefined_rdom_dimension.cpp + uninitialized_param.cpp unknown_target.cpp vector_tile.cpp vectorize_dynamic.cpp diff --git a/test/error/uninitialized_param.cpp b/test/error/uninitialized_param.cpp new file mode 100644 index 000000000000..2fa8d47b2ec3 --- /dev/null +++ b/test/error/uninitialized_param.cpp @@ -0,0 +1,22 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + ImageParam image_param(Int(32), 2, "image_param"); + Param scalar_param("scalar_param"); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = image_param(x, y) + scalar_param; + + Buffer b(10, 10); + image_param.set(b); + + f.realize({10, 10}); + + printf("Success!\n"); + return 0; +} From c8bb1afd089589c413234ede59c4b5436ad8947a Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 18 Sep 2023 16:46:58 -0700 Subject: [PATCH 02/10] Fix broken tests --- test/correctness/func_clone.cpp | 8 -------- test/correctness/specialize.cpp | 4 ++++ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/test/correctness/func_clone.cpp b/test/correctness/func_clone.cpp index c5cfbe868b61..2e1d35f019e8 100644 --- a/test/correctness/func_clone.cpp +++ b/test/correctness/func_clone.cpp @@ -159,14 +159,6 @@ int update_defined_after_clone_test() { return 1; } - Buffer im = g.realize({200, 200}); - auto func = [](int x, int y) { - return ((0 <= x && x <= 99) && (0 <= y && y <= 99) && (x < y)) ? 3 * (x + y) : (x + y); - }; - if (check_image(im, func)) { - return 1; - } - for (bool param_value : {false, true}) { param.set(param_value); diff --git a/test/correctness/specialize.cpp b/test/correctness/specialize.cpp index 54c74ca8ada9..1a807003f72a 100644 --- a/test/correctness/specialize.cpp +++ b/test/correctness/specialize.cpp @@ -447,6 +447,8 @@ int main(int argc, char **argv) { Buffer input(3, 3), output(3, 3); // Shouldn't throw a bounds error: im.set(input); + cond1.set(false); + cond2.set(false); out.realize(output); if (if_then_else_count != 1) { @@ -476,6 +478,8 @@ int main(int argc, char **argv) { Buffer input(3, 3), output(3, 3); // Shouldn't throw a bounds error: im.set(input); + cond1.set(false); + cond2.set(false); out.realize(output); // There should have been 2 Ifs total: They are the From a60eb9911333c4870bb8f733380bcb68e45f8061 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 18 Sep 2023 18:01:38 -0700 Subject: [PATCH 03/10] Update Parameter.h --- src/Parameter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Parameter.h b/src/Parameter.h index 424555ed9253..47126fef7325 100644 --- a/src/Parameter.h +++ b/src/Parameter.h @@ -131,7 +131,7 @@ class Parameter { template HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) { halide_scalar_value_t sv; - memcpy(&sv, &val, sizeof(val)); + memcpy(&sv.u.u64, &val, sizeof(val)); set_scalar(type_of(), sv); } From 699ec871c584c9a5a60c6735226936c45c9227e4 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Tue, 19 Sep 2023 09:07:16 -0700 Subject: [PATCH 04/10] Update func_clone.cpp --- test/correctness/func_clone.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/correctness/func_clone.cpp b/test/correctness/func_clone.cpp index 2e1d35f019e8..8bf0b4d80e87 100644 --- a/test/correctness/func_clone.cpp +++ b/test/correctness/func_clone.cpp @@ -159,6 +159,15 @@ int update_defined_after_clone_test() { return 1; } + param.set(false); + Buffer im = g.realize({200, 200}); + auto func = [](int x, int y) { + return ((0 <= x && x <= 99) && (0 <= y && y <= 99) && (x < y)) ? 3 * (x + y) : (x + y); + }; + if (check_image(im, func)) { + return 1; + } + for (bool param_value : {false, true}) { param.set(param_value); From 296a58a7df4e8c221f79d2420130d93f730fdb44 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Tue, 19 Sep 2023 11:34:23 -0700 Subject: [PATCH 05/10] Fix Generators too --- src/Generator.h | 3 +- src/Pipeline.cpp | 22 ++++++++++++ test/error/CMakeLists.txt | 2 ++ test/error/uninitialized_param_2.cpp | 46 +++++++++++++++++++++++++ test/error/uninitialized_param_3.cpp | 50 ++++++++++++++++++++++++++++ 5 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 test/error/uninitialized_param_2.cpp create mode 100644 test/error/uninitialized_param_3.cpp diff --git a/src/Generator.h b/src/Generator.h index d0614b24efb9..9bc335b52ed7 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -2001,7 +2001,8 @@ class GeneratorInput_Scalar : public GeneratorInputImpl { void set_def_min_max() override { for (Parameter &p : this->parameters_) { - p.set_scalar(def_); + // No: we want to leave the Parameter unset here. + // p.set_scalar(def_); p.set_default_value(def_expr_); } } diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index 631033404137..27b164e7b626 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -70,6 +70,23 @@ std::string sanitize_function_name(const std::string &s) { return name; } +class FindVariablesWithUnsetParams : public IRVisitor { + using IRVisitor::visit; + + void visit(const Variable *op) override { + if (op->param.defined() && !op->param.is_buffer() && op->param.name() != "__user_context") { + user_assert(op->param.has_scalar_expr()) + << "You cannot call realize() or compile_to_callable() on Halide code that references the scalar Generator Input " << op->param.name() + << ", as the value will never be defined at Generator compile time. " + << "Consider scheduling the function as compute_root().memoize() instead.\n"; + } + IRVisitor::visit(op); + } + +public: + FindVariablesWithUnsetParams() = default; +}; + } // namespace namespace Internal { @@ -808,6 +825,11 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target bool is_bounds_inference, JITCallArgs &args_result) { user_assert(defined()) << "Can't realize an undefined Pipeline\n"; + FindVariablesWithUnsetParams find_unset; + for (const LoweredFunc &f : contents->module.functions()) { + f.body.accept(&find_unset); + } + size_t total_outputs = 0; for (const Func &out : this->outputs()) { total_outputs += out.outputs(); diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 69e0979163b5..902c16ee7679 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -108,6 +108,8 @@ tests(GROUPS error undefined_pipeline_realize.cpp undefined_rdom_dimension.cpp uninitialized_param.cpp + uninitialized_param_2.cpp + uninitialized_param_3.cpp unknown_target.cpp vector_tile.cpp vectorize_dynamic.cpp diff --git a/test/error/uninitialized_param_2.cpp b/test/error/uninitialized_param_2.cpp new file mode 100644 index 000000000000..1f4eed42027a --- /dev/null +++ b/test/error/uninitialized_param_2.cpp @@ -0,0 +1,46 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; +using namespace Halide::ConciseCasts; + +#include "Halide.h" + +namespace { + +Var x; + +class PleaseFail : public Halide::Generator { +public: + Input> input{"input"}; + Input scalar_input{"scalar_input"}; + Output> output{"output"}; + + void generate() { + Func lut_fn("lut_fn"); + lut_fn(x) = u8_sat(x * scalar_input / 255.f); + + // This should always fail, because it depends on a scalar input + // that *cannot* have a valid value at this point. + auto lut = lut_fn.realize({256}); + + output(x) = input(x) + lut[0](x); + } +}; + +} // namespace + +HALIDE_REGISTER_GENERATOR(PleaseFail, PleaseFail) + +int main(int argc, char **argv) { + Halide::Internal::ExecuteGeneratorArgs args; + args.output_dir = Internal::get_test_tmp_dir(); + args.output_types = std::set{OutputFileType::object}; + args.targets = std::vector{get_target_from_environment()}; + args.generator_name = "PleaseFail"; + execute_generator(args); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/uninitialized_param_3.cpp b/test/error/uninitialized_param_3.cpp new file mode 100644 index 000000000000..f7927840edaf --- /dev/null +++ b/test/error/uninitialized_param_3.cpp @@ -0,0 +1,50 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; +using namespace Halide::ConciseCasts; + +#include "Halide.h" + +namespace { + +Var x; + +class PleaseFail : public Halide::Generator { +public: + Input> input{"input"}; + Input scalar_input{"scalar_input"}; + Output> output{"output"}; + + void generate() { + Func lut_fn("lut_fn"); + lut_fn(x) = u8_sat(x * scalar_input / 255.f); + + // This should always fail, because it depends on a scalar input + // that *cannot* have a valid value at this point. + auto lut_callable = lut_fn.compile_to_callable({}); + + Buffer lut(256); + int r = lut_callable(lut); + assert(r == 0); + + output(x) = input(x) + lut(x); + } +}; + +} // namespace + +HALIDE_REGISTER_GENERATOR(PleaseFail, PleaseFail) + +int main(int argc, char **argv) { + Halide::Internal::ExecuteGeneratorArgs args; + args.output_dir = Internal::get_test_tmp_dir(); + args.output_types = std::set{OutputFileType::object}; + args.targets = std::vector{get_target_from_environment()}; + args.generator_name = "PleaseFail"; + execute_generator(args); + + printf("Success!\n"); + return 0; +} From 2fa1c73401dba77a532227d5a67bd3b8c9541be2 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 25 Sep 2023 13:04:02 -0700 Subject: [PATCH 06/10] Fixes --- .../src/halide/halide_/PyParameter.cpp | 5 -- src/Deserialization.cpp | 16 +++- src/InferArguments.cpp | 4 +- src/Parameter.cpp | 77 +++++++++---------- src/Parameter.h | 53 ++++++++----- src/Pipeline.cpp | 2 +- src/Serialization.cpp | 10 ++- src/halide_ir.fbs | 2 +- 8 files changed, 94 insertions(+), 75 deletions(-) diff --git a/python_bindings/src/halide/halide_/PyParameter.cpp b/python_bindings/src/halide/halide_/PyParameter.cpp index 87da6610a691..8464ed387d23 100644 --- a/python_bindings/src/halide/halide_/PyParameter.cpp +++ b/python_bindings/src/halide/halide_/PyParameter.cpp @@ -30,11 +30,6 @@ void define_parameter(py::module &m) { .def(py::init(), py::arg("p")) .def(py::init()) .def(py::init()) - .def(py::init &, int, const std::vector &, - MemoryType>()) - .def(py::init()) .def("_to_argument", [](const Parameter &p) -> Argument { return Argument(p.name(), p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar, diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index c56badf7b887..2f90fe4159a3 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -1159,14 +1159,24 @@ Parameter Deserializer::deserialize_parameter(const Serialize::Parameter *parame deserialize_vector(parameter->buffer_constraints(), &Deserializer::deserialize_buffer_constraint); const auto memory_type = deserialize_memory_type(parameter->memory_type()); - return Parameter(type, is_buffer, dimensions, name, Buffer<>(), host_alignment, buffer_constraints, memory_type); + return Parameter(type, dimensions, name, Buffer<>(), host_alignment, buffer_constraints, memory_type); } else { - const uint64_t data = parameter->data(); + static_assert(FLATBUFFERS_USE_STD_OPTIONAL); + const auto make_optional_halide_scalar_value_t = [](const std::optional &v) -> std::optional { + if (v.has_value()) { + halide_scalar_value_t scalar_data; + scalar_data.u.u64 = v.value(); + return std::optional(scalar_data); + } else { + return std::nullopt; + } + }; + const std::optional scalar_data = make_optional_halide_scalar_value_t(parameter->scalar_data()); const auto scalar_default = deserialize_expr(parameter->scalar_default_type(), parameter->scalar_default()); const auto scalar_min = deserialize_expr(parameter->scalar_min_type(), parameter->scalar_min()); const auto scalar_max = deserialize_expr(parameter->scalar_max_type(), parameter->scalar_max()); const auto scalar_estimate = deserialize_expr(parameter->scalar_estimate_type(), parameter->scalar_estimate()); - return Parameter(type, is_buffer, dimensions, name, data, scalar_default, scalar_min, scalar_max, scalar_estimate); + return Parameter(type, dimensions, name, std::move(scalar_data), scalar_default, scalar_min, scalar_max, scalar_estimate); } } diff --git a/src/InferArguments.cpp b/src/InferArguments.cpp index 9fabb47b9ef8..5a3e7c952a7a 100644 --- a/src/InferArguments.cpp +++ b/src/InferArguments.cpp @@ -198,9 +198,7 @@ class InferArguments : public IRGraphVisitor { ArgumentEstimates argument_estimates = p.get_argument_estimates(); if (!p.is_buffer()) { - // We don't want to crater here if a scalar param isn't set; - // instead, default to a zero of the right type, like we used to. - argument_estimates.scalar_def = p.has_scalar_expr() ? p.scalar_expr() : make_zero(p.type()); + argument_estimates.scalar_def = p.default_value(); argument_estimates.scalar_min = p.min_value(); argument_estimates.scalar_max = p.max_value(); argument_estimates.scalar_estimate = p.estimate(); diff --git a/src/Parameter.cpp b/src/Parameter.cpp index f3455c699f47..e196fbd3d87d 100644 --- a/src/Parameter.cpp +++ b/src/Parameter.cpp @@ -15,12 +15,11 @@ struct ParameterContents { const int dimensions; const std::string name; Buffer<> buffer; - uint64_t data = 0; + std::optional scalar_data; int host_alignment; std::vector buffer_constraints; Expr scalar_default, scalar_min, scalar_max, scalar_estimate; const bool is_buffer; - bool data_ever_set = false; MemoryType memory_type = MemoryType::Auto; ParameterContents(Type t, bool b, int d, const std::string &n) @@ -47,8 +46,8 @@ void destroy(const ParameterContents *p) { } // namespace Internal -void Parameter::check_data_ever_set() const { - user_assert(contents->data_ever_set) << "Parameter " << name() << " has never had a scalar value set.\n"; +void Parameter::check_has_scalar_data() const { + user_assert(contents->scalar_data.has_value()) << "Parameter " << name() << " has never had a scalar value set.\n"; } void Parameter::check_defined() const { @@ -87,21 +86,21 @@ Parameter::Parameter(const Type &t, bool is_buffer, int d, const std::string &na internal_assert(is_buffer || d == 0) << "Scalar parameters should be zero-dimensional"; } -Parameter::Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name, +Parameter::Parameter(const Type &t, int dimensions, const std::string &name, const Buffer &buffer, int host_alignment, const std::vector &buffer_constraints, MemoryType memory_type) - : contents(new Internal::ParameterContents(t, is_buffer, dimensions, name)) { + : contents(new Internal::ParameterContents(t, /*is_buffer*/ true, dimensions, name)) { contents->buffer = buffer; contents->host_alignment = host_alignment; contents->buffer_constraints = buffer_constraints; contents->memory_type = memory_type; } -Parameter::Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name, - uint64_t data, const Expr &scalar_default, const Expr &scalar_min, +Parameter::Parameter(const Type &t, int dimensions, const std::string &name, + std::optional scalar_data, const Expr &scalar_default, const Expr &scalar_min, const Expr &scalar_max, const Expr &scalar_estimate) - : contents(new Internal::ParameterContents(t, is_buffer, dimensions, name)) { - contents->data = data; + : contents(new Internal::ParameterContents(t, /*is_buffer*/ false, dimensions, name)) { + contents->scalar_data = std::move(scalar_data); contents->scalar_default = scalar_default; contents->scalar_min = scalar_min; contents->scalar_max = scalar_max; @@ -128,57 +127,57 @@ bool Parameter::is_buffer() const { return contents->is_buffer; } -bool Parameter::has_scalar_expr() const { - return defined() && !contents->is_buffer && contents->data_ever_set; +bool Parameter::has_scalar_value() const { + return defined() && !contents->is_buffer && contents->scalar_data.has_value(); } Expr Parameter::scalar_expr() const { check_is_scalar(); - // Redundant here, since every call to scalar<>() also checks this. - // check_data_ever_set(); + check_has_scalar_data(); + const auto sv = contents->scalar_data.value(); const Type t = type(); if (t.is_float()) { switch (t.bits()) { case 16: if (t.is_bfloat()) { - return Expr(scalar()); + return Expr(bfloat16_t::make_from_bits(sv.u.u16)); } else { - return Expr(scalar()); + return Expr(float16_t::make_from_bits(sv.u.u16)); } case 32: - return Expr(scalar()); + return Expr(sv.u.f32); case 64: - return Expr(scalar()); + return Expr(sv.u.f64); } } else if (t.is_int()) { switch (t.bits()) { case 8: - return Expr(scalar()); + return Expr(sv.u.i8); case 16: - return Expr(scalar()); + return Expr(sv.u.i16); case 32: - return Expr(scalar()); + return Expr(sv.u.i32); case 64: - return Expr(scalar()); + return Expr(sv.u.i64); } } else if (t.is_uint()) { switch (t.bits()) { case 1: - return Internal::make_bool(scalar()); + return Internal::make_bool(sv.u.b); case 8: - return Expr(scalar()); + return Expr(sv.u.u8); case 16: - return Expr(scalar()); + return Expr(sv.u.u16); case 32: - return Expr(scalar()); + return Expr(sv.u.u32); case 64: - return Expr(scalar()); + return Expr(sv.u.u64); } } else if (t.is_handle()) { // handles are always uint64 internally. switch (t.bits()) { case 64: - return Expr(scalar()); + return Expr(sv.u.u64); } } internal_error << "Unsupported type " << t << " in scalar_expr\n"; @@ -214,24 +213,24 @@ const void *Parameter::read_only_scalar_address() const { // Code that calls this method is (presumably) going // to read from the address, so complain if the scalar value // has never been set. - check_data_ever_set(); - return &contents->data; + check_has_scalar_data(); + return std::addressof(contents->scalar_data.value()); } -uint64_t Parameter::scalar_raw_value() const { +std::optional Parameter::scalar_data() const { + return defined() ? contents->scalar_data : std::nullopt; +} + +halide_scalar_value_t Parameter::scalar_data_checked(const Type &val_type) const { check_is_scalar(); - check_data_ever_set(); - return contents->data; + check_type(val_type); + check_has_scalar_data(); + return contents->scalar_data.value(); } void Parameter::set_scalar(const Type &val_type, halide_scalar_value_t val) { check_type(val_type); - // Setting this to zero isn't strictly necessary, but it does - // mean that the 'unused' bits of the field are never affected by what - // may have previously been there. - contents->data = 0; - memcpy(&contents->data, &val, val_type.bytes()); - contents->data_ever_set = true; + contents->scalar_data = std::optional(val); } /** Tests if this handle is the same as another handle */ diff --git a/src/Parameter.h b/src/Parameter.h index 47126fef7325..e50ca8135182 100644 --- a/src/Parameter.h +++ b/src/Parameter.h @@ -4,6 +4,7 @@ /** \file * Defines the internal representation of parameters to halide piplines */ +#include #include #include "Buffer.h" @@ -25,7 +26,9 @@ struct BufferConstraint { }; namespace Internal { + #ifdef WITH_SERIALIZATION +class Deserializer; class Serializer; #endif struct ParameterContents; @@ -35,7 +38,7 @@ struct ParameterContents; /** A reference-counted handle to a parameter to a halide * pipeline. May be a scalar parameter or a buffer */ class Parameter { - void check_data_ever_set() const; + void check_has_scalar_data() const; void check_defined() const; void check_is_buffer() const; void check_is_scalar() const; @@ -46,9 +49,10 @@ class Parameter { Internal::IntrusivePtr contents; #ifdef WITH_SERIALIZATION - friend class Internal::Serializer; //< for scalar_raw_value() + friend class Internal::Deserializer; //< for scalar_data() + friend class Internal::Serializer; //< for scalar_data() #endif - friend class Pipeline; //< for scalar_address() + friend class Pipeline; //< for read_only_scalar_address() /** Get the raw currently-bound buffer. null if unbound */ const halide_buffer_t *raw_buffer() const; @@ -59,9 +63,23 @@ class Parameter { * not be written to, so don't cast away the constness. Only relevant when jitting. */ const void *read_only_scalar_address() const; - /** Get the raw data of the current value of the scalar - * parameter. Only relevant when serializing. */ - uint64_t scalar_raw_value() const; + /** If the Parameter is a scalar, and the scalar data is valid, return + * the scalar data. Otherwise, return nullopt. */ + std::optional scalar_data() const; + + /** If the Parameter is a scalar of the given type and has a valid scalar value, return it. + * Otherwise, assert-fail. */ + halide_scalar_value_t scalar_data_checked(const Type &val_type) const; + + /** Construct a new buffer parameter via deserialization. */ + Parameter(const Type &t, int dimensions, const std::string &name, + const Buffer &buffer, int host_alignment, const std::vector &buffer_constraints, + MemoryType memory_type); + + /** Construct a new scalar parameter via deserialization. */ + Parameter(const Type &t, int dimensions, const std::string &name, + std::optional scalar_data, const Expr &scalar_default, const Expr &scalar_min, + const Expr &scalar_max, const Expr &scalar_estimate); public: /** Construct a new undefined handle */ @@ -83,15 +101,6 @@ class Parameter { * explicitly specified (as opposed to autogenerated). */ Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name); - /** Construct a new parameter via deserialization. */ - Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name, - const Buffer &buffer, int host_alignment, const std::vector &buffer_constraints, - MemoryType memory_type); - - Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name, - uint64_t data, const Expr &scalar_default, const Expr &scalar_min, - const Expr &scalar_max, const Expr &scalar_estimate); - Parameter(const Parameter &) = default; Parameter &operator=(const Parameter &) = default; Parameter(Parameter &&) = default; @@ -113,18 +122,20 @@ class Parameter { * bound value. Only relevant when jitting */ template HALIDE_NO_USER_CODE_INLINE T scalar() const { - check_type(type_of()); - check_data_ever_set(); - return *((const T *)(read_only_scalar_address())); + static_assert(sizeof(T) <= sizeof(halide_scalar_value_t)); + const auto sv = scalar_data_checked(type_of()); + T t; + memcpy(&t, &sv.u.u64, sizeof(t)); + return t; } - /** This returns the current value of scalar() - * as an Expr. If no value has ever been set, it will assert-fail */ + /** This returns the current value of scalar() as an Expr. + * If the Parameter is not scalar, or its scalar data is not valid, this will assert-fail. */ Expr scalar_expr() const; /** This returns true if scalar_expr() would return a valid Expr, * false if not. */ - bool has_scalar_expr() const; + bool has_scalar_value() const; /** If the parameter is a scalar parameter, set its current * value. Only relevant when jitting */ diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index 27b164e7b626..c9ff0a10a580 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -75,7 +75,7 @@ class FindVariablesWithUnsetParams : public IRVisitor { void visit(const Variable *op) override { if (op->param.defined() && !op->param.is_buffer() && op->param.name() != "__user_context") { - user_assert(op->param.has_scalar_expr()) + user_assert(op->param.has_scalar_value()) << "You cannot call realize() or compile_to_callable() on Halide code that references the scalar Generator Input " << op->param.name() << ", as the value will never be defined at Generator compile time. " << "Consider scheduling the function as compute_root().memoize() instead.\n"; diff --git a/src/Serialization.cpp b/src/Serialization.cpp index cc9daa9fe325..d9afcb5e083b 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -1298,13 +1298,19 @@ Offset Serializer::serialize_parameter(FlatBufferBuilder & return Serialize::CreateParameter(builder, defined, is_buffer, type_serialized, dimensions, name_serialized, host_alignment, builder.CreateVector(buffer_constraints_serialized), memory_type_serialized); } else { - const uint64_t data = parameter.scalar_raw_value(); + static_assert(FLATBUFFERS_USE_STD_OPTIONAL); + const auto make_optional_u64 = [](const std::optional &v) -> std::optional { + return v.has_value() ? + std::optional(v.value().u.u64) : + std::nullopt; + }; + const auto scalar_data = make_optional_u64(parameter.scalar_data()); const auto scalar_default_serialized = serialize_expr(builder, parameter.default_value()); const auto scalar_min_serialized = serialize_expr(builder, parameter.min_value()); const auto scalar_max_serialized = serialize_expr(builder, parameter.max_value()); const auto scalar_estimate_serialized = serialize_expr(builder, parameter.estimate()); return Serialize::CreateParameter(builder, defined, is_buffer, type_serialized, - dimensions, name_serialized, 0, 0, Serialize::MemoryType_Auto, data, + dimensions, name_serialized, 0, 0, Serialize::MemoryType_Auto, scalar_data, scalar_default_serialized.first, scalar_default_serialized.second, scalar_min_serialized.first, scalar_min_serialized.second, scalar_max_serialized.first, scalar_max_serialized.second, diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index cab57a7dc044..685d20995761 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -618,7 +618,7 @@ table Parameter { host_alignment: int32; buffer_constraints: [BufferConstraint]; memory_type: MemoryType; - data: uint64; + scalar_data: uint64 = null; // Note: it is valid for this to be omitted, even if is_buffer = false. scalar_default: Expr; scalar_min: Expr; scalar_max: Expr; From 3963c5e008b3ddb1b4ba6ed5b27176e31f16aa1e Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 25 Sep 2023 13:12:43 -0700 Subject: [PATCH 07/10] Update InferArguments.cpp --- src/InferArguments.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/InferArguments.cpp b/src/InferArguments.cpp index 5a3e7c952a7a..020d4184642d 100644 --- a/src/InferArguments.cpp +++ b/src/InferArguments.cpp @@ -198,7 +198,9 @@ class InferArguments : public IRGraphVisitor { ArgumentEstimates argument_estimates = p.get_argument_estimates(); if (!p.is_buffer()) { - argument_estimates.scalar_def = p.default_value(); + // We don't want to crater here if a scalar param isn't set; + // instead, default to a zero of the right type, like we used to. + argument_estimates.scalar_def = p.has_scalar_value() ? p.scalar_expr() : make_zero(p.type()); argument_estimates.scalar_min = p.min_value(); argument_estimates.scalar_max = p.max_value(); argument_estimates.scalar_estimate = p.estimate(); From 57838ba7a18fc53fc5833b72b449b664bfb18feb Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 25 Sep 2023 15:25:35 -0700 Subject: [PATCH 08/10] Fixes --- src/Deserialization.cpp | 2 +- src/Parameter.cpp | 27 +++++++++++---------------- src/Parameter.h | 9 ++++++--- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index 2f90fe4159a3..b882f8b0b157 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -1176,7 +1176,7 @@ Parameter Deserializer::deserialize_parameter(const Serialize::Parameter *parame const auto scalar_min = deserialize_expr(parameter->scalar_min_type(), parameter->scalar_min()); const auto scalar_max = deserialize_expr(parameter->scalar_max_type(), parameter->scalar_max()); const auto scalar_estimate = deserialize_expr(parameter->scalar_estimate_type(), parameter->scalar_estimate()); - return Parameter(type, dimensions, name, std::move(scalar_data), scalar_default, scalar_min, scalar_max, scalar_estimate); + return Parameter(type, dimensions, name, scalar_data, scalar_default, scalar_min, scalar_max, scalar_estimate); } } diff --git a/src/Parameter.cpp b/src/Parameter.cpp index e196fbd3d87d..6338c6b82de2 100644 --- a/src/Parameter.cpp +++ b/src/Parameter.cpp @@ -46,10 +46,6 @@ void destroy(const ParameterContents *p) { } // namespace Internal -void Parameter::check_has_scalar_data() const { - user_assert(contents->scalar_data.has_value()) << "Parameter " << name() << " has never had a scalar value set.\n"; -} - void Parameter::check_defined() const { user_assert(defined()) << "Parameter is undefined\n"; } @@ -97,10 +93,10 @@ Parameter::Parameter(const Type &t, int dimensions, const std::string &name, } Parameter::Parameter(const Type &t, int dimensions, const std::string &name, - std::optional scalar_data, const Expr &scalar_default, const Expr &scalar_min, + const std::optional &scalar_data, const Expr &scalar_default, const Expr &scalar_min, const Expr &scalar_max, const Expr &scalar_estimate) : contents(new Internal::ParameterContents(t, /*is_buffer*/ false, dimensions, name)) { - contents->scalar_data = std::move(scalar_data); + contents->scalar_data = scalar_data; contents->scalar_default = scalar_default; contents->scalar_min = scalar_min; contents->scalar_max = scalar_max; @@ -132,9 +128,7 @@ bool Parameter::has_scalar_value() const { } Expr Parameter::scalar_expr() const { - check_is_scalar(); - check_has_scalar_data(); - const auto sv = contents->scalar_data.value(); + const auto sv = scalar_data_checked(); const Type t = type(); if (t.is_float()) { switch (t.bits()) { @@ -210,10 +204,7 @@ void Parameter::set_buffer(const Buffer<> &b) { const void *Parameter::read_only_scalar_address() const { check_is_scalar(); - // Code that calls this method is (presumably) going - // to read from the address, so complain if the scalar value - // has never been set. - check_has_scalar_data(); + user_assert(contents->scalar_data.has_value()) << "Parameter " << name() << " does not have a valid scalar value.\n"; return std::addressof(contents->scalar_data.value()); } @@ -221,13 +212,17 @@ std::optional Parameter::scalar_data() const { return defined() ? contents->scalar_data : std::nullopt; } -halide_scalar_value_t Parameter::scalar_data_checked(const Type &val_type) const { +halide_scalar_value_t Parameter::scalar_data_checked() const { check_is_scalar(); - check_type(val_type); - check_has_scalar_data(); + user_assert(contents->scalar_data.has_value()) << "Parameter " << name() << " does not have a valid scalar value.\n"; return contents->scalar_data.value(); } +halide_scalar_value_t Parameter::scalar_data_checked(const Type &val_type) const { + check_type(val_type); + return scalar_data_checked(); +} + void Parameter::set_scalar(const Type &val_type, halide_scalar_value_t val) { check_type(val_type); contents->scalar_data = std::optional(val); diff --git a/src/Parameter.h b/src/Parameter.h index e50ca8135182..09840683980c 100644 --- a/src/Parameter.h +++ b/src/Parameter.h @@ -38,7 +38,6 @@ struct ParameterContents; /** A reference-counted handle to a parameter to a halide * pipeline. May be a scalar parameter or a buffer */ class Parameter { - void check_has_scalar_data() const; void check_defined() const; void check_is_buffer() const; void check_is_scalar() const; @@ -67,7 +66,11 @@ class Parameter { * the scalar data. Otherwise, return nullopt. */ std::optional scalar_data() const; - /** If the Parameter is a scalar of the given type and has a valid scalar value, return it. + /** If the Parameter is a scalar and has a valid scalar value, return it. + * Otherwise, assert-fail. */ + halide_scalar_value_t scalar_data_checked() const; + + /** If the Parameter is a scalar *of the given type* and has a valid scalar value, return it. * Otherwise, assert-fail. */ halide_scalar_value_t scalar_data_checked(const Type &val_type) const; @@ -78,7 +81,7 @@ class Parameter { /** Construct a new scalar parameter via deserialization. */ Parameter(const Type &t, int dimensions, const std::string &name, - std::optional scalar_data, const Expr &scalar_default, const Expr &scalar_min, + const std::optional &scalar_data, const Expr &scalar_default, const Expr &scalar_min, const Expr &scalar_max, const Expr &scalar_estimate); public: From b914ccd17dfed3dde36b48a5e8943eecce4e5919 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 25 Sep 2023 16:10:59 -0700 Subject: [PATCH 09/10] pacify clang-tidy --- src/Parameter.cpp | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/Parameter.cpp b/src/Parameter.cpp index 6338c6b82de2..d9616b5bebf8 100644 --- a/src/Parameter.cpp +++ b/src/Parameter.cpp @@ -204,8 +204,16 @@ void Parameter::set_buffer(const Buffer<> &b) { const void *Parameter::read_only_scalar_address() const { check_is_scalar(); - user_assert(contents->scalar_data.has_value()) << "Parameter " << name() << " does not have a valid scalar value.\n"; - return std::addressof(contents->scalar_data.value()); + // Use explicit if here (rather than user_assert) so that we don't + // have to disable bugprone-unchecked-optional-access in clang-tidy, + // which is a useful check. + const auto &sv = contents->scalar_data; + if (sv.has_value()) { + return std::addressof(sv.value()); + } else { + user_error << "Parameter " << name() << " does not have a valid scalar value.\n"; + return nullptr; + } } std::optional Parameter::scalar_data() const { @@ -214,8 +222,18 @@ std::optional Parameter::scalar_data() const { halide_scalar_value_t Parameter::scalar_data_checked() const { check_is_scalar(); - user_assert(contents->scalar_data.has_value()) << "Parameter " << name() << " does not have a valid scalar value.\n"; - return contents->scalar_data.value(); + // Use explicit if here (rather than user_assert) so that we don't + // have to disable bugprone-unchecked-optional-access in clang-tidy, + // which is a useful check. + halide_scalar_value_t result; + const auto &sv = contents->scalar_data; + if (sv.has_value()) { + result = sv.value(); + } else { + user_error << "Parameter " << name() << " does not have a valid scalar value.\n"; + result.u.u64 = 0; // silence "possibly uninitialized" compiler warning + } + return result; } halide_scalar_value_t Parameter::scalar_data_checked(const Type &val_type) const { From 70135b09b0f4e176c269cfbacf07cb854432b5ba Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Tue, 26 Sep 2023 11:06:36 -0700 Subject: [PATCH 10/10] fixes --- src/Pipeline.cpp | 22 ------------ test/error/CMakeLists.txt | 1 - test/error/uninitialized_param_3.cpp | 50 ---------------------------- 3 files changed, 73 deletions(-) delete mode 100644 test/error/uninitialized_param_3.cpp diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index c9ff0a10a580..631033404137 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -70,23 +70,6 @@ std::string sanitize_function_name(const std::string &s) { return name; } -class FindVariablesWithUnsetParams : public IRVisitor { - using IRVisitor::visit; - - void visit(const Variable *op) override { - if (op->param.defined() && !op->param.is_buffer() && op->param.name() != "__user_context") { - user_assert(op->param.has_scalar_value()) - << "You cannot call realize() or compile_to_callable() on Halide code that references the scalar Generator Input " << op->param.name() - << ", as the value will never be defined at Generator compile time. " - << "Consider scheduling the function as compute_root().memoize() instead.\n"; - } - IRVisitor::visit(op); - } - -public: - FindVariablesWithUnsetParams() = default; -}; - } // namespace namespace Internal { @@ -825,11 +808,6 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target bool is_bounds_inference, JITCallArgs &args_result) { user_assert(defined()) << "Can't realize an undefined Pipeline\n"; - FindVariablesWithUnsetParams find_unset; - for (const LoweredFunc &f : contents->module.functions()) { - f.body.accept(&find_unset); - } - size_t total_outputs = 0; for (const Func &out : this->outputs()) { total_outputs += out.outputs(); diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 902c16ee7679..dde9d35fea0b 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -109,7 +109,6 @@ tests(GROUPS error undefined_rdom_dimension.cpp uninitialized_param.cpp uninitialized_param_2.cpp - uninitialized_param_3.cpp unknown_target.cpp vector_tile.cpp vectorize_dynamic.cpp diff --git a/test/error/uninitialized_param_3.cpp b/test/error/uninitialized_param_3.cpp deleted file mode 100644 index f7927840edaf..000000000000 --- a/test/error/uninitialized_param_3.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "Halide.h" -#include "halide_test_dirs.h" -#include - -using namespace Halide; -using namespace Halide::ConciseCasts; - -#include "Halide.h" - -namespace { - -Var x; - -class PleaseFail : public Halide::Generator { -public: - Input> input{"input"}; - Input scalar_input{"scalar_input"}; - Output> output{"output"}; - - void generate() { - Func lut_fn("lut_fn"); - lut_fn(x) = u8_sat(x * scalar_input / 255.f); - - // This should always fail, because it depends on a scalar input - // that *cannot* have a valid value at this point. - auto lut_callable = lut_fn.compile_to_callable({}); - - Buffer lut(256); - int r = lut_callable(lut); - assert(r == 0); - - output(x) = input(x) + lut(x); - } -}; - -} // namespace - -HALIDE_REGISTER_GENERATOR(PleaseFail, PleaseFail) - -int main(int argc, char **argv) { - Halide::Internal::ExecuteGeneratorArgs args; - args.output_dir = Internal::get_test_tmp_dir(); - args.output_types = std::set{OutputFileType::object}; - args.targets = std::vector{get_target_from_environment()}; - args.generator_name = "PleaseFail"; - execute_generator(args); - - printf("Success!\n"); - return 0; -}