Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for public IP connections #227

Merged
merged 11 commits into from
Jan 29, 2024
3 changes: 2 additions & 1 deletion google/cloud/alloydb/connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from google.cloud.alloydb.connector.async_connector import AsyncConnector
from google.cloud.alloydb.connector.connector import Connector
from google.cloud.alloydb.connector.instance import IPTypes
from google.cloud.alloydb.connector.version import __version__

__all__ = ["__version__", "Connector", "AsyncConnector"]
__all__ = ["__version__", "Connector", "AsyncConnector", "IPTypes"]
8 changes: 7 additions & 1 deletion google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import google.cloud.alloydb.connector.asyncpg as asyncpg
from google.cloud.alloydb.connector.client import AlloyDBClient
from google.cloud.alloydb.connector.instance import Instance
from google.cloud.alloydb.connector.instance import IPTypes
from google.cloud.alloydb.connector.utils import generate_keys

if TYPE_CHECKING:
Expand All @@ -46,6 +47,8 @@ class AsyncConnector:
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE for private IP connections.
"""

def __init__(
Expand All @@ -54,12 +57,14 @@ def __init__(
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: IPTypes = IPTypes.PRIVATE,
) -> None:
self._instances: Dict[str, Instance] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
self._ip_type = ip_type
# initialize credentials
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
if credentials:
Expand Down Expand Up @@ -136,7 +141,8 @@ async def connect(
kwargs.pop("port", None)

# get connection info for AlloyDB instance
ip_address, context = await instance.connection_info()
ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type)
ip_address, context = await instance.connection_info(ip_type)

# callable to be used for auto IAM authn
def get_authentication_token() -> str:
Expand Down
11 changes: 7 additions & 4 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import logging
from typing import List, Optional, Tuple, TYPE_CHECKING
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING

import aiohttp
from google.auth.transport.requests import Request
Expand Down Expand Up @@ -81,7 +81,7 @@ async def _get_metadata(
region: str,
cluster: str,
name: str,
) -> str:
) -> Dict[str, Optional[str]]:
"""
Fetch the metadata for a given AlloyDB instance.

Expand All @@ -97,7 +97,7 @@ async def _get_metadata(
name (str): The name of the AlloyDB instance.

Returns:
str: IP address of the AlloyDB instance.
dict: IP addresses of the AlloyDB instance.
"""
logger.debug(f"['{project}/{region}/{cluster}/{name}']: Requesting metadata")

Expand All @@ -114,7 +114,10 @@ async def _get_metadata(
resp = await self._client.get(url, headers=headers, raise_for_status=True)
resp_dict = await resp.json()

return resp_dict["ipAddress"]
return {
"PRIVATE": resp_dict.get("ipAddress"),
"PUBLIC": resp_dict.get("publicIpAddress"),
}

async def _get_client_certificate(
self,
Expand Down
8 changes: 7 additions & 1 deletion google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from google.cloud.alloydb.connector.client import AlloyDBClient
from google.cloud.alloydb.connector.instance import Instance
from google.cloud.alloydb.connector.instance import IPTypes
import google.cloud.alloydb.connector.pg8000 as pg8000
from google.cloud.alloydb.connector.utils import generate_keys
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb
Expand Down Expand Up @@ -56,6 +57,8 @@ class Connector:
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE for private IP connections.
"""

def __init__(
Expand All @@ -64,6 +67,7 @@ def __init__(
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: IPTypes = IPTypes.PRIVATE,
) -> None:
# create event loop and start it in background thread
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
Expand All @@ -74,6 +78,7 @@ def __init__(
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
self._ip_type = ip_type
# initialize credentials
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
if credentials:
Expand Down Expand Up @@ -163,7 +168,8 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
kwargs.pop("port", None)

# get connection info for AlloyDB instance
ip_address, context = await instance.connection_info()
ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type)
ip_address, context = await instance.connection_info(ip_type)

# synchronous drivers are blocking and run using executor
try:
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/alloydb/connector/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@

class RefreshError(Exception):
pass


class IPTypeNotFoundError(Exception):
pass
23 changes: 21 additions & 2 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from __future__ import annotations

import asyncio
from enum import Enum
import logging
import re
from typing import Tuple, TYPE_CHECKING

from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
from google.cloud.alloydb.connector.exceptions import RefreshError
from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter
from google.cloud.alloydb.connector.refresh import _is_valid
Expand All @@ -39,6 +41,15 @@
)


class IPTypes(Enum):
"""
Enum for specifying IP type to connect to AlloyDB with.
"""

PUBLIC: str = "PUBLIC"
PRIVATE: str = "PRIVATE"


def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]:
# should take form "projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>"
if INSTANCE_URI_REGEX.fullmatch(instance_uri) is None:
Expand Down Expand Up @@ -214,16 +225,24 @@ async def force_refresh(self) -> None:
if not await _is_valid(self._current):
self._current = self._next

