Skip to content

Commit

Permalink
Python bindings should preserve float32 scalar Exprs (fixed #8414)
Browse files Browse the repository at this point in the history
Tweak the PyBind11 code so that float32 scalars created via numpy.float32() don't get converted into boolean expressions.
  • Loading branch information
steven-johnson committed Sep 16, 2024
1 parent 4d368bf commit 6ab39d3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
41 changes: 27 additions & 14 deletions python_bindings/src/halide/halide_/PyExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,42 @@ void define_expr(py::module &m) {

auto expr_class =
py::class_<Expr>(m, "Expr")
// Default ctor
.def(py::init<>())
.def(py::init([](bool b) {
return Internal::make_bool(b);
}))
// PyBind11 searches in declared order,
// int should be tried before float conversion
.def(py::init<int>())
.def(py::init<int64_t>())
// Python float is implemented by double
// But Halide prohibits implicitly construct by double.
.def(py::init([](double v) {
return double_to_expr_check(v);
}))
.def(py::init<std::string>())

// for implicitly_convertible
// For implicitly_convertible
.def(py::init([](const FuncRef &f) -> Expr { return f; }))
.def(py::init([](const FuncTupleElementRef &f) -> Expr { return f; }))
.def(py::init([](const Param<> &p) -> Expr { return p; }))
.def(py::init([](const RDom &r) -> Expr { return r; }))
.def(py::init([](const RVar &r) -> Expr { return r; }))
.def(py::init([](const Var &v) -> Expr { return v; }))

// Weird types
.def(py::init<std::string>())

// Numeric types.
// This is tricky. PyBind11 tries the conversions in declared order,
// and we generally want to prefer int over float conversion (to avoid
// accidental promotion). However, we want to keep a float32 as a float32
// (e.g. specified via numpy.float32()) and the implicit Expr conversion
// will confuse PyBind, hence the apparently wrong order.
.def(py::init<float>())
.def(py::init([](bool b) {
return Internal::make_bool(b);
}))
.def(py::init<int32_t>())
.def(py::init<int64_t>())
// Most scalar fp values we get from Python will actually be doubles;
// for efficiency, we want to store these as float32 instead of float64.
// This may not always be the right decision -- e.g., if someone
// constructs something via numpy.float64() they will be unhappy --
// but changing the behavior now would likely cause lots of subtle
// regressions.
.def(py::init([](double v) {
return double_to_expr_check(v);
}))

.def("__bool__", to_bool)
.def("__nonzero__", to_bool)

Expand Down
12 changes: 12 additions & 0 deletions python_bindings/test/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,17 @@ def test_implicit_convert_int64():
assert (hl.i32(0) + (0x7fffffff+1)).type() == hl.Int(64)


def test_explicit_expr_ctors():
assert (hl.Expr(np.bool_(0))).type() == hl.Bool()
assert (hl.Expr(np.int32(0))).type() == hl.Int(32), (hl.Expr(np.int32(0))).type()
assert (hl.Expr(np.int64(0x7fffffff+1))).type() == hl.Int(64), (hl.Expr(np.int64(0x7fffffff+1))).type()
assert (hl.Expr(np.float32(0))).type() == hl.Float(32), (hl.Expr(np.float32(0))).type()
# Note that this is deliberate: we have aggressively downscaled scalar
# float64 values from Python into float32, and we aren't going to change
# that now.
assert (hl.Expr(np.float64(0))).type() == hl.Float(32)


if __name__ == "__main__":
test_compiletime_error()
test_runtime_error()
Expand All @@ -469,3 +480,4 @@ def test_implicit_convert_int64():
test_bool_conversion()
test_requirements()
test_implicit_convert_int64()
test_explicit_expr_ctors()

0 comments on commit 6ab39d3

Please sign in to comment.