Skip to content

Commit

Permalink
Prevent use of uninitialized scalar Parameters in JIT code (#7847, pa…
Browse files Browse the repository at this point in the history
…rtial)
  • Loading branch information
steven-johnson committed Sep 18, 2023
1 parent 68a0341 commit 9495153
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 13 deletions.
5 changes: 4 additions & 1 deletion src/InferArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "ExternFuncArgument.h"
#include "Function.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "InferArguments.h"

Expand Down Expand Up @@ -197,7 +198,9 @@ class InferArguments : public IRGraphVisitor {

ArgumentEstimates argument_estimates = p.get_argument_estimates();
if (!p.is_buffer()) {
argument_estimates.scalar_def = p.scalar_expr();
// We don't want to crater here if a scalar param isn't set;
// instead, default to a zero of the right type, like we used to.
argument_estimates.scalar_def = p.has_scalar_expr() ? p.scalar_expr() : make_zero(p.type());
argument_estimates.scalar_min = p.min_value();
argument_estimates.scalar_max = p.max_value();
argument_estimates.scalar_estimate = p.estimate();
Expand Down
28 changes: 27 additions & 1 deletion src/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct ParameterContents {
std::vector<BufferConstraint> buffer_constraints;
Expr scalar_default, scalar_min, scalar_max, scalar_estimate;
const bool is_buffer;
bool data_ever_set = false;
MemoryType memory_type = MemoryType::Auto;

ParameterContents(Type t, bool b, int d, const std::string &n)
Expand All @@ -46,6 +47,10 @@ void destroy<Halide::Internal::ParameterContents>(const ParameterContents *p) {

} // namespace Internal

void Parameter::check_data_ever_set() const {
user_assert(contents->data_ever_set) << "Parameter " << name() << " has never had a scalar value set.\n";
}

void Parameter::check_defined() const {
user_assert(defined()) << "Parameter is undefined\n";
}
Expand Down Expand Up @@ -123,8 +128,14 @@ bool Parameter::is_buffer() const {
return contents->is_buffer;
}

bool Parameter::has_scalar_expr() const {
return defined() && !contents->is_buffer && contents->data_ever_set;
}

Expr Parameter::scalar_expr() const {
check_is_scalar();
// Redundant here, since every call to scalar<>() also checks this.
// check_data_ever_set();
const Type t = type();
if (t.is_float()) {
switch (t.bits()) {
Expand Down Expand Up @@ -198,16 +209,31 @@ void Parameter::set_buffer(const Buffer<> &b) {
contents->buffer = b;
}

void *Parameter::scalar_address() const {
const void *Parameter::read_only_scalar_address() const {
check_is_scalar();
// Code that calls this method is (presumably) going
// to read from the address, so complain if the scalar value
// has never been set.
check_data_ever_set();
return &contents->data;
}

uint64_t Parameter::scalar_raw_value() const {
check_is_scalar();
check_data_ever_set();
return contents->data;
}

void Parameter::set_scalar(const Type &val_type, halide_scalar_value_t val) {
check_type(val_type);
// Setting this to zero isn't strictly necessary, but it does
// mean that the 'unused' bits of the field are never affected by what
// may have previously been there.
contents->data = 0;
memcpy(&contents->data, &val, val_type.bytes());
contents->data_ever_set = true;
}

/** Tests if this handle is the same as another handle */
bool Parameter::same_as(const Parameter &other) const {
return contents.same_as(other.contents);
Expand Down
25 changes: 15 additions & 10 deletions src/Parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct ParameterContents;
/** A reference-counted handle to a parameter to a halide
* pipeline. May be a scalar parameter or a buffer */
class Parameter {
void check_data_ever_set() const;
void check_defined() const;
void check_is_buffer() const;
void check_is_scalar() const;
Expand All @@ -54,8 +55,9 @@ class Parameter {

/** Get the pointer to the current value of the scalar
* parameter. For a given parameter, this address will never
* change. Only relevant when jitting. */
void *scalar_address() const;
* change. Note that this can only be used to *read* from -- it must
* not be written to, so don't cast away the constness. Only relevant when jitting. */
const void *read_only_scalar_address() const;

/** Get the raw data of the current value of the scalar
* parameter. Only relevant when serializing. */
Expand Down Expand Up @@ -112,27 +114,30 @@ class Parameter {
template<typename T>
HALIDE_NO_USER_CODE_INLINE T scalar() const {
check_type(type_of<T>());
return *((const T *)(scalar_address()));
check_data_ever_set();
return *((const T *)(read_only_scalar_address()));
}

/** This returns the current value of scalar<type()>()
* as an Expr. */
* as an Expr. If no value has ever been set, it will assert-fail */
Expr scalar_expr() const;

/** This returns true if scalar_expr() would return a valid Expr,
* false if not. */
bool has_scalar_expr() const;

/** If the parameter is a scalar parameter, set its current
* value. Only relevant when jitting */
template<typename T>
HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) {
check_type(type_of<T>());
*((T *)(scalar_address())) = val;
halide_scalar_value_t sv;
memcpy(&sv, &val, sizeof(val));
set_scalar(type_of<T>(), sv);
}

/** If the parameter is a scalar parameter, set its current
* value. Only relevant when jitting */
HALIDE_NO_USER_CODE_INLINE void set_scalar(const Type &val_type, halide_scalar_value_t val) {
check_type(val_type);
memcpy(scalar_address(), &val, val_type.bytes());
}
void set_scalar(const Type &val_type, halide_scalar_value_t val);

/** If the parameter is a buffer parameter, get its currently
* bound buffer. Only relevant when jitting */
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target
}
debug(2) << "JIT input ImageParam argument ";
} else {
args_result.store[arg_index++] = p.scalar_address();
args_result.store[arg_index++] = p.read_only_scalar_address();
debug(2) << "JIT input scalar argument ";
}
}
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ tests(GROUPS error
undefined_pipeline_compile.cpp
undefined_pipeline_realize.cpp
undefined_rdom_dimension.cpp
uninitialized_param.cpp
unknown_target.cpp
vector_tile.cpp
vectorize_dynamic.cpp
Expand Down
22 changes: 22 additions & 0 deletions test/error/uninitialized_param.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
ImageParam image_param(Int(32), 2, "image_param");
Param<int> scalar_param("scalar_param");

Var x("x"), y("y");
Func f("f");

f(x, y) = image_param(x, y) + scalar_param;

Buffer<int> b(10, 10);
image_param.set(b);

f.realize({10, 10});

printf("Success!\n");
return 0;
}

0 comments on commit 9495153

Please sign in to comment.