diff --git a/src/oracledb/pool.py b/src/oracledb/pool.py index 4487a60a..34501cc8 100644 --- a/src/oracledb/pool.py +++ b/src/oracledb/pool.py @@ -411,7 +411,7 @@ def _set_connection_type(self, conn_class): conn_class = connection_module.Connection elif not issubclass( conn_class, connection_module.Connection - ) or issubclass(connection_module.AsyncConnection): + ) or issubclass(conn_class, connection_module.AsyncConnection): errors._raise_err(errors.ERR_INVALID_CONN_CLASS) self._connection_type = conn_class self._connection_method = oracledb.connect diff --git a/tests/test_2400_pool.py b/tests/test_2400_pool.py index 7142f5e5..f3a99213 100644 --- a/tests/test_2400_pool.py +++ b/tests/test_2400_pool.py @@ -866,6 +866,16 @@ def test_2434_invalid_pool_class(self): pool_class=int, ) + def test_2435_pool_with_connectiontype(self): + "2435 - test creating a pool with a subclassed connection type" + + class MyConnection(oracledb.Connection): + pass + + pool = test_env.get_pool(connectiontype=MyConnection) + with pool.acquire() as conn: + self.assertIsInstance(conn, MyConnection) + if __name__ == "__main__": test_env.run_test_cases() diff --git a/utils/templates/pool.py b/utils/templates/pool.py index 60a00e8a..2e48c6ae 100644 --- a/utils/templates/pool.py +++ b/utils/templates/pool.py @@ -409,7 +409,7 @@ def _set_connection_type(self, conn_class): conn_class = connection_module.Connection elif not issubclass( conn_class, connection_module.Connection - ) or issubclass(connection_module.AsyncConnection): + ) or issubclass(conn_class, connection_module.AsyncConnection): errors._raise_err(errors.ERR_INVALID_CONN_CLASS) self._connection_type = conn_class self._connection_method = oracledb.connect