Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent use of uninitialized scalar Parameters in JIT code (#7847, partial) #7853

Merged
merged 11 commits into from
Sep 27, 2023
5 changes: 0 additions & 5 deletions python_bindings/src/halide/halide_/PyParameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ void define_parameter(py::module &m) {
.def(py::init<const Parameter &>(), py::arg("p"))
.def(py::init<const Type &, bool, int>())
.def(py::init<const Type &, bool, int, const std::string &>())
.def(py::init<const Type &, bool, int, const std::string &,
const Buffer<void> &, int, const std::vector<BufferConstraint> &,
MemoryType>())
.def(py::init<const Type &, bool, int, const std::string &,
uint64_t, const Expr &, const Expr &, const Expr &, const Expr &>())
.def("_to_argument", [](const Parameter &p) -> Argument {
return Argument(p.name(),
p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar,
Expand Down
16 changes: 13 additions & 3 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,14 +1159,24 @@ Parameter Deserializer::deserialize_parameter(const Serialize::Parameter *parame
deserialize_vector<Serialize::BufferConstraint, BufferConstraint>(parameter->buffer_constraints(),
&Deserializer::deserialize_buffer_constraint);
const auto memory_type = deserialize_memory_type(parameter->memory_type());
return Parameter(type, is_buffer, dimensions, name, Buffer<>(), host_alignment, buffer_constraints, memory_type);
return Parameter(type, dimensions, name, Buffer<>(), host_alignment, buffer_constraints, memory_type);
} else {
const uint64_t data = parameter->data();
static_assert(FLATBUFFERS_USE_STD_OPTIONAL);
const auto make_optional_halide_scalar_value_t = [](const std::optional<uint64_t> &v) -> std::optional<halide_scalar_value_t> {
if (v.has_value()) {
halide_scalar_value_t scalar_data;
scalar_data.u.u64 = v.value();
return std::optional<halide_scalar_value_t>(scalar_data);
} else {
return std::nullopt;
}
};
const std::optional<halide_scalar_value_t> scalar_data = make_optional_halide_scalar_value_t(parameter->scalar_data());
const auto scalar_default = deserialize_expr(parameter->scalar_default_type(), parameter->scalar_default());
const auto scalar_min = deserialize_expr(parameter->scalar_min_type(), parameter->scalar_min());
const auto scalar_max = deserialize_expr(parameter->scalar_max_type(), parameter->scalar_max());
const auto scalar_estimate = deserialize_expr(parameter->scalar_estimate_type(), parameter->scalar_estimate());
return Parameter(type, is_buffer, dimensions, name, data, scalar_default, scalar_min, scalar_max, scalar_estimate);
return Parameter(type, dimensions, name, scalar_data, scalar_default, scalar_min, scalar_max, scalar_estimate);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,8 @@ class GeneratorInput_Scalar : public GeneratorInputImpl<T, Expr> {

void set_def_min_max() override {
for (Parameter &p : this->parameters_) {
p.set_scalar<TBase>(def_);
// No: we want to leave the Parameter unset here.
// p.set_scalar<TBase>(def_);
p.set_default_value(def_expr_);
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/InferArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "ExternFuncArgument.h"
#include "Function.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "InferArguments.h"

Expand Down Expand Up @@ -197,7 +198,9 @@ class InferArguments : public IRGraphVisitor {

ArgumentEstimates argument_estimates = p.get_argument_estimates();
if (!p.is_buffer()) {
argument_estimates.scalar_def = p.scalar_expr();
// We don't want to crater here if a scalar param isn't set;
// instead, default to a zero of the right type, like we used to.
argument_estimates.scalar_def = p.has_scalar_value() ? p.scalar_expr() : make_zero(p.type());
argument_estimates.scalar_min = p.min_value();
argument_estimates.scalar_max = p.max_value();
argument_estimates.scalar_estimate = p.estimate();
Expand Down
90 changes: 64 additions & 26 deletions src/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct ParameterContents {
const int dimensions;
const std::string name;
Buffer<> buffer;
uint64_t data = 0;
std::optional<halide_scalar_value_t> scalar_data;
int host_alignment;
std::vector<BufferConstraint> buffer_constraints;
Expr scalar_default, scalar_min, scalar_max, scalar_estimate;
Expand Down Expand Up @@ -82,21 +82,21 @@ Parameter::Parameter(const Type &t, bool is_buffer, int d, const std::string &na
internal_assert(is_buffer || d == 0) << "Scalar parameters should be zero-dimensional";
}

Parameter::Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name,
Parameter::Parameter(const Type &t, int dimensions, const std::string &name,
const Buffer<void> &buffer, int host_alignment, const std::vector<BufferConstraint> &buffer_constraints,
MemoryType memory_type)
: contents(new Internal::ParameterContents(t, is_buffer, dimensions, name)) {
: contents(new Internal::ParameterContents(t, /*is_buffer*/ true, dimensions, name)) {
contents->buffer = buffer;
contents->host_alignment = host_alignment;
contents->buffer_constraints = buffer_constraints;
contents->memory_type = memory_type;
}

Parameter::Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name,
uint64_t data, const Expr &scalar_default, const Expr &scalar_min,
Parameter::Parameter(const Type &t, int dimensions, const std::string &name,
const std::optional<halide_scalar_value_t> &scalar_data, const Expr &scalar_default, const Expr &scalar_min,
const Expr &scalar_max, const Expr &scalar_estimate)
: contents(new Internal::ParameterContents(t, is_buffer, dimensions, name)) {
contents->data = data;
: contents(new Internal::ParameterContents(t, /*is_buffer*/ false, dimensions, name)) {
contents->scalar_data = scalar_data;
contents->scalar_default = scalar_default;
contents->scalar_min = scalar_min;
contents->scalar_max = scalar_max;
Expand All @@ -123,51 +123,55 @@ bool Parameter::is_buffer() const {
return contents->is_buffer;
}

bool Parameter::has_scalar_value() const {
return defined() && !contents->is_buffer && contents->scalar_data.has_value();
}

Expr Parameter::scalar_expr() const {
check_is_scalar();
const auto sv = scalar_data_checked();
const Type t = type();
if (t.is_float()) {
switch (t.bits()) {
case 16:
if (t.is_bfloat()) {
return Expr(scalar<bfloat16_t>());
return Expr(bfloat16_t::make_from_bits(sv.u.u16));
} else {
return Expr(scalar<float16_t>());
return Expr(float16_t::make_from_bits(sv.u.u16));
}
case 32:
return Expr(scalar<float>());
return Expr(sv.u.f32);
case 64:
return Expr(scalar<double>());
return Expr(sv.u.f64);
}
} else if (t.is_int()) {
switch (t.bits()) {
case 8:
return Expr(scalar<int8_t>());
return Expr(sv.u.i8);
case 16:
return Expr(scalar<int16_t>());
return Expr(sv.u.i16);
case 32:
return Expr(scalar<int32_t>());
return Expr(sv.u.i32);
case 64:
return Expr(scalar<int64_t>());
return Expr(sv.u.i64);
}
} else if (t.is_uint()) {
switch (t.bits()) {
case 1:
return Internal::make_bool(scalar<bool>());
return Internal::make_bool(sv.u.b);
case 8:
return Expr(scalar<uint8_t>());
return Expr(sv.u.u8);
case 16:
return Expr(scalar<uint16_t>());
return Expr(sv.u.u16);
case 32:
return Expr(scalar<uint32_t>());
return Expr(sv.u.u32);
case 64:
return Expr(scalar<uint64_t>());
return Expr(sv.u.u64);
}
} else if (t.is_handle()) {
// handles are always uint64 internally.
switch (t.bits()) {
case 64:
return Expr(scalar<uint64_t>());
return Expr(sv.u.u64);
}
}
internal_error << "Unsupported type " << t << " in scalar_expr\n";
Expand Down Expand Up @@ -198,14 +202,48 @@ void Parameter::set_buffer(const Buffer<> &b) {
contents->buffer = b;
}

void *Parameter::scalar_address() const {
const void *Parameter::read_only_scalar_address() const {
check_is_scalar();
return &contents->data;
// Use explicit if here (rather than user_assert) so that we don't
// have to disable bugprone-unchecked-optional-access in clang-tidy,
// which is a useful check.
const auto &sv = contents->scalar_data;
if (sv.has_value()) {
return std::addressof(sv.value());
} else {
user_error << "Parameter " << name() << " does not have a valid scalar value.\n";
return nullptr;
}
}

std::optional<halide_scalar_value_t> Parameter::scalar_data() const {
return defined() ? contents->scalar_data : std::nullopt;
}

uint64_t Parameter::scalar_raw_value() const {
halide_scalar_value_t Parameter::scalar_data_checked() const {
check_is_scalar();
return contents->data;
// Use explicit if here (rather than user_assert) so that we don't
// have to disable bugprone-unchecked-optional-access in clang-tidy,
// which is a useful check.
halide_scalar_value_t result;
const auto &sv = contents->scalar_data;
if (sv.has_value()) {
result = sv.value();
} else {
user_error << "Parameter " << name() << " does not have a valid scalar value.\n";
result.u.u64 = 0; // silence "possibly uninitialized" compiler warning
}
return result;
}

halide_scalar_value_t Parameter::scalar_data_checked(const Type &val_type) const {
check_type(val_type);
return scalar_data_checked();
}

void Parameter::set_scalar(const Type &val_type, halide_scalar_value_t val) {
check_type(val_type);
contents->scalar_data = std::optional<halide_scalar_value_t>(val);
}

/** Tests if this handle is the same as another handle */
Expand Down
71 changes: 45 additions & 26 deletions src/Parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
/** \file
* Defines the internal representation of parameters to halide piplines
*/
#include <optional>
#include <string>

#include "Buffer.h"
Expand All @@ -25,7 +26,9 @@ struct BufferConstraint {
};

namespace Internal {

#ifdef WITH_SERIALIZATION
class Deserializer;
class Serializer;
#endif
struct ParameterContents;
Expand All @@ -45,21 +48,41 @@ class Parameter {
Internal::IntrusivePtr<Internal::ParameterContents> contents;

#ifdef WITH_SERIALIZATION
friend class Internal::Serializer; //< for scalar_raw_value()
friend class Internal::Deserializer; //< for scalar_data()
friend class Internal::Serializer; //< for scalar_data()
#endif
friend class Pipeline; //< for scalar_address()
friend class Pipeline; //< for read_only_scalar_address()

/** Get the raw currently-bound buffer. null if unbound */
const halide_buffer_t *raw_buffer() const;

/** Get the pointer to the current value of the scalar
* parameter. For a given parameter, this address will never
* change. Only relevant when jitting. */
void *scalar_address() const;
* change. Note that this can only be used to *read* from -- it must
* not be written to, so don't cast away the constness. Only relevant when jitting. */
const void *read_only_scalar_address() const;

/** If the Parameter is a scalar, and the scalar data is valid, return
* the scalar data. Otherwise, return nullopt. */
std::optional<halide_scalar_value_t> scalar_data() const;

/** If the Parameter is a scalar and has a valid scalar value, return it.
* Otherwise, assert-fail. */
halide_scalar_value_t scalar_data_checked() const;

/** If the Parameter is a scalar *of the given type* and has a valid scalar value, return it.
* Otherwise, assert-fail. */
halide_scalar_value_t scalar_data_checked(const Type &val_type) const;

/** Get the raw data of the current value of the scalar
* parameter. Only relevant when serializing. */
uint64_t scalar_raw_value() const;
/** Construct a new buffer parameter via deserialization. */
Parameter(const Type &t, int dimensions, const std::string &name,
const Buffer<void> &buffer, int host_alignment, const std::vector<BufferConstraint> &buffer_constraints,
MemoryType memory_type);

/** Construct a new scalar parameter via deserialization. */
Parameter(const Type &t, int dimensions, const std::string &name,
const std::optional<halide_scalar_value_t> &scalar_data, const Expr &scalar_default, const Expr &scalar_min,
const Expr &scalar_max, const Expr &scalar_estimate);

public:
/** Construct a new undefined handle */
Expand All @@ -81,15 +104,6 @@ class Parameter {
* explicitly specified (as opposed to autogenerated). */
Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name);

/** Construct a new parameter via deserialization. */
Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name,
const Buffer<void> &buffer, int host_alignment, const std::vector<BufferConstraint> &buffer_constraints,
MemoryType memory_type);

Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name,
uint64_t data, const Expr &scalar_default, const Expr &scalar_min,
const Expr &scalar_max, const Expr &scalar_estimate);

Parameter(const Parameter &) = default;
Parameter &operator=(const Parameter &) = default;
Parameter(Parameter &&) = default;
Expand All @@ -111,28 +125,33 @@ class Parameter {
* bound value. Only relevant when jitting */
template<typename T>
HALIDE_NO_USER_CODE_INLINE T scalar() const {
check_type(type_of<T>());
return *((const T *)(scalar_address()));
static_assert(sizeof(T) <= sizeof(halide_scalar_value_t));
const auto sv = scalar_data_checked(type_of<T>());
T t;
memcpy(&t, &sv.u.u64, sizeof(t));
return t;
}

/** This returns the current value of scalar<type()>()
* as an Expr. */
/** This returns the current value of scalar<type()>() as an Expr.
* If the Parameter is not scalar, or its scalar data is not valid, this will assert-fail. */
Expr scalar_expr() const;

/** This returns true if scalar_expr() would return a valid Expr,
* false if not. */
bool has_scalar_value() const;

/** If the parameter is a scalar parameter, set its current
* value. Only relevant when jitting */
template<typename T>
HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) {
check_type(type_of<T>());
*((T *)(scalar_address())) = val;
halide_scalar_value_t sv;
memcpy(&sv.u.u64, &val, sizeof(val));
set_scalar(type_of<T>(), sv);
}

/** If the parameter is a scalar parameter, set its current
* value. Only relevant when jitting */
HALIDE_NO_USER_CODE_INLINE void set_scalar(const Type &val_type, halide_scalar_value_t val) {
check_type(val_type);
memcpy(scalar_address(), &val, val_type.bytes());
}
void set_scalar(const Type &val_type, halide_scalar_value_t val);

/** If the parameter is a buffer parameter, get its currently
* bound buffer. Only relevant when jitting */
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target
}
debug(2) << "JIT input ImageParam argument ";
} else {
args_result.store[arg_index++] = p.scalar_address();
args_result.store[arg_index++] = p.read_only_scalar_address();
debug(2) << "JIT input scalar argument ";
}
}
Expand Down
10 changes: 8 additions & 2 deletions src/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1298,13 +1298,19 @@ Offset<Serialize::Parameter> Serializer::serialize_parameter(FlatBufferBuilder &
return Serialize::CreateParameter(builder, defined, is_buffer, type_serialized, dimensions, name_serialized, host_alignment,
builder.CreateVector(buffer_constraints_serialized), memory_type_serialized);
} else {
const uint64_t data = parameter.scalar_raw_value();
static_assert(FLATBUFFERS_USE_STD_OPTIONAL);
const auto make_optional_u64 = [](const std::optional<halide_scalar_value_t> &v) -> std::optional<uint64_t> {
return v.has_value() ?
std::optional<uint64_t>(v.value().u.u64) :
std::nullopt;
};
const auto scalar_data = make_optional_u64(parameter.scalar_data());
const auto scalar_default_serialized = serialize_expr(builder, parameter.default_value());
const auto scalar_min_serialized = serialize_expr(builder, parameter.min_value());
const auto scalar_max_serialized = serialize_expr(builder, parameter.max_value());
const auto scalar_estimate_serialized = serialize_expr(builder, parameter.estimate());
return Serialize::CreateParameter(builder, defined, is_buffer, type_serialized,
dimensions, name_serialized, 0, 0, Serialize::MemoryType_Auto, data,
dimensions, name_serialized, 0, 0, Serialize::MemoryType_Auto, scalar_data,
scalar_default_serialized.first, scalar_default_serialized.second,
scalar_min_serialized.first, scalar_min_serialized.second,
scalar_max_serialized.first, scalar_max_serialized.second,
Expand Down
2 changes: 1 addition & 1 deletion src/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ table Parameter {
host_alignment: int32;
buffer_constraints: [BufferConstraint];
memory_type: MemoryType;
data: uint64;
scalar_data: uint64 = null; // Note: it is valid for this to be omitted, even if is_buffer = false.
scalar_default: Expr;
scalar_min: Expr;
scalar_max: Expr;
Expand Down
Loading