Skip to content

Commit

Permalink
Error at compile-time when a non-subclassable class is being subclassed
Browse files Browse the repository at this point in the history
Previously this crashed at runtime.
  • Loading branch information
ChayimFriedman2 committed Aug 19, 2024
1 parent 4211c57 commit f927019
Show file tree
Hide file tree
Showing 19 changed files with 69 additions and 66 deletions.
1 change: 1 addition & 0 deletions newsfragments/4453.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make subclassing a class that doesn't allow that a compile-time error instead of runtime
13 changes: 13 additions & 0 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2244,7 +2244,20 @@ impl<'a> PyClassImplsBuilder<'a> {
quote! { #pyo3_path::PyAny }
};

let pyclass_base_type_impl = attr.options.subclass.map(|subclass| {
quote_spanned! { subclass.span() =>
impl #pyo3_path::impl_::pyclass::PyClassBaseType for #cls {
type LayoutAsBase = #pyo3_path::impl_::pycell::PyClassObject<Self>;
type BaseNativeType = <Self as #pyo3_path::impl_::pyclass::PyClassImpl>::BaseNativeType;
type Initializer = #pyo3_path::pyclass_init::PyClassInitializer<Self>;
type PyClassMutability = <Self as #pyo3_path::impl_::pyclass::PyClassImpl>::PyClassMutability;
}
}
});

