diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 0da68f1..6114609 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -18,8 +18,9 @@ from types import TracebackType from typing import Any, Dict, Optional, Type, TYPE_CHECKING -from google.auth import default +import google.auth from google.auth.credentials import with_scopes_if_required +import google.auth.transport.requests import google.cloud.alloydb.connector.asyncpg as asyncpg from google.cloud.alloydb.connector.client import AlloyDBClient @@ -44,6 +45,7 @@ class AsyncConnector: Defaults to None, picking up project from environment. alloydb_api_endpoint (str): Base URL to use when calling the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". + enable_iam_auth (bool): Enables automatic IAM database authentication. """ def __init__( @@ -51,18 +53,20 @@ def __init__( credentials: Optional[Credentials] = None, quota_project: Optional[str] = None, alloydb_api_endpoint: str = "https://alloydb.googleapis.com", + enable_iam_auth: bool = False, ) -> None: self._instances: Dict[str, Instance] = {} # initialize default params self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint + self._enable_iam_auth = enable_iam_auth # initialize credentials scopes = ["https://www.googleapis.com/auth/cloud-platform"] if credentials: self._credentials = with_scopes_if_required(credentials, scopes=scopes) # otherwise use application default credentials else: - self._credentials, _ = default(scopes=scopes) + self._credentials, _ = google.auth.default(scopes=scopes) # check if AsyncConnector is being initialized with event loop running # Otherwise we will lazy init keys @@ -107,6 +111,8 @@ async def connect( driver=driver, ) + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + # use existing connection info if possible if instance_uri in self._instances: instance = self._instances[instance_uri] @@ -132,6 +138,18 @@ async def connect( # get connection info for AlloyDB instance ip_address, context = await instance.connection_info() + # callable to be used for auto IAM authn + def get_authentication_token() -> str: + """Get OAuth2 access token to be used for IAM database authentication""" + # refresh credentials if expired + if not self._credentials.valid: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + return self._credentials.token + + # if enable_iam_auth is set, use auth token as database password + if enable_iam_auth: + kwargs["password"] = get_authentication_token try: return await connector(ip_address, context, **kwargs) except Exception: diff --git a/tests/system/test_asyncpg_iam_authn.py b/tests/system/test_asyncpg_iam_authn.py new file mode 100644 index 0000000..9ec8c79 --- /dev/null +++ b/tests/system/test_asyncpg_iam_authn.py @@ -0,0 +1,94 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +import os +from typing import Tuple + +import asyncpg +import sqlalchemy +import sqlalchemy.ext.asyncio + +from google.cloud.alloydb.connector import AsyncConnector + + +async def create_sqlalchemy_engine( + inst_uri: str, + user: str, + db: str, +) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]: + """Creates a connection pool for an AlloyDB instance and returns the pool + and the connector. Callers are responsible for closing the pool and the + connector. + + A sample invocation looks like: + + pool, connector = await create_sqlalchemy_engine( + inst_uri, + user, + db, + ) + async with pool.connect() as conn: + time = (await conn.execute(sqlalchemy.text("SELECT NOW()"))).fetchone() + conn.commit() + curr_time = time[0] + # do something with query result + await connector.close() + + Args: + instance_uri (str): + The instance URI specifies the instance relative to the project, + region, and cluster. For example: + "projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance" + user (str): + The formatted IAM database username. + e.g., my-email@test.com, service-account@project-id.iam + db_name (str): + The name of the database, e.g., mydb + """ + connector = AsyncConnector() + + async def getconn() -> asyncpg.Connection: + conn: asyncpg.Connection = await connector.connect( + inst_uri, + "asyncpg", + user=user, + db=db, + enable_iam_auth=True, + ) + return conn + + # create async SQLAlchemy connection pool + engine = sqlalchemy.ext.asyncio.create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + execution_options={"isolation_level": "AUTOCOMMIT"}, + ) + return engine, connector + + +async def test_asyncpg_iam_authn_time() -> None: + """Basic test to get time from database.""" + inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] + user = os.environ["ALLOYDB_IAM_USER"] + db = os.environ["ALLOYDB_DB"] + + pool, connector = await create_sqlalchemy_engine(inst_uri, user, db) + async with pool.connect() as conn: + time = (await conn.execute(sqlalchemy.text("SELECT NOW()"))).fetchone() + curr_time = time[0] + assert type(curr_time) is datetime + await connector.close() + # cleanup AsyncEngine + await pool.dispose() diff --git a/tests/system/test_pg8000_iam_authn.py b/tests/system/test_pg8000_iam_authn.py index 62400fe..0ed356e 100644 --- a/tests/system/test_pg8000_iam_authn.py +++ b/tests/system/test_pg8000_iam_authn.py @@ -14,8 +14,8 @@ from datetime import datetime import os +from typing import Tuple -# [START alloydb_sqlalchemy_connect_connector] import pg8000 import sqlalchemy @@ -26,7 +26,7 @@ def create_sqlalchemy_engine( inst_uri: str, user: str, db: str, -) -> (sqlalchemy.engine.Engine, Connector): +) -> Tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for an AlloyDB instance and returns the pool and the connector. Callers are responsible for closing the pool and the connector. @@ -51,7 +51,8 @@ def create_sqlalchemy_engine( region, and cluster. For example: "projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance" user (str): - The database user name, e.g., postgres + The formatted IAM database username. + e.g., my-email@test.com, service-account@project-id.iam db_name (str): The name of the database, e.g., mydb """ @@ -76,9 +77,6 @@ def getconn() -> pg8000.dbapi.Connection: return engine, connector -# [END alloydb_sqlalchemy_connect_connector] - - def test_pg8000_iam_authn_time() -> None: """Basic test to get time from database.""" inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index 279a8bd..f1c1957 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -36,6 +36,7 @@ async def test_AsyncConnector_init(credentials: FakeCredentials) -> None: assert connector._alloydb_api_endpoint == ALLOYDB_API_ENDPOINT assert connector._client is None assert connector._credentials == credentials + assert connector._enable_iam_auth is False await connector.close() @@ -52,6 +53,7 @@ async def test_AsyncConnector_context_manager( assert connector._alloydb_api_endpoint == ALLOYDB_API_ENDPOINT assert connector._client is None assert connector._credentials == credentials + assert connector._enable_iam_auth is False TEST_INSTANCE_NAME = "/".join(