Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more ways to connect to weaviate #35864

Merged
merged 5 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions airflow/providers/weaviate/hooks/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

from __future__ import annotations

import warnings
from typing import Any

import weaviate
from weaviate import Client as WeaviateClient
from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook


Expand All @@ -40,19 +43,19 @@ def __init__(self, conn_id: str = default_conn_name, *args: Any, **kwargs: Any)
super().__init__(*args, **kwargs)
self.conn_id = conn_id

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget
from flask_babel import lazy_gettext
from wtforms import PasswordField

return {
"token": PasswordField(lazy_gettext("Weaviate API Token"), widget=BS3PasswordFieldWidget()),
"token": PasswordField(lazy_gettext("Weaviate API Key"), widget=BS3PasswordFieldWidget()),
}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["port", "schema"],
Expand All @@ -62,28 +65,43 @@ def get_ui_field_behaviour() -> dict[str, Any]:
},
}

def get_client(self) -> weaviate.Client:
def get_conn(self) -> WeaviateClient:
conn = self.get_connection(self.conn_id)
url = conn.host
username = conn.login or ""
password = conn.password or ""
extras = conn.extra_dejson
token = extras.pop("token", "")
access_token = extras.get("access_token", None)
refresh_token = extras.get("refresh_token", None)
expires_in = extras.get("expires_in", 60)
# previously token was used as api_key(backwards compatibility)
api_key = extras.get("api_key", None) or extras.get("token", None)
client_secret = extras.get("client_secret", None)
additional_headers = extras.pop("additional_headers", {})
scope = conn.extra_dejson.get("oidc_scope", "offline_access")