async def connection_info(self) -> Tuple[str, ssl.SSLContext]:
async def connection_info(self, ip_type: IPTypes) -> Tuple[str, ssl.SSLContext]:
"""
Return connection info for current refresh result.

Args:
ip_type (IpTypes): Type of AlloyDB instance IP to connect over.
Returns:
Tuple[str, ssl.SSLContext]: AlloyDB instance IP address
and configured TLS connection.
"""
refresh: RefreshResult = await self._current
return refresh.instance_ip, refresh.context
ip_address = refresh.ip_addrs.get(ip_type.value)
if ip_address is None:
raise IPTypeNotFoundError(
"AlloyDB instance does not have an IP addresses matching "
f"type: '{ip_type.value}'"
)
return ip_address, refresh.context

async def close(self) -> None:
"""
Expand Down
11 changes: 7 additions & 4 deletions google/cloud/alloydb/connector/refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import ssl
from tempfile import TemporaryDirectory
from typing import List, Tuple, TYPE_CHECKING
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING

from cryptography import x509

Expand Down Expand Up @@ -71,16 +71,19 @@ class RefreshResult:
Builds the TLS context required to connect to AlloyDB database.

Args:
instance_ip (str): The IP address of the AlloyDB instance.
ip_addrs (Dict[str, str]): The IP addresses of the AlloyDB instance.
key (rsa.RSAPrivateKey): Private key for the client connection.
certs (Tuple[str, List(str)]): Client cert and CA certs for establishing
the chain of trust used in building the TLS context.
"""

def __init__(
self, instance_ip: str, key: rsa.RSAPrivateKey, certs: Tuple[str, List[str]]
self,
ip_addrs: Dict[str, Optional[str]],
key: rsa.RSAPrivateKey,
certs: Tuple[str, List[str]],
) -> None:
self.instance_ip = instance_ip
self.ip_addrs = ip_addrs
# create TLS context
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# update ssl.PROTOCOL_TLS_CLIENT default
Expand Down
4 changes: 4 additions & 0 deletions tests/system/test_asyncpg_iam_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_async_connector_iam_authn]
import asyncpg
import sqlalchemy
import sqlalchemy.ext.asyncio
Expand Down Expand Up @@ -78,6 +79,9 @@ async def getconn() -> asyncpg.Connection:
return engine, connector


# [END alloydb_sqlalchemy_connect_async_connector_iam_authn]


async def test_asyncpg_iam_authn_time() -> None:
"""Basic test to get time from database."""
inst_uri = os.environ["ALLOYDB_INSTANCE_URI"]
Expand Down
103 changes: 103 additions & 0 deletions tests/system/test_asyncpg_public_ip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_async_connector_public_ip]
import asyncpg
import pytest
import sqlalchemy
import sqlalchemy.ext.asyncio

from google.cloud.alloydb.connector import AsyncConnector
from google.cloud.alloydb.connector import IPTypes


async def create_sqlalchemy_engine(
inst_uri: str,
user: str,
password: str,
db: str,
) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.

A sample invocation looks like:

engine, connector = await create_sqlalchemy_engine(
inst_uri,
user,
password,
db,
)
async with engine.connect() as conn:
time = await conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone()
curr_time = time[0]
# do something with query result
await connector.close()

Args:
instance_uri (str):
The instance URI specifies the instance relative to the project,
region, and cluster. For example:
"projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance"
user (str):
The database user name, e.g., postgres
password (str):
The database user's password, e.g., secret-password
db_name (str):
The name of the database, e.g., mydb
"""
connector = AsyncConnector()

async def getconn() -> asyncpg.Connection:
conn: asyncpg.Connection = await connector.connect(
inst_uri,
"asyncpg",
user=user,
password=password,
db=db,
ip_type=IPTypes.PUBLIC,
)
return conn

# create SQLAlchemy connection pool
engine = sqlalchemy.ext.asyncio.create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
return engine, connector


# [END alloydb_sqlalchemy_connect_async_connector_public_ip]


@pytest.mark.asyncio
async def test_connection_with_asyncpg() -> None:
"""Basic test to get time from database."""
inst_uri = os.environ["ALLOYDB_INSTANCE_URI"]
user = os.environ["ALLOYDB_USER"]
password = os.environ["ALLOYDB_PASS"]
db = os.environ["ALLOYDB_DB"]

pool, connector = await create_sqlalchemy_engine(inst_uri, user, password, db)

async with pool.connect() as conn:
res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone()
assert res[0] == 1

await connector.close()
3 changes: 2 additions & 1 deletion tests/system/test_pg8000_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from datetime import datetime
import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_connector]
import pg8000
Expand All @@ -27,7 +28,7 @@ def create_sqlalchemy_engine(
user: str,
password: str,
db: str,
) -> (sqlalchemy.engine.Engine, Connector):
) -> Tuple[sqlalchemy.engine.Engine, Connector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
5 changes: 4 additions & 1 deletion tests/system/test_pg8000_iam_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_connector_iam_authn]
import pg8000
import sqlalchemy

Expand Down Expand Up @@ -73,10 +74,12 @@ def getconn() -> pg8000.dbapi.Connection:
"postgresql+pg8000://",
creator=getconn,
)
engine.dialect.description_encoding = None
return engine, connector


# [END alloydb_sqlalchemy_connect_connector_iam_authn]


def test_pg8000_iam_authn_time() -> None:
"""Basic test to get time from database."""
inst_uri = os.environ["ALLOYDB_INSTANCE_URI"]
Expand Down
Loading