Skip to content

Commit

Permalink
Merge pull request langchain-ai#2 from lossyrob/feature/azuredb-postg…
Browse files Browse the repository at this point in the history
…res-entra-example

Update Azure DB for PostgreSQL example with Entra auth.
  • Loading branch information
AbeOmor authored Oct 8, 2024
2 parents 79e867e + 1a8ce2b commit 4f25c4c
Showing 1 changed file with 180 additions and 59 deletions.
239 changes: 180 additions & 59 deletions docs/docs/integrations/vectorstores/azure_db_for_postgresql.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,28 @@
"id": "691dfd83",
"metadata": {},
"source": [
"## Vector Support on Azure Database for PostgreSQL - Flexible Server"
"## Vector Support on Azure Database for PostgreSQL - Flexible Server\n",
"\n",
"Azure Database for PostgreSQL - Flexible Server enables you to efficiently store and query millions of vector embeddings in PostgreSQL. As well as scale your AI use cases from POC to production:\n",
"\n",
"- Provides a familiar SQL interface for querying vector embeddings and relational data.\n",
"- Boosts `pgvector` with a faster and more precise similarity search across 100M+ vectors using DiskANN indexing algorithm.\n",
"- Simplifies operations by integrating relational metadata, vector embeddings, and time-series data into a single database.\n",
"- Leverages the power of the robust PostgreSQL ecosystem and Azure Cloud for enterprise-grade features inculding replication, and high availability."
]
},
{
"cell_type": "markdown",
"id": "7d44e542",
"id": "fad650f8",
"metadata": {},
"source": [
"Azure Database for PostgreSQL - Flexible Server enables you to efficiently store and query millions of vector embeddings in PostgreSQL. As well as scale your AI use cases from POC to production:\n",
"## Authentication\n",
"\n",
"- Provides a familiar SQL interface for querying vector embeddings and relational data.\n",
"- Boosts `pgvector` with a faster and more precise similarity search across 100M+ vectors using DiskANN indexing algorithm.\n",
"- Simplifies operations by integrating relational metadata, vector embeddings, and time-series data into a single database.\n",
"- Leverages the power of the robust PostgreSQL ecosystem and Azure Cloud for enterprise-grade features inculding replication, and high availability."
"Azure Database for PostgreSQL - Flexible Server supports password-based as well as [Microsoft Entra](https://learn.microsoft.com/en-us/azure/postgresql/flexible-server/concepts-azure-ad-authentication) (formerly Azure Active Directory) authentication.\n",
"Entra authentication allows you to use Entra identity to authenticate to your PostgreSQL server. This eliminates the need to manage separate usernames and passwords for your database users, and allows\n",
"you to leverage the same security mechanisms that you use for other Azure services.\n",
"\n",
"This notebook is set up to use either authentication method. You can configure whether or not to use Entra authentication later in the notebook."
]
},
{
Expand All @@ -46,22 +54,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "92df32f0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"outputs": [],
"source": [
"%pip install -qU langchain_postgres\n",
"%pip install -qU langchain-openai"
"%pip install -qU langchain-openai\n",
"%pip install -qU azure-identity"
]
},
{
Expand All @@ -85,58 +85,34 @@
"id": "eee31ce1-2c28-484d-82be-d22d9f9a31fd",
"metadata": {},
"source": [
"### Credentials\n",
"### Connection Details\n",
"\n",
"You will need you Azure Database for PostgreSQL [connection details](https://learn.microsoft.com/en-us/azure/postgresql/flexible-server/quickstart-create-server-portal#get-the-connection-information) and add them as environment variables to run this notebook."
]
},
{
"cell_type": "markdown",
"id": "fc023baa",
"metadata": {},
"source": [
"Set environment variables for the connection URI elements"
"You will need you Azure Database for PostgreSQL [connection details](https://learn.microsoft.com/en-us/azure/postgresql/flexible-server/quickstart-create-server-portal#get-the-connection-information) and add them as environment variables to run this notebook.\n",
"\n",
"Set the `USE_ENTRA_AUTH` flag to `True` if you want to use Microsoft Entra authentication. If using Entra authentication, you will only need to supply the host and database name. If using password authentication, you'll also need to set the username and password."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"id": "ed8c87bf",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"USE_ENTRA_AUTH = True\n",
"\n",
"# Supply the connection details for the database\n",
"os.environ[\"DBHOST\"] = \"<server-name>\"\n",
"os.environ[\"DBNAME\"] = \"<database-name>\"\n",
"os.environ[\"DBUSER\"] = \"<username>\"\n",
"os.environ[\"DBPASSWORD\"] = getpass.getpass(\"Database Password:\")\n",
"os.environ[\"SSLMODE\"] = \"require\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0f8e2f23",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import urllib.parse\n",
"\n",
"\n",
"def get_connection_uri():\n",
" # Read URI parameters from the environment\n",
" dbhost = os.environ[\"DBHOST\"]\n",
" dbname = os.environ[\"DBNAME\"]\n",
" dbuser = urllib.parse.quote(os.environ[\"DBUSER\"])\n",
" password = os.environ[\"DBPASSWORD\"]\n",
" sslmode = os.environ[\"SSLMODE\"]\n",
"os.environ[\"SSLMODE\"] = \"require\"\n",
"\n",
" # Construct connection URI\n",
" db_uri = f\"postgresql://{dbuser}:{password}@{dbhost}/{dbname}?sslmode={sslmode}\"\n",
" return db_uri"
"if not USE_ENTRA_AUTH:\n",
" # If using a username and password, supply them here\n",
" os.environ[\"DBUSER\"] = \"<username>\"\n",
" os.environ[\"DBPASSWORD\"] = getpass.getpass(\"Database Password:\")"
]
},
{
Expand All @@ -160,7 +136,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "94f5c129",
"metadata": {},
"outputs": [],
Expand All @@ -186,9 +162,153 @@
"## Initialization"
]
},
{
"cell_type": "markdown",
"id": "fa69c6c7",
"metadata": {},
"source": [
"### Supporting Functions for Entra Authentication\n",
"\n",
"The cell below contains functions that set up LangChain to use Entra authentication. It provides a `TokenManager` class that \n",
"retrieves tokens for the Azure DB for PostgreSQL service using `DefaultAzureCredential` from the `azure.identity` library.\n",
"It will cache tokens and refresh them when they are close to expiring. This is used to ensure the sqlalchemy Engine has\n",
"a valid token with which to create new connections. It will also parse the token, which is a Java Web Token (JWT), to extract\n",
"the username that is used to connect to the database.\n",
"\n",
"The create_postgres_engine function creates a sqlalchemy `Engine` that dynamically sets the username and password based on\n",
"the token fetched from the TokenManager. This `Engine` can be passed into the `connection` parameter of the `PGVector` LangChain VectorStore."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "65ff46e2",
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"import json\n",
"import time\n",
"\n",
"from azure.identity import DefaultAzureCredential\n",
"from sqlalchemy import create_engine, event\n",
"from sqlalchemy.engine.url import URL\n",
"\n",
"\n",
"class TokenManager:\n",
" def __init__(self):\n",
" self.credential = DefaultAzureCredential()\n",
" self.token_data = None\n",
" self.scope = \"https://ossrdbms-aad.database.windows.net/.default\"\n",
"\n",
" def _decode_jwt(self, token):\n",
" \"\"\"Decode the JWT payload to extract claims.\"\"\"\n",
" payload = token.split(\".\")[1]\n",
" padding = \"=\" * (4 - len(payload) % 4)\n",
" decoded_payload = base64.urlsafe_b64decode(payload + padding)\n",
" return json.loads(decoded_payload)\n",
"\n",
" def _is_token_expiring_soon(self, claims, buffer_time=300):\n",
" \"\"\"Check if the token is expiring soon (within buffer_time seconds).\"\"\"\n",
" current_time = time.time()\n",
" expiration_time = claims.get(\"exp\", 0)\n",
" return current_time > (expiration_time - buffer_time)\n",
"\n",
" def get_token_and_username(self):\n",
" \"\"\"Fetches a new token if needed and returns the username and token.\"\"\"\n",
" if not self.token_data or self._is_token_expiring_soon(\n",
" self.token_data[\"claims\"]\n",
" ):\n",
" # Fetch a new token and extract the username\n",
" token = self.credential.get_token(self.scope)\n",
" claims = self._decode_jwt(token.token)\n",
" username = claims.get(\"upn\")\n",
"\n",
" self.token_data = {\n",
" \"token\": token.token,\n",
" \"claims\": claims,\n",
" \"username\": username,\n",
" }\n",
"\n",
" return self.token_data[\"username\"], self.token_data[\"token\"]\n",
"\n",
"\n",
"def create_postgres_engine():\n",
" db_url = URL.create(\n",
" drivername=\"postgresql+psycopg\",\n",
" username=\"\", # This will be replaced dynamically\n",
" password=\"\", # This will be replaced dynamically\n",
" host=os.environ[\"DBHOST\"],\n",
" port=os.environ.get(\"DBPORT\", 5432),\n",
" database=os.environ[\"DBNAME\"],\n",
" )\n",
"\n",
" # Create a sqlalchemy engine\n",
" engine = create_engine(db_url, echo=True)\n",
"\n",
" # Initialize the TokenManager\n",
" token_manager = TokenManager()\n",
"\n",
" # Listen for the connection event to inject dynamic credentials\n",
" @event.listens_for(engine, \"do_connect\")\n",
" def provide_dynamic_credentials(dialect, conn_rec, cargs, cparams):\n",
" # Fetch the dynamic username and token\n",
" username, token = token_manager.get_token_and_username()\n",
"\n",
" # Override the connection parameters\n",
" cparams[\"user\"] = username\n",
" cparams[\"password\"] = token\n",
"\n",
" return engine"
]
},
{
"cell_type": "markdown",
"id": "804e77eb",
"metadata": {},
"source": [
"### Supporting Function for Password Authentication\n",
"\n",
"If not using Entra authentication, the `get_connection_uri` provides a connection URI that pulls the username and password from environment variables."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "36dd0ebd",
"metadata": {},
"outputs": [],
"source": [
"import urllib.parse\n",
"\n",
"\n",
"def get_connection_uri():\n",
" # Read URI parameters from the environment\n",
" dbhost = os.environ[\"DBHOST\"]\n",
" dbname = os.environ[\"DBNAME\"]\n",
" dbuser = urllib.parse.quote(os.environ[\"DBUSER\"])\n",
" password = os.environ[\"DBPASSWORD\"]\n",
" sslmode = os.environ[\"SSLMODE\"]\n",
"\n",
" # Construct connection URI\n",
" # Use psycopg 3!\n",
" db_uri = (\n",
" f\"postgresql+psycopg://{dbuser}:{password}@{dbhost}/{dbname}?sslmode={sslmode}\"\n",
" )\n",
" return db_uri"
]
},
{
"cell_type": "markdown",
"id": "c1bfc61e",
"metadata": {},
"source": [
"### Creating the vector store"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "979a65bd-742f-4b0d-be1e-c0baae245ec6",
"metadata": {
"tags": []
Expand All @@ -199,9 +319,10 @@
"from langchain_postgres import PGVector\n",
"from langchain_postgres.vectorstores import PGVector\n",
"\n",
"connection = get_connection_uri() # Uses psycopg3!\n",
"collection_name = \"my_docs\"\n",
"\n",
"# The connection is either a sqlalchemy engine or a connection URI\n",
"connection = create_postgres_engine() if USE_ENTRA_AUTH else get_connection_uri()\n",
"\n",
"vector_store = PGVector(\n",
" embeddings=embeddings,\n",
Expand Down

0 comments on commit 4f25c4c

Please sign in to comment.