diff --git a/rust/core/src/driver_manager.rs b/rust/core/src/driver_manager.rs index ca0d53a634..c486aaaabb 100644 --- a/rust/core/src/driver_manager.rs +++ b/rust/core/src/driver_manager.rs @@ -240,19 +240,11 @@ impl ManagedDriver { check_status(status, error)?; Ok(driver) } -} - -impl Driver for ManagedDriver { - type DatabaseType = ManagedDatabase; - fn new_database(&mut self) -> Result { - self.new_database_with_opts(None) - } - - fn new_database_with_opts( - &mut self, - opts: impl IntoIterator::Option, OptionValue)>, - ) -> Result { + /// Returns a new database using the loaded driver. + /// + /// This uses `&mut self` to prevent a deadlock. + fn database_new(&mut self) -> Result { let driver = &self.inner.driver.lock().unwrap(); let mut database = ffi::FFI_AdbcDatabase::default(); @@ -262,10 +254,17 @@ impl Driver for ManagedDriver { let status = unsafe { method(&mut database, &mut error) }; check_status(status, error)?; - // DatabaseSetOption - for (key, value) in opts { - set_option_database(driver, &mut database, self.inner.version, key, value)?; - } + Ok(database) + } + + /// Initialize the given database using the loaded driver. + /// + /// This uses `&mut self` to prevent a deadlock. + fn database_init( + &mut self, + mut database: ffi::FFI_AdbcDatabase, + ) -> Result { + let driver = &self.inner.driver.lock().unwrap(); // DatabaseInit let mut error = ffi::FFI_AdbcError::with_driver(driver); @@ -273,6 +272,40 @@ impl Driver for ManagedDriver { let status = unsafe { method(&mut database, &mut error) }; check_status(status, error)?; + Ok(database) + } +} + +impl Driver for ManagedDriver { + type DatabaseType = ManagedDatabase; + + fn new_database(&mut self) -> Result { + // Construct a new database. + let database = self.database_new()?; + // Initialize the database. + let database = self.database_init(database)?; + let inner = Arc::new(ManagedDatabaseInner { + database: Mutex::new(database), + driver: self.inner.clone(), + }); + Ok(Self::DatabaseType { inner }) + } + + fn new_database_with_opts( + &mut self, + opts: impl IntoIterator::Option, OptionValue)>, + ) -> Result { + // Construct a new database. + let mut database = self.database_new()?; + // Set the options. + { + let driver = &self.inner.driver.lock().unwrap(); + for (key, value) in opts { + set_option_database(driver, &mut database, self.inner.version, key, value)?; + } + } + // Initialize the database. + let database = self.database_init(database)?; let inner = Arc::new(ManagedDatabaseInner { database: Mutex::new(database), driver: self.inner.clone(), @@ -425,6 +458,41 @@ impl ManagedDatabase { fn driver_version(&self) -> AdbcVersion { self.inner.driver.version } + + /// Returns a new connection using the loaded driver. + /// + /// This uses `&mut self` to prevent a deadlock. + fn connection_new(&mut self) -> Result { + let driver = &self.inner.driver.driver.lock().unwrap(); + let mut connection = ffi::FFI_AdbcConnection::default(); + + // ConnectionNew + let mut error = ffi::FFI_AdbcError::with_driver(driver); + let method = driver_method!(driver, ConnectionNew); + let status = unsafe { method(&mut connection, &mut error) }; + check_status(status, error)?; + + Ok(connection) + } + + /// Initialize the given connection using the loaded driver. + /// + /// This uses `&mut self` to prevent a deadlock. + fn connection_init( + &mut self, + mut connection: ffi::FFI_AdbcConnection, + ) -> Result { + let driver = &self.inner.driver.driver.lock().unwrap(); + let mut database = self.inner.database.lock().unwrap(); + + // ConnectionInit + let mut error = ffi::FFI_AdbcError::with_driver(driver); + let method = driver_method!(driver, ConnectionInit); + let status = unsafe { method(&mut connection, &mut *database, &mut error) }; + check_status(status, error)?; + + Ok(connection) + } } impl Optionable for ManagedDatabase { @@ -497,35 +565,38 @@ impl Database for ManagedDatabase { type ConnectionType = ManagedConnection; fn new_connection(&mut self) -> Result { - self.new_connection_with_opts(None) + // Construct a new connection. + let connection = self.connection_new()?; + // Initialize the connection. + let connection = self.connection_init(connection)?; + let inner = ManagedConnectionInner { + connection: Mutex::new(connection), + database: self.inner.clone(), + }; + Ok(Self::ConnectionType { + inner: Arc::new(inner), + }) } fn new_connection_with_opts( &mut self, opts: impl IntoIterator::Option, OptionValue)>, ) -> Result { - let driver = &self.inner.driver.driver.lock().unwrap(); - let mut database = self.inner.database.lock().unwrap(); - let mut connection = ffi::FFI_AdbcConnection::default(); - let mut error = ffi::FFI_AdbcError::with_driver(driver); - let method = driver_method!(driver, ConnectionNew); - let status = unsafe { method(&mut connection, &mut error) }; - check_status(status, error)?; - - for (key, value) in opts { - set_option_connection(driver, &mut connection, self.driver_version(), key, value)?; + // Construct a new connection. + let mut connection = self.connection_new()?; + // Set the options. + { + let driver = &self.inner.driver.driver.lock().unwrap(); + for (key, value) in opts { + set_option_connection(driver, &mut connection, self.driver_version(), key, value)?; + } } - - let mut error = ffi::FFI_AdbcError::with_driver(driver); - let method = driver_method!(driver, ConnectionInit); - let status = unsafe { method(&mut connection, database.deref_mut(), &mut error) }; - check_status(status, error)?; - + // Initialize the connection. + let connection = self.connection_init(connection)?; let inner = ManagedConnectionInner { connection: Mutex::new(connection), database: self.inner.clone(), }; - Ok(Self::ConnectionType { inner: Arc::new(inner), }) diff --git a/rust/drivers/dummy/src/lib.rs b/rust/drivers/dummy/src/lib.rs index 4ba348c3a8..a841fe0e37 100644 --- a/rust/drivers/dummy/src/lib.rs +++ b/rust/drivers/dummy/src/lib.rs @@ -183,16 +183,14 @@ impl Driver for DummyDriver { type DatabaseType = DummyDatabase; fn new_database(&mut self) -> Result { - self.new_database_with_opts(None) + Ok(Self::DatabaseType::default()) } fn new_database_with_opts( &mut self, opts: impl IntoIterator::Option, OptionValue)>, ) -> Result { - let mut database = Self::DatabaseType { - options: HashMap::new(), - }; + let mut database = Self::DatabaseType::default(); for (key, value) in opts { database.set_option(key, value)?; } @@ -200,6 +198,7 @@ impl Driver for DummyDriver { } } +#[derive(Default)] pub struct DummyDatabase { options: HashMap, } @@ -232,16 +231,14 @@ impl Database for DummyDatabase { type ConnectionType = DummyConnection; fn new_connection(&mut self) -> Result { - self.new_connection_with_opts(None) + Ok(Self::ConnectionType::default()) } fn new_connection_with_opts( &mut self, opts: impl IntoIterator::Option, OptionValue)>, ) -> Result { - let mut connection = Self::ConnectionType { - options: HashMap::new(), - }; + let mut connection = Self::ConnectionType::default(); for (key, value) in opts { connection.set_option(key, value)?; } @@ -249,6 +246,7 @@ impl Database for DummyDatabase { } } +#[derive(Default)] pub struct DummyConnection { options: HashMap, } @@ -281,9 +279,7 @@ impl Connection for DummyConnection { type StatementType = DummyStatement; fn new_statement(&mut self) -> Result { - Ok(Self::StatementType { - options: HashMap::new(), - }) + Ok(Self::StatementType::default()) } // This method is used to test that errors round-trip correctly. @@ -798,6 +794,7 @@ impl Connection for DummyConnection { } } +#[derive(Default)] pub struct DummyStatement { options: HashMap, }