Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent use of uninitialized scalar Parameters in JIT code (#7847, partial) #7853

Merged
merged 11 commits into from
Sep 27, 2023
5 changes: 4 additions & 1 deletion src/InferArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "ExternFuncArgument.h"
#include "Function.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "InferArguments.h"

Expand Down Expand Up @@ -197,7 +198,9 @@ class InferArguments : public IRGraphVisitor {

ArgumentEstimates argument_estimates = p.get_argument_estimates();
if (!p.is_buffer()) {
argument_estimates.scalar_def = p.scalar_expr();
// We don't want to crater here if a scalar param isn't set;
// instead, default to a zero of the right type, like we used to.
argument_estimates.scalar_def = p.has_scalar_expr() ? p.scalar_expr() : make_zero(p.type());
argument_estimates.scalar_min = p.min_value();
argument_estimates.scalar_max = p.max_value();
argument_estimates.scalar_estimate = p.estimate();
Expand Down
28 changes: 27 additions & 1 deletion src/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct ParameterContents {
std::vector<BufferConstraint> buffer_constraints;
Expr scalar_default, scalar_min, scalar_max, scalar_estimate;
const bool is_buffer;
bool data_ever_set = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the place we realize whenever we add sth to the frontend IR, serialization needs to be updated at the same time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_ever_set looks like a field that needs to be consistent across serialization (i.e. A parameter that has a data set should also be known as data-ever-set after serialization roundtrip). In this case, we need to add this field to the serialization implementation to properly serialize it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... didn't think about that.

I'm actually kinda unsure as to whether we should be saving the scalar value of the Param (if any) to the Serialization or not -- I know we do now (and it's required for the serialization-round-trip-via-JIT hack), but conceptually, it seems like the wrong thing to do.

A Parameter (either scalar or buffer) is conceptually just a placeholder with required-type information. Is it really desirable (in the general sense) to save whatever happens to be stuff in there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does seem weird to me at first as well, had to admit it.

But now I treat this new field in the PR same/similar to the purpose of defined, then it makes more sense

MemoryType memory_type = MemoryType::Auto;

ParameterContents(Type t, bool b, int d, const std::string &n)
Expand All @@ -46,6 +47,10 @@ void destroy<Halide::Internal::ParameterContents>(const ParameterContents *p) {

} // namespace Internal

void Parameter::check_data_ever_set() const {
user_assert(contents->data_ever_set) << "Parameter " << name() << " has never had a scalar value set.\n";
}

void Parameter::check_defined() const {
user_assert(defined()) << "Parameter is undefined\n";
}
Expand Down Expand Up @@ -123,8 +128,14 @@ bool Parameter::is_buffer() const {
return contents->is_buffer;
}

bool Parameter::has_scalar_expr() const {
return defined() && !contents->is_buffer && contents->data_ever_set;
}

Expr Parameter::scalar_expr() const {
check_is_scalar();
// Redundant here, since every call to scalar<>() also checks this.
// check_data_ever_set();
const Type t = type();
if (t.is_float()) {
switch (t.bits()) {
Expand Down Expand Up @@ -198,16 +209,31 @@ void Parameter::set_buffer(const Buffer<> &b) {
contents->buffer = b;
}

void *Parameter::scalar_address() const {
const void *Parameter::read_only_scalar_address() const {
check_is_scalar();
// Code that calls this method is (presumably) going
// to read from the address, so complain if the scalar value
// has never been set.
check_data_ever_set();
return &contents->data;
}

uint64_t Parameter::scalar_raw_value() const {
check_is_scalar();
check_data_ever_set();
return contents->data;
}

void Parameter::set_scalar(const Type &val_type, halide_scalar_value_t val) {
check_type(val_type);
// Setting this to zero isn't strictly necessary, but it does
// mean that the 'unused' bits of the field are never affected by what
// may have previously been there.
contents->data = 0;
memcpy(&contents->data, &val, val_type.bytes());
contents->data_ever_set = true;
}

/** Tests if this handle is the same as another handle */
bool Parameter::same_as(const Parameter &other) const {
return contents.same_as(other.contents);
Expand Down
25 changes: 15 additions & 10 deletions src/Parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct ParameterContents;
/** A reference-counted handle to a parameter to a halide
* pipeline. May be a scalar parameter or a buffer */
class Parameter {
void check_data_ever_set() const;
void check_defined() const;
void check_is_buffer() const;
void check_is_scalar() const;
Expand All @@ -54,8 +55,9 @@ class Parameter {

/** Get the pointer to the current value of the scalar
* parameter. For a given parameter, this address will never
* change. Only relevant when jitting. */
void *scalar_address() const;
* change. Note that this can only be used to *read* from -- it must
* not be written to, so don't cast away the constness. Only relevant when jitting. */
const void *read_only_scalar_address() const;

/** Get the raw data of the current value of the scalar
* parameter. Only relevant when serializing. */
Expand Down Expand Up @@ -112,27 +114,30 @@ class Parameter {
template<typename T>
HALIDE_NO_USER_CODE_INLINE T scalar() const {
check_type(type_of<T>());
return *((const T *)(scalar_address()));
check_data_ever_set();
return *((const T *)(read_only_scalar_address()));
}

/** This returns the current value of scalar<type()>()
* as an Expr. */
* as an Expr. If no value has ever been set, it will assert-fail */
Expr scalar_expr() const;

/** This returns true if scalar_expr() would return a valid Expr,
* false if not. */
bool has_scalar_expr() const;

/** If the parameter is a scalar parameter, set its current
* value. Only relevant when jitting */
template<typename T>
HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) {
check_type(type_of<T>());
*((T *)(scalar_address())) = val;
halide_scalar_value_t sv;
memcpy(&sv.u.u64, &val, sizeof(val));
set_scalar(type_of<T>(), sv);
}

/** If the parameter is a scalar parameter, set its current
* value. Only relevant when jitting */
HALIDE_NO_USER_CODE_INLINE void set_scalar(const Type &val_type, halide_scalar_value_t val) {
check_type(val_type);
memcpy(scalar_address(), &val, val_type.bytes());
}
void set_scalar(const Type &val_type, halide_scalar_value_t val);

/** If the parameter is a buffer parameter, get its currently
* bound buffer. Only relevant when jitting */
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target
}
debug(2) << "JIT input ImageParam argument ";
} else {
args_result.store[arg_index++] = p.scalar_address();
args_result.store[arg_index++] = p.read_only_scalar_address();
debug(2) << "JIT input scalar argument ";
}
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/func_clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ int update_defined_after_clone_test() {
return 1;
}

param.set(false);
Buffer<int> im = g.realize({200, 200});
steven-johnson marked this conversation as resolved.
Show resolved Hide resolved
auto func = [](int x, int y) {
return ((0 <= x && x <= 99) && (0 <= y && y <= 99) && (x < y)) ? 3 * (x + y) : (x + y);
Expand Down
4 changes: 4 additions & 0 deletions test/correctness/specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ int main(int argc, char **argv) {
Buffer<int> input(3, 3), output(3, 3);
// Shouldn't throw a bounds error:
im.set(input);
cond1.set(false);
cond2.set(false);
out.realize(output);

if (if_then_else_count != 1) {
Expand Down Expand Up @@ -476,6 +478,8 @@ int main(int argc, char **argv) {
Buffer<int> input(3, 3), output(3, 3);
// Shouldn't throw a bounds error:
im.set(input);
cond1.set(false);
cond2.set(false);
out.realize(output);

// There should have been 2 Ifs total: They are the
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ tests(GROUPS error
undefined_pipeline_compile.cpp
undefined_pipeline_realize.cpp
undefined_rdom_dimension.cpp
uninitialized_param.cpp
unknown_target.cpp
vector_tile.cpp
vectorize_dynamic.cpp
Expand Down
22 changes: 22 additions & 0 deletions test/error/uninitialized_param.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
ImageParam image_param(Int(32), 2, "image_param");
Param<int> scalar_param("scalar_param");

Var x("x"), y("y");
Func f("f");

f(x, y) = image_param(x, y) + scalar_param;

Buffer<int> b(10, 10);
image_param.set(b);

f.realize({10, 10});

printf("Success!\n");
return 0;
}