if token == "" and username != "":
auth_client_secret = weaviate.AuthClientPassword(
username=username, password=password, scope=scope
scope = extras.get("scope", None) or extras.get("oidc_scope", None)
if api_key:
auth_client_secret = AuthApiKey(api_key)
elif access_token:
auth_client_secret = AuthBearerToken(
access_token, expires_in=expires_in, refresh_token=refresh_token
)
elif client_secret:
auth_client_secret = AuthClientCredentials(client_secret=client_secret, scope=scope)
else:
auth_client_secret = weaviate.AuthApiKey(token)
auth_client_secret = AuthClientPassword(username=username, password=password, scope=scope)

client = weaviate.Client(
return WeaviateClient(
url=url, auth_client_secret=auth_client_secret, additional_headers=additional_headers
)

return client
def get_client(self) -> WeaviateClient:
# Keeping this for backwards compatibility
warnings.warn(
"The `get_client` method has been renamed to `get_conn`",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return self.get_conn()

def test_connection(self) -> tuple[bool, str]:
try:
Expand Down
19 changes: 19 additions & 0 deletions docs/apache-airflow-providers-weaviate/connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ OIDC Password (optional)
Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in the
connection. All parameters are optional.
The extras are those parameters that are acceptable in the different authentication methods
here: `Authentication <https://weaviate-python-client.readthedocs.io/en/stable/weaviate.auth.html>`__

* If you'd like to use Vectorizers for your class, configure the API keys to use the corresponding
embedding API. The extras accepts a key ``additional_headers`` containing the dictionary
Expand All @@ -50,3 +52,20 @@ Extra (optional)

Weaviate API Token (optional)
Specify your Weaviate API Key to connect when API Key option is to be used for authentication.

Supported Authentication Methods
--------------------------------
* API Key Authentication: This method uses the Weaviate API Key to authenticate the connection. You can either have the
API key in the ``Weaviate API Token`` field or in the extra field as a dictionary with key ``token`` or ``api_key`` and
value as the API key.

* Bearer Token Authentication: This method uses the Access Token to authenticate the connection. You need to
have the Access Token in the extra field as a dictionary with key ``access_token`` and value as the Access Token. Other
parameters such as ``expires_in`` and ``refresh_token`` are optional.

* Client Credentials Authentication: This method uses the Client Credentials to authenticate the connection. You need to
have the Client Credentials in the extra field as a dictionary with key ``client_secret`` and value as the Client Credentials.
The ``scope`` is optional.

* Password Authentication: This method uses the username and password to authenticate the connection. You can specify the
scope in the extra field as a dictionary with key ``scope`` and value as the scope. The ``scope`` is optional.
150 changes: 147 additions & 3 deletions tests/providers/weaviate/hooks/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
# under the License.
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock, Mock, patch

import pytest

from airflow.models import Connection
from airflow.providers.weaviate.hooks.weaviate import WeaviateHook

pytestmark = pytest.mark.db_test
ephraimbuddy marked this conversation as resolved.
Show resolved Hide resolved

TEST_CONN_ID = "test_weaviate_conn"


Expand All @@ -38,13 +42,153 @@ def weaviate_hook():
return hook


@pytest.fixture
def mock_auth_api_key():
with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthApiKey") as m:
yield m


@pytest.fixture
def mock_auth_bearer_token():
with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthBearerToken") as m:
yield m


@pytest.fixture
def mock_auth_client_credentials():
with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientCredentials") as m:
yield m


@pytest.fixture
def mock_auth_client_password():
with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientPassword") as m:
yield m


class TestWeaviateHook:
"""
Test the WeaviateHook Hook.
"""

@pytest.fixture(autouse=True)
def setup_method(self, monkeypatch):
"""Set up the test method."""
self.weaviate_api_key1 = "weaviate_api_key1"
self.weaviate_api_key2 = "weaviate_api_key2"
self.api_key = "api_key"
self.weaviate_client_credentials = "weaviate_client_credentials"
self.client_secret = "client_secret"
self.scope = "scope1 scope2"
self.client_password = "client_password"
self.client_bearer_token = "client_bearer_token"
self.host = "http://localhost:8080"
conns = (
Connection(
conn_id=self.weaviate_api_key1,
host=self.host,
conn_type="weaviate",
extra={"api_key": self.api_key},
),
Connection(
conn_id=self.weaviate_api_key2,
host=self.host,
conn_type="weaviate",
extra={"token": self.api_key},
),
Connection(
conn_id=self.weaviate_client_credentials,
host=self.host,
conn_type="weaviate",
extra={"client_secret": self.client_secret, "scope": self.scope},
),
Connection(
conn_id=self.client_password,
host=self.host,
conn_type="weaviate",
login="login",
password="password",
),
Connection(
conn_id=self.client_bearer_token,
host=self.host,
conn_type="weaviate",
extra={
"access_token": self.client_bearer_token,
"expires_in": 30,
"refresh_token": "refresh_token",
},
),
)
for conn in conns:
monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.get_uri())

@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
def test_get_conn_with_api_key_in_extra(self, mock_client, mock_auth_api_key):
hook = WeaviateHook(conn_id=self.weaviate_api_key1)
hook.get_conn()
mock_auth_api_key.assert_called_once_with(self.api_key)
mock_client.assert_called_once_with(
url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={}
)

@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
def test_get_conn_with_token_in_extra(self, mock_client, mock_auth_api_key):
# when token is passed in extra
hook = WeaviateHook(conn_id=self.weaviate_api_key2)
hook.get_conn()
mock_auth_api_key.assert_called_once_with(self.api_key)
mock_client.assert_called_once_with(
url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={}
)

@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
def test_get_conn_with_access_token_in_extra(self, mock_client, mock_auth_bearer_token):
hook = WeaviateHook(conn_id=self.client_bearer_token)
hook.get_conn()
mock_auth_bearer_token.assert_called_once_with(
self.client_bearer_token, expires_in=30, refresh_token="refresh_token"
)
mock_client.assert_called_once_with(
url=self.host,
auth_client_secret=mock_auth_bearer_token(
access_token=self.client_bearer_token, expires_in=30, refresh_token="refresh_token"
),
additional_headers={},
)

@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
def test_get_conn_with_client_secret_in_extra(self, mock_client, mock_auth_client_credentials):
hook = WeaviateHook(conn_id=self.weaviate_client_credentials)
hook.get_conn()
mock_auth_client_credentials.assert_called_once_with(
client_secret=self.client_secret, scope=self.scope
)
mock_client.assert_called_once_with(
url=self.host,
auth_client_secret=mock_auth_client_credentials(api_key=self.client_secret, scope=self.scope),
additional_headers={},
)

@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
def test_get_conn_with_client_password_in_extra(self, mock_client, mock_auth_client_password):
hook = WeaviateHook(conn_id=self.client_password)
hook.get_conn()
mock_auth_client_password.assert_called_once_with(username="login", password="password", scope=None)
mock_client.assert_called_once_with(
url=self.host,
auth_client_secret=mock_auth_client_password(username="login", password="password", scope=None),
additional_headers={},
)


def test_create_class(weaviate_hook):
"""
Test the create_class method of WeaviateHook.
"""
# Mock the Weaviate Client
mock_client = MagicMock()
weaviate_hook.get_client = MagicMock(return_value=mock_client)
weaviate_hook.get_conn = MagicMock(return_value=mock_client)

# Define test class JSON
test_class_json = {
Expand All @@ -65,7 +209,7 @@ def test_create_schema(weaviate_hook):
"""
# Mock the Weaviate Client
mock_client = MagicMock()
weaviate_hook.get_client = MagicMock(return_value=mock_client)
weaviate_hook.get_conn = MagicMock(return_value=mock_client)

# Define test schema JSON
test_schema_json = {
Expand All @@ -90,7 +234,7 @@ def test_batch_data(weaviate_hook):
"""
# Mock the Weaviate Client
mock_client = MagicMock()
weaviate_hook.get_client = MagicMock(return_value=mock_client)
weaviate_hook.get_conn = MagicMock(return_value=mock_client)

# Define test data
test_class_name = "TestClass"
Expand Down