Skip to content

Commit

Permalink
Apply __bool__ conversion only to numpy.bool_ to avoid false positives.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Dec 17, 2023
1 parent 45afbc2 commit e0a2c6f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 62 deletions.
6 changes: 6 additions & 0 deletions pytests/src/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@ fn issue_219() {
Python::with_gil(|_| {});
}

#[pyfunction]
fn accepts_bool(val: bool) -> bool {
val
}

#[pymodule]
pub fn misc(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(issue_219, m)?)?;
m.add_function(wrap_pyfunction!(accepts_bool, m)?)?;
Ok(())
}
10 changes: 10 additions & 0 deletions pytests/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,13 @@ def test_import_in_subinterpreter_forbidden():
)

_xxsubinterpreters.destroy(sub_interpreter)


def test_accepts_numpy_bool():
# binary numpy wheel not available on all platforms
numpy = pytest.importorskip("numpy")

assert pyo3_pytests.misc.accepts_bool(True) is True
assert pyo3_pytests.misc.accepts_bool(False) is False
assert pyo3_pytests.misc.accepts_bool(numpy.bool_(True)) is True
assert pyo3_pytests.misc.accepts_bool(numpy.bool_(False)) is False
95 changes: 33 additions & 62 deletions src/types/boolobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ impl IntoPy<PyObject> for bool {
/// Fails with `TypeError` if the input is not a Python `bool`.
impl<'source> FromPyObject<'source> for bool {
fn extract(obj: &'source PyAny) -> PyResult<Self> {
if let Ok(obj) = obj.downcast::<PyBool>() {
return Ok(obj.is_true());
}
let err = match obj.downcast::<PyBool>() {
Ok(obj) => return Ok(obj.is_true()),
Err(err) => err,
};

let missing_conversion = |obj: &PyAny| {
PyTypeError::new_err(format!(
Expand All @@ -92,28 +93,42 @@ impl<'source> FromPyObject<'source> for bool {
unsafe {
let ptr = obj.as_ptr();

if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() {
if let Some(nb_bool) = tp_as_number.nb_bool {
match (nb_bool)(ptr) {
0 => return Ok(false),
1 => return Ok(true),
_ => return Err(crate::PyErr::fetch(obj.py())),
if libc::strcmp(
(*ffi::Py_TYPE(ptr)).tp_name,
b"numpy.bool_\0".as_ptr().cast(),
) == 0
{
if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() {
if let Some(nb_bool) = tp_as_number.nb_bool {
match (nb_bool)(ptr) {
0 => return Ok(false),
1 => return Ok(true),
_ => return Err(crate::PyErr::fetch(obj.py())),
}
}
}
}

Err(missing_conversion(obj))
return Err(missing_conversion(obj));
}
}

#[cfg(any(Py_LIMITED_API, PyPy))]
{
let meth = obj
.lookup_special(crate::intern!(obj.py(), "__bool__"))?
.ok_or_else(|| missing_conversion(obj))?;

let obj = meth.call0()?.downcast::<PyBool>()?;
Ok(obj.is_true())
if obj
.get_type()
.name()
.map_or(false, |name| name == "numpy.bool_")
{
let meth = obj
.lookup_special(crate::intern!(obj.py(), "__bool__"))?
.ok_or_else(|| missing_conversion(obj))?;

let obj = meth.call0()?.downcast::<PyBool>()?;
return Ok(obj.is_true());
}
}

Err(err.into())
}

#[cfg(feature = "experimental-inspect")]
Expand All @@ -124,7 +139,7 @@ impl<'source> FromPyObject<'source> for bool {

#[cfg(test)]
mod tests {
use crate::types::{PyAny, PyBool, PyModule};
use crate::types::{PyAny, PyBool};
use crate::Python;
use crate::ToPyObject;

Expand All @@ -147,48 +162,4 @@ mod tests {
assert!(false.to_object(py).is(PyBool::new(py, false)));
});
}

#[test]
fn test_magic_method() {
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class A:
def __bool__(self): return True
class B:
def __bool__(self): return "not a bool"
class C:
def __len__(self): return 23
class D:
pass
"#,
"test.py",
"test",
)
.unwrap();

let a = module.getattr("A").unwrap().call0().unwrap();
assert!(a.extract::<bool>().unwrap());

let b = module.getattr("B").unwrap().call0().unwrap();
assert!(matches!(
&*b.extract::<bool>().unwrap_err().to_string(),
"TypeError: 'str' object cannot be converted to 'PyBool'"
| "TypeError: __bool__ should return bool, returned str"
));

let c = module.getattr("C").unwrap().call0().unwrap();
assert_eq!(
c.extract::<bool>().unwrap_err().to_string(),
"TypeError: object of type '<class 'test.C'>' does not define a '__bool__' conversion",
);

let d = module.getattr("D").unwrap().call0().unwrap();
assert_eq!(
d.extract::<bool>().unwrap_err().to_string(),
"TypeError: object of type '<class 'test.D'>' does not define a '__bool__' conversion",
);
});
}
}

0 comments on commit e0a2c6f

Please sign in to comment.