From e5ef56ce8f9a40651e366436d9eea67c94a30abd Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Sun, 26 Nov 2023 14:50:17 +0100 Subject: [PATCH 1/5] Add more ways to connect to weaviate There are other options for connecting to weaviate. This commit adds these other options and also improved the imports/typing --- airflow/providers/weaviate/hooks/weaviate.py | 31 ++-- .../providers/weaviate/hooks/test_weaviate.py | 148 ++++++++++++++++++ 2 files changed, 167 insertions(+), 12 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index c8b0ed05d4d99..a8895b1caf804 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -19,7 +19,8 @@ from typing import Any -import weaviate +from weaviate import Client as WeaviateClient +from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword from airflow.hooks.base import BaseHook @@ -62,29 +63,35 @@ def get_ui_field_behaviour() -> dict[str, Any]: }, } - def get_client(self) -> weaviate.Client: + def get_client(self) -> WeaviateClient: conn = self.get_connection(self.conn_id) url = conn.host username = conn.login or "" password = conn.password or "" extras = conn.extra_dejson - token = extras.pop("token", "") + access_token = extras.get("access_token", None) + refresh_token = extras.get("refresh_token", None) + expires_in = extras.get("expires_in", 60) + # previously token was used as api_key(backwards compatibility) + api_key = extras.get("api_key", None) or extras.get("token", None) + client_secret = extras.get("client_secret", None) additional_headers = extras.pop("additional_headers", {}) - scope = conn.extra_dejson.get("oidc_scope", "offline_access") - - if token == "" and username != "": - auth_client_secret = weaviate.AuthClientPassword( - username=username, password=password, scope=scope + scope = extras.get("scope", None) or extras.get("oidc_scope", "offline_access") + if api_key: + auth_client_secret = AuthApiKey(api_key) + elif access_token: + auth_client_secret = AuthBearerToken( + access_token, expires_in=expires_in, refresh_token=refresh_token ) + elif client_secret: + auth_client_secret = AuthClientCredentials(client_secret=client_secret, scope=scope) else: - auth_client_secret = weaviate.AuthApiKey(token) + auth_client_secret = AuthClientPassword(username=username, password=password, scope=scope) - client = weaviate.Client( + return WeaviateClient( url=url, auth_client_secret=auth_client_secret, additional_headers=additional_headers ) - return client - def test_connection(self) -> tuple[bool, str]: try: client = self.get_client() diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index 56f57ebc9bac1..389ae22d3e5b0 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -16,12 +16,16 @@ # under the License. from __future__ import annotations +from unittest import mock from unittest.mock import MagicMock, Mock, patch import pytest +from airflow.models import Connection from airflow.providers.weaviate.hooks.weaviate import WeaviateHook +pytestmark = pytest.mark.db_test + TEST_CONN_ID = "test_weaviate_conn" @@ -38,6 +42,150 @@ def weaviate_hook(): return hook +@pytest.fixture +def mock_auth_api_key(): + with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthApiKey") as m: + yield m + + +@pytest.fixture +def mock_auth_bearer_token(): + with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthBearerToken") as m: + yield m + + +@pytest.fixture +def mock_auth_client_credentials(): + with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientCredentials") as m: + yield m + + +@pytest.fixture +def mock_auth_client_password(): + with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientPassword") as m: + yield m + + +class TestWeaviateHook: + """ + Test the WeaviateHook Hook. + """ + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch): + """Set up the test method.""" + self.weaviate_api_key1 = "weaviate_api_key1" + self.weaviate_api_key2 = "weaviate_api_key2" + self.api_key = "api_key" + self.weaviate_client_credentials = "weaviate_client_credentials" + self.client_secret = "client_secret" + self.scope = "scope1 scope2" + self.client_password = "client_password" + self.client_bearer_token = "client_bearer_token" + self.host = "http://localhost:8080" + conns = ( + Connection( + conn_id=self.weaviate_api_key1, + host=self.host, + conn_type="weaviate", + extra={"api_key": self.api_key}, + ), + Connection( + conn_id=self.weaviate_api_key2, + host=self.host, + conn_type="weaviate", + extra={"token": self.api_key}, + ), + Connection( + conn_id=self.weaviate_client_credentials, + host=self.host, + conn_type="weaviate", + extra={"client_secret": self.client_secret, "scope": self.scope}, + ), + Connection( + conn_id=self.client_password, + host=self.host, + conn_type="weaviate", + login="login", + password="password", + ), + Connection( + conn_id=self.client_bearer_token, + host=self.host, + conn_type="weaviate", + extra={ + "access_token": self.client_bearer_token, + "expires_in": 30, + "refresh_token": "refresh_token", + }, + ), + ) + for conn in conns: + monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.get_uri()) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_client_with_api_key_in_extra(self, mock_client, mock_auth_api_key): + hook = WeaviateHook(conn_id=self.weaviate_api_key1) + hook.get_client() + mock_auth_api_key.assert_called_once_with(self.api_key) + mock_client.assert_called_once_with( + url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} + ) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_client_with_token_in_extra(self, mock_client, mock_auth_api_key): + # when token is passed in extra + hook = WeaviateHook(conn_id=self.weaviate_api_key2) + hook.get_client() + mock_auth_api_key.assert_called_once_with(self.api_key) + mock_client.assert_called_once_with( + url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} + ) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_client_with_access_token_in_extra(self, mock_client, mock_auth_bearer_token): + hook = WeaviateHook(conn_id=self.client_bearer_token) + hook.get_client() + mock_auth_bearer_token.assert_called_once_with( + self.client_bearer_token, expires_in=30, refresh_token="refresh_token" + ) + mock_client.assert_called_once_with( + url=self.host, + auth_client_secret=mock_auth_bearer_token( + access_token=self.client_bearer_token, expires_in=30, refresh_token="refresh_token" + ), + additional_headers={}, + ) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_client_with_client_secret_in_extra(self, mock_client, mock_auth_client_credentials): + hook = WeaviateHook(conn_id=self.weaviate_client_credentials) + hook.get_client() + mock_auth_client_credentials.assert_called_once_with( + client_secret=self.client_secret, scope=self.scope + ) + mock_client.assert_called_once_with( + url=self.host, + auth_client_secret=mock_auth_client_credentials(api_key=self.client_secret, scope=self.scope), + additional_headers={}, + ) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_client_with_client_password_in_extra(self, mock_client, mock_auth_client_password): + hook = WeaviateHook(conn_id=self.client_password) + hook.get_client() + mock_auth_client_password.assert_called_once_with( + username="login", password="password", scope="offline_access" + ) + mock_client.assert_called_once_with( + url=self.host, + auth_client_secret=mock_auth_client_password( + username="login", password="password", scope="offline_access" + ), + additional_headers={}, + ) + + def test_create_class(weaviate_hook): """ Test the create_class method of WeaviateHook. From 2999c5177a642ee66f0a4ff4141ef69cdd3e6b26 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 28 Nov 2023 12:56:53 +0100 Subject: [PATCH 2/5] fixup! Add more ways to connect to weaviate --- airflow/providers/weaviate/hooks/weaviate.py | 16 ++++++++++------ .../connections.rst | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index a8895b1caf804..7a0d36fc8f4c6 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -41,19 +41,19 @@ def __init__(self, conn_id: str = default_conn_name, *args: Any, **kwargs: Any) super().__init__(*args, **kwargs) self.conn_id = conn_id - @staticmethod - def get_connection_form_widgets() -> dict[str, Any]: + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: """Returns connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget from flask_babel import lazy_gettext from wtforms import PasswordField return { - "token": PasswordField(lazy_gettext("Weaviate API Token"), widget=BS3PasswordFieldWidget()), + "token": PasswordField(lazy_gettext("Weaviate API Key"), widget=BS3PasswordFieldWidget()), } - @staticmethod - def get_ui_field_behaviour() -> dict[str, Any]: + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: """Returns custom field behaviour.""" return { "hidden_fields": ["port", "schema"], @@ -63,7 +63,7 @@ def get_ui_field_behaviour() -> dict[str, Any]: }, } - def get_client(self) -> WeaviateClient: + def get_conn(self) -> WeaviateClient: conn = self.get_connection(self.conn_id) url = conn.host username = conn.login or "" @@ -92,6 +92,10 @@ def get_client(self) -> WeaviateClient: url=url, auth_client_secret=auth_client_secret, additional_headers=additional_headers ) + def get_client(self) -> WeaviateClient: + # Keeping this for backwards compatibility + return self.get_conn() + def test_connection(self) -> tuple[bool, str]: try: client = self.get_client() diff --git a/docs/apache-airflow-providers-weaviate/connections.rst b/docs/apache-airflow-providers-weaviate/connections.rst index 5e16164ff68d7..ed8faf8508b3e 100644 --- a/docs/apache-airflow-providers-weaviate/connections.rst +++ b/docs/apache-airflow-providers-weaviate/connections.rst @@ -42,6 +42,8 @@ OIDC Password (optional) Extra (optional) Specify the extra parameters (as json dictionary) that can be used in the connection. All parameters are optional. + The extras are those parameters that are acceptable in the different authentication methods + here: `Authentication `__ * If you'd like to use Vectorizers for your class, configure the API keys to use the corresponding embedding API. The extras accepts a key ``additional_headers`` containing the dictionary @@ -50,3 +52,20 @@ Extra (optional) Weaviate API Token (optional) Specify your Weaviate API Key to connect when API Key option is to be used for authentication. + +Supported Authentication Methods +-------------------------------- +* API Key Authentication:* This method uses the Weaviate API Key to authenticate the connection. You can either have the + API key in the ``Weaviate API Token`` field or in the extra field as a dictionary with key ``token`` or ``api_key`` and + value as the API key. + +* Bearer Token Authentication:* This method uses the Access Token to authenticate the connection. You need to +have the Access Token in the extra field as a dictionary with key ``access_token`` and value as the Access Token. Other +parameters such as ``expires_in`` and ``refresh_token`` are optional. + +* Client Credentials Authentication:* This method uses the Client Credentials to authenticate the connection. You need to +have the Client Credentials in the extra field as a dictionary with key ``client_secret`` and value as the Client Credentials. +The ``scope`` is optional. + +* Password Authentication:* This method uses the username and password to authenticate the connection. You can specify the +scope in the extra field as a dictionary with key ``scope`` and value as the scope. The ``scope`` is optional. From 45037e8c6d35e09cf3890a777742b69789b53de9 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 28 Nov 2023 13:57:48 +0100 Subject: [PATCH 3/5] fixup! fixup! Add more ways to connect to weaviate --- airflow/providers/weaviate/hooks/weaviate.py | 2 +- .../connections.rst | 18 +++++----- .../providers/weaviate/hooks/test_weaviate.py | 34 ++++++++----------- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 7a0d36fc8f4c6..1900aaaab75ea 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -76,7 +76,7 @@ def get_conn(self) -> WeaviateClient: api_key = extras.get("api_key", None) or extras.get("token", None) client_secret = extras.get("client_secret", None) additional_headers = extras.pop("additional_headers", {}) - scope = extras.get("scope", None) or extras.get("oidc_scope", "offline_access") + scope = extras.get("scope", None) or extras.get("oidc_scope", None) if api_key: auth_client_secret = AuthApiKey(api_key) elif access_token: diff --git a/docs/apache-airflow-providers-weaviate/connections.rst b/docs/apache-airflow-providers-weaviate/connections.rst index ed8faf8508b3e..081fe14d92acb 100644 --- a/docs/apache-airflow-providers-weaviate/connections.rst +++ b/docs/apache-airflow-providers-weaviate/connections.rst @@ -55,17 +55,17 @@ Weaviate API Token (optional) Supported Authentication Methods -------------------------------- -* API Key Authentication:* This method uses the Weaviate API Key to authenticate the connection. You can either have the +* API Key Authentication: This method uses the Weaviate API Key to authenticate the connection. You can either have the API key in the ``Weaviate API Token`` field or in the extra field as a dictionary with key ``token`` or ``api_key`` and value as the API key. -* Bearer Token Authentication:* This method uses the Access Token to authenticate the connection. You need to -have the Access Token in the extra field as a dictionary with key ``access_token`` and value as the Access Token. Other -parameters such as ``expires_in`` and ``refresh_token`` are optional. +* Bearer Token Authentication: This method uses the Access Token to authenticate the connection. You need to + have the Access Token in the extra field as a dictionary with key ``access_token`` and value as the Access Token. Other + parameters such as ``expires_in`` and ``refresh_token`` are optional. -* Client Credentials Authentication:* This method uses the Client Credentials to authenticate the connection. You need to -have the Client Credentials in the extra field as a dictionary with key ``client_secret`` and value as the Client Credentials. -The ``scope`` is optional. +* Client Credentials Authentication: This method uses the Client Credentials to authenticate the connection. You need to + have the Client Credentials in the extra field as a dictionary with key ``client_secret`` and value as the Client Credentials. + The ``scope`` is optional. -* Password Authentication:* This method uses the username and password to authenticate the connection. You can specify the -scope in the extra field as a dictionary with key ``scope`` and value as the scope. The ``scope`` is optional. +* Password Authentication: This method uses the username and password to authenticate the connection. You can specify the + scope in the extra field as a dictionary with key ``scope`` and value as the scope. The ``scope`` is optional. diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index 389ae22d3e5b0..98a2a1f0760cd 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -124,28 +124,28 @@ def setup_method(self, monkeypatch): monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.get_uri()) @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_client_with_api_key_in_extra(self, mock_client, mock_auth_api_key): + def test_get_conn_with_api_key_in_extra(self, mock_client, mock_auth_api_key): hook = WeaviateHook(conn_id=self.weaviate_api_key1) - hook.get_client() + hook.get_conn() mock_auth_api_key.assert_called_once_with(self.api_key) mock_client.assert_called_once_with( url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} ) @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_client_with_token_in_extra(self, mock_client, mock_auth_api_key): + def test_get_conn_with_token_in_extra(self, mock_client, mock_auth_api_key): # when token is passed in extra hook = WeaviateHook(conn_id=self.weaviate_api_key2) - hook.get_client() + hook.get_conn() mock_auth_api_key.assert_called_once_with(self.api_key) mock_client.assert_called_once_with( url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} ) @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_client_with_access_token_in_extra(self, mock_client, mock_auth_bearer_token): + def test_get_conn_with_access_token_in_extra(self, mock_client, mock_auth_bearer_token): hook = WeaviateHook(conn_id=self.client_bearer_token) - hook.get_client() + hook.get_conn() mock_auth_bearer_token.assert_called_once_with( self.client_bearer_token, expires_in=30, refresh_token="refresh_token" ) @@ -158,9 +158,9 @@ def test_get_client_with_access_token_in_extra(self, mock_client, mock_auth_bear ) @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_client_with_client_secret_in_extra(self, mock_client, mock_auth_client_credentials): + def test_get_conn_with_client_secret_in_extra(self, mock_client, mock_auth_client_credentials): hook = WeaviateHook(conn_id=self.weaviate_client_credentials) - hook.get_client() + hook.get_conn() mock_auth_client_credentials.assert_called_once_with( client_secret=self.client_secret, scope=self.scope ) @@ -171,17 +171,13 @@ def test_get_client_with_client_secret_in_extra(self, mock_client, mock_auth_cli ) @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") - def test_get_client_with_client_password_in_extra(self, mock_client, mock_auth_client_password): + def test_get_conn_with_client_password_in_extra(self, mock_client, mock_auth_client_password): hook = WeaviateHook(conn_id=self.client_password) - hook.get_client() - mock_auth_client_password.assert_called_once_with( - username="login", password="password", scope="offline_access" - ) + hook.get_conn() + mock_auth_client_password.assert_called_once_with(username="login", password="password", scope=None) mock_client.assert_called_once_with( url=self.host, - auth_client_secret=mock_auth_client_password( - username="login", password="password", scope="offline_access" - ), + auth_client_secret=mock_auth_client_password(username="login", password="password", scope=None), additional_headers={}, ) @@ -192,7 +188,7 @@ def test_create_class(weaviate_hook): """ # Mock the Weaviate Client mock_client = MagicMock() - weaviate_hook.get_client = MagicMock(return_value=mock_client) + weaviate_hook.get_conn = MagicMock(return_value=mock_client) # Define test class JSON test_class_json = { @@ -213,7 +209,7 @@ def test_create_schema(weaviate_hook): """ # Mock the Weaviate Client mock_client = MagicMock() - weaviate_hook.get_client = MagicMock(return_value=mock_client) + weaviate_hook.get_conn = MagicMock(return_value=mock_client) # Define test schema JSON test_schema_json = { @@ -238,7 +234,7 @@ def test_batch_data(weaviate_hook): """ # Mock the Weaviate Client mock_client = MagicMock() - weaviate_hook.get_client = MagicMock(return_value=mock_client) + weaviate_hook.get_conn = MagicMock(return_value=mock_client) # Define test data test_class_name = "TestClass" From 5613e843b4cc3c95346bf4187e6859519f7b5561 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 28 Nov 2023 17:09:57 +0100 Subject: [PATCH 4/5] add depreccation --- airflow/providers/weaviate/hooks/weaviate.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index 1900aaaab75ea..151aaabea6f1f 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -17,11 +17,13 @@ from __future__ import annotations +import warnings from typing import Any from weaviate import Client as WeaviateClient from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook @@ -94,6 +96,11 @@ def get_conn(self) -> WeaviateClient: def get_client(self) -> WeaviateClient: # Keeping this for backwards compatibility + warnings.warn( + "The `get_client` method has been renamed to `get_conn`", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) return self.get_conn() def test_connection(self) -> tuple[bool, str]: From 1f12b0a6e2e4c05c2ab9566afb77143e20b18198 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 28 Nov 2023 18:26:16 +0100 Subject: [PATCH 5/5] remove mark as dbtest --- tests/providers/weaviate/hooks/test_weaviate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index 98a2a1f0760cd..0274004fc004f 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -24,8 +24,6 @@ from airflow.models import Connection from airflow.providers.weaviate.hooks.weaviate import WeaviateHook -pytestmark = pytest.mark.db_test - TEST_CONN_ID = "test_weaviate_conn"