Ok(quote! {
#pyclass_base_type_impl

impl #pyo3_path::impl_::pyclass::PyClassImpl for #cls {
const IS_BASETYPE: bool = #is_basetype;
const IS_SUBCLASS: bool = #is_subclass;
Expand Down
1 change: 1 addition & 0 deletions src/exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ macro_rules! impl_native_exception (

$crate::impl_exception_boilerplate!($name);
$crate::pyobject_native_type!($name, $layout, |_py| unsafe { $crate::ffi::$exc_name as *mut $crate::ffi::PyTypeObject } $(, #checkfunction=$checkfunction)?);
$crate::pyobject_subclassable_native_type!($name, $layout);
);
($name:ident, $exc_name:ident, $doc:expr) => (
impl_native_exception!($name, $exc_name, $doc, $crate::ffi::PyBaseExceptionObject);
Expand Down
23 changes: 12 additions & 11 deletions src/impl_/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,18 @@ impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
#[cfg_attr(
all(diagnostic_namespace, feature = "abi3"),
diagnostic::on_unimplemented(
note = "with the `abi3` feature enabled, PyO3 does not support subclassing native types"
message = "pyclass `{Self}` cannot be subclassed",
label = "required for `#[pyclass(extends={Self})]`",
note = "if you own `{Self}`, add `subclass` to the `#[pyclass]` macro: `#[pyclass(subclass)]`",
note = "with the `abi3` feature enabled, PyO3 does not support subclassing native types",
)
)]
#[cfg_attr(
all(diagnostic_namespace, not(feature = "abi3")),
diagnostic::on_unimplemented(
message = "pyclass `{Self}` cannot be subclassed",
label = "required for `#[pyclass(extends={Self})]`",
note = "if you own `{Self}`, add `subclass` to the `#[pyclass]` macro: `#[pyclass(subclass)]`",
)
)]
pub trait PyClassBaseType: Sized {
Expand All @@ -1123,16 +1134,6 @@ pub trait PyClassBaseType: Sized {
type PyClassMutability: PyClassMutability;
}

/// All mutable PyClasses can be used as a base type.
///
/// In the future this will be extended to immutable PyClasses too.
impl<T: PyClass> PyClassBaseType for T {
type LayoutAsBase = crate::impl_::pycell::PyClassObject<T>;
type BaseNativeType = T::BaseNativeType;
type Initializer = crate::pyclass_init::PyClassInitializer<Self>;
type PyClassMutability = T::PyClassMutability;
}

/// Implementation of tp_dealloc for pyclasses without gc
pub(crate) unsafe extern "C" fn tp_dealloc<T: PyClass>(obj: *mut ffi::PyObject) {
crate::impl_::trampoline::dealloc(obj, PyClassObject::<T>::tp_dealloc)
Expand Down
2 changes: 1 addition & 1 deletion src/pyclass_init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ where
impl<S, B> From<(S, B)> for PyClassInitializer<S>
where
S: PyClass<BaseType = B>,
B: PyClass,
B: PyClass + PyClassBaseType<Initializer = PyClassInitializer<B>>,
B::BaseType: PyClassBaseType<Initializer = PyNativeTypeInitializer<B::BaseType>>,
{
fn from(sub_and_base: (S, B)) -> PyClassInitializer<S> {
Expand Down
1 change: 1 addition & 0 deletions src/types/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pyobject_native_type_info!(
);

pyobject_native_type_sized!(PyAny, ffi::PyObject);
pyobject_subclassable_native_type!(PyAny, ffi::PyObject);

/// This trait represents the Python APIs which are usable on all Python objects.
///
Expand Down
2 changes: 2 additions & 0 deletions src/types/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ use std::os::raw::c_double;
#[repr(transparent)]
pub struct PyComplex(PyAny);

pyobject_subclassable_native_type!(PyComplex, ffi::PyComplexObject);

pyobject_native_type!(
PyComplex,
ffi::PyComplexObject,
Expand Down
5 changes: 5 additions & 0 deletions src/types/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ pyobject_native_type!(
#module=Some("datetime"),
#checkfunction=PyDate_Check
);
pyobject_subclassable_native_type!(PyDate, crate::ffi::PyDateTime_Date);

impl PyDate {
/// Creates a new `datetime.date`.
Expand Down Expand Up @@ -248,6 +249,7 @@ pyobject_native_type!(
#module=Some("datetime"),
#checkfunction=PyDateTime_Check
);
pyobject_subclassable_native_type!(PyDateTime, crate::ffi::PyDateTime_DateTime);

impl PyDateTime {
/// Creates a new `datetime.datetime` object.
Expand Down Expand Up @@ -424,6 +426,7 @@ pyobject_native_type!(
#module=Some("datetime"),
#checkfunction=PyTime_Check
);
pyobject_subclassable_native_type!(PyTime, crate::ffi::PyDateTime_Time);

impl PyTime {
/// Creates a new `datetime.time` object.
Expand Down Expand Up @@ -550,6 +553,7 @@ pyobject_native_type!(
#module=Some("datetime"),
#checkfunction=PyTZInfo_Check
);
pyobject_subclassable_native_type!(PyTzInfo, crate::ffi::PyObject);

/// Equivalent to `datetime.timezone.utc`
pub fn timezone_utc_bound(py: Python<'_>) -> Bound<'_, PyTzInfo> {
Expand Down Expand Up @@ -594,6 +598,7 @@ pyobject_native_type!(
#module=Some("datetime"),
#checkfunction=PyDelta_Check
);
pyobject_subclassable_native_type!(PyDelta, crate::ffi::PyDateTime_Delta);

impl PyDelta {
/// Creates a new `timedelta`.
Expand Down
3 changes: 3 additions & 0 deletions src/types/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ use crate::{ffi, Python, ToPyObject};
#[repr(transparent)]
pub struct PyDict(PyAny);

#[cfg(not(feature = "abi3"))]
pyobject_subclassable_native_type!(PyDict, crate::ffi::PyDictObject);

pyobject_native_type!(
PyDict,
ffi::PyDictObject,
Expand Down
3 changes: 3 additions & 0 deletions src/types/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ use std::os::raw::c_double;
#[repr(transparent)]
pub struct PyFloat(PyAny);

#[cfg(not(feature = "abi3"))]
pyobject_subclassable_native_type!(PyFloat, crate::ffi::PyFloatObject);

pyobject_native_type!(
PyFloat,
ffi::PyFloatObject,
Expand Down
2 changes: 2 additions & 0 deletions src/types/frozenset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ impl<'py> PyFrozenSetBuilder<'py> {
#[repr(transparent)]
pub struct PyFrozenSet(PyAny);

#[cfg(not(feature = "abi3"))]
pyobject_subclassable_native_type!(PyFrozenSet, crate::ffi::PySetObject);
#[cfg(not(any(PyPy, GraalPy)))]
pyobject_native_type!(
PyFrozenSet,
Expand Down
13 changes: 10 additions & 3 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,8 @@ macro_rules! pyobject_native_type_core {

#[doc(hidden)]
#[macro_export]
macro_rules! pyobject_native_type_sized {
macro_rules! pyobject_subclassable_native_type {
($name:ty, $layout:path $(;$generics:ident)*) => {
unsafe impl $crate::type_object::PyLayout<$name> for $layout {}
impl $crate::type_object::PySizedLayout<$name> for $layout {}
impl<$($generics,)*> $crate::impl_::pyclass::PyClassBaseType for $name {
type LayoutAsBase = $crate::impl_::pycell::PyClassObjectBase<$layout>;
type BaseNativeType = $name;
Expand All @@ -207,6 +205,15 @@ macro_rules! pyobject_native_type_sized {
}
}

#[doc(hidden)]
#[macro_export]
macro_rules! pyobject_native_type_sized {
($name:ty, $layout:path $(;$generics:ident)*) => {
unsafe impl $crate::type_object::PyLayout<$name> for $layout {}
impl $crate::type_object::PySizedLayout<$name> for $layout {}
};
}

/// Declares all of the boilerplate for Python types which can be inherited from (because the exact
/// Python layout is known).
#[doc(hidden)]
Expand Down
3 changes: 3 additions & 0 deletions src/types/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ use std::ptr;
#[repr(transparent)]
pub struct PySet(PyAny);

#[cfg(not(feature = "abi3"))]
pyobject_subclassable_native_type!(PySet, crate::ffi::PySetObject);

#[cfg(not(any(PyPy, GraalPy)))]
pyobject_native_type!(
PySet,
Expand Down
3 changes: 3 additions & 0 deletions src/types/weakref/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ use super::PyWeakrefMethods;
#[repr(transparent)]
pub struct PyWeakrefReference(PyAny);

#[cfg(not(any(PyPy, GraalPy)))]
pyobject_subclassable_native_type!(PyWeakrefReference, crate::ffi::PyWeakReference);

#[cfg(not(any(PyPy, GraalPy, Py_LIMITED_API)))]
pyobject_native_type!(
PyWeakrefReference,
Expand Down
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ fn test_compile_errors() {
#[cfg(all(Py_LIMITED_API, not(Py_3_9)))]
t.compile_fail("tests/ui/abi3_dict.rs");
t.compile_fail("tests/ui/duplicate_pymodule_submodule.rs");
t.compile_fail("tests/ui/invalid_base_class.rs");
}
27 changes: 0 additions & 27 deletions tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
#[cfg(not(Py_LIMITED_API))]
use pyo3::types::PyBool;

#[path = "../src/tests/common.rs"]
mod common;
Expand Down Expand Up @@ -186,31 +184,6 @@ fn test_declarative_module() {
})
}

#[cfg(not(Py_LIMITED_API))]
#[pyclass(extends = PyBool)]
struct ExtendsBool;

#[cfg(not(Py_LIMITED_API))]
#[pymodule]
mod class_initialization_module {
#[pymodule_export]
use super::ExtendsBool;
}

#[test]
#[cfg(not(Py_LIMITED_API))]
fn test_class_initialization_fails() {
Python::with_gil(|py| {
let err = class_initialization_module::_PYO3_DEF
.make_module(py)
.unwrap_err();
assert_eq!(
err.to_string(),
"RuntimeError: An error occurred while initializing class ExtendsBool"
);
})
}

#[pymodule]
mod r#type {
#[pymodule_export]
Expand Down
23 changes: 0 additions & 23 deletions tests/test_inheritance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,26 +345,3 @@ fn test_subclass_ref_counts() {
);
})
}

#[test]
#[cfg(not(Py_LIMITED_API))]
fn module_add_class_inherit_bool_fails() {
use pyo3::types::PyBool;

#[pyclass(extends = PyBool)]
struct ExtendsBool;

Python::with_gil(|py| {
let m = PyModule::new(py, "test_module").unwrap();

let err = m.add_class::<ExtendsBool>().unwrap_err();
assert_eq!(
err.to_string(),
"RuntimeError: An error occurred while initializing class ExtendsBool"
);
assert_eq!(
err.cause(py).unwrap().to_string(),
"TypeError: type 'bool' is not an acceptable base type"
);
})
}
2 changes: 1 addition & 1 deletion tests/test_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ macro_rules! make_struct_using_macro {
// Ensure that one doesn't need to fall back on the escape type: tt
// in order to macro create pyclass.
($class_name:ident, $py_name:literal) => {
#[pyclass(name=$py_name)]
#[pyclass(name=$py_name, subclass)]
struct $class_name {}
};
}
Expand Down
7 changes: 7 additions & 0 deletions tests/ui/invalid_base_class.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use pyo3::prelude::*;
use pyo3::types::PyBool;

#[pyclass(extends=PyBool)]
struct ExtendsBool;

fn main() {}

0 comments on commit f927019

Please sign in to comment.