Skip to content

Commit

Permalink
feat: add auto IAM authn for asyncpg (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Jan 16, 2024
1 parent 1ed5aa2 commit 165b059
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 8 deletions.
22 changes: 20 additions & 2 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from types import TracebackType
from typing import Any, Dict, Optional, Type, TYPE_CHECKING

from google.auth import default
import google.auth
from google.auth.credentials import with_scopes_if_required
import google.auth.transport.requests

import google.cloud.alloydb.connector.asyncpg as asyncpg
from google.cloud.alloydb.connector.client import AlloyDBClient
Expand All @@ -44,25 +45,28 @@ class AsyncConnector:
Defaults to None, picking up project from environment.
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
"""

def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
) -> None:
self._instances: Dict[str, Instance] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
# initialize credentials
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
if credentials:
self._credentials = with_scopes_if_required(credentials, scopes=scopes)
# otherwise use application default credentials
else:
self._credentials, _ = default(scopes=scopes)
self._credentials, _ = google.auth.default(scopes=scopes)

# check if AsyncConnector is being initialized with event loop running
# Otherwise we will lazy init keys
Expand Down Expand Up @@ -107,6 +111,8 @@ async def connect(
driver=driver,
)

enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)

# use existing connection info if possible
if instance_uri in self._instances:
instance = self._instances[instance_uri]
Expand All @@ -132,6 +138,18 @@ async def connect(
# get connection info for AlloyDB instance
ip_address, context = await instance.connection_info()

# callable to be used for auto IAM authn
def get_authentication_token() -> str:
"""Get OAuth2 access token to be used for IAM database authentication"""
# refresh credentials if expired
if not self._credentials.valid:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)
return self._credentials.token

# if enable_iam_auth is set, use auth token as database password
if enable_iam_auth:
kwargs["password"] = get_authentication_token
try:
return await connector(ip_address, context, **kwargs)
except Exception:
Expand Down
94 changes: 94 additions & 0 deletions tests/system/test_asyncpg_iam_authn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
import os
from typing import Tuple

import asyncpg
import sqlalchemy
import sqlalchemy.ext.asyncio

from google.cloud.alloydb.connector import AsyncConnector


async def create_sqlalchemy_engine(
inst_uri: str,
user: str,
db: str,
) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
A sample invocation looks like:
pool, connector = await create_sqlalchemy_engine(
inst_uri,
user,
db,
)
async with pool.connect() as conn:
time = (await conn.execute(sqlalchemy.text("SELECT NOW()"))).fetchone()
conn.commit()
curr_time = time[0]
# do something with query result
await connector.close()
Args:
instance_uri (str):
The instance URI specifies the instance relative to the project,
region, and cluster. For example:
"projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance"
user (str):
The formatted IAM database username.
e.g., my-email@test.com, service-account@project-id.iam
db_name (str):
The name of the database, e.g., mydb
"""
connector = AsyncConnector()

async def getconn() -> asyncpg.Connection:
conn: asyncpg.Connection = await connector.connect(
inst_uri,
"asyncpg",
user=user,
db=db,
enable_iam_auth=True,
)
return conn

# create async SQLAlchemy connection pool
engine = sqlalchemy.ext.asyncio.create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
return engine, connector


async def test_asyncpg_iam_authn_time() -> None:
"""Basic test to get time from database."""
inst_uri = os.environ["ALLOYDB_INSTANCE_URI"]
user = os.environ["ALLOYDB_IAM_USER"]
db = os.environ["ALLOYDB_DB"]

pool, connector = await create_sqlalchemy_engine(inst_uri, user, db)
async with pool.connect() as conn:
time = (await conn.execute(sqlalchemy.text("SELECT NOW()"))).fetchone()
curr_time = time[0]
assert type(curr_time) is datetime
await connector.close()
# cleanup AsyncEngine
await pool.dispose()
10 changes: 4 additions & 6 deletions tests/system/test_pg8000_iam_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from datetime import datetime
import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_connector]
import pg8000
import sqlalchemy

Expand All @@ -26,7 +26,7 @@ def create_sqlalchemy_engine(
inst_uri: str,
user: str,
db: str,
) -> (sqlalchemy.engine.Engine, Connector):
) -> Tuple[sqlalchemy.engine.Engine, Connector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand All @@ -51,7 +51,8 @@ def create_sqlalchemy_engine(
region, and cluster. For example:
"projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance"
user (str):
The database user name, e.g., postgres
The formatted IAM database username.
e.g., my-email@test.com, service-account@project-id.iam
db_name (str):
The name of the database, e.g., mydb
"""
Expand All @@ -76,9 +77,6 @@ def getconn() -> pg8000.dbapi.Connection:
return engine, connector


# [END alloydb_sqlalchemy_connect_connector]


def test_pg8000_iam_authn_time() -> None:
"""Basic test to get time from database."""
inst_uri = os.environ["ALLOYDB_INSTANCE_URI"]
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def test_AsyncConnector_init(credentials: FakeCredentials) -> None:
assert connector._alloydb_api_endpoint == ALLOYDB_API_ENDPOINT
assert connector._client is None
assert connector._credentials == credentials
assert connector._enable_iam_auth is False
await connector.close()


Expand All @@ -52,6 +53,7 @@ async def test_AsyncConnector_context_manager(
assert connector._alloydb_api_endpoint == ALLOYDB_API_ENDPOINT
assert connector._client is None
assert connector._credentials == credentials
assert connector._enable_iam_auth is False


TEST_INSTANCE_NAME = "/".join(
Expand Down

0 comments on commit 165b059

Please sign in to comment.