Skip to content

Commit

Permalink
[Feature] Integrate Databricks SDK with Model Serving Auth Provider (#…
Browse files Browse the repository at this point in the history
…761)

## Changes
This PR introduces a new model serving auth method to Databricks SDK. 
- If the correct environment variables are set to identify a model
serving environment
- Check to see if there is an oauth file written by the serving
environment
- If this file exists use the token here for authentication

## Tests
Added Unit tests

- [x] `make test` run locally
- [x] `make fmt` applied
- [x] relevant integration tests applied

---------

Signed-off-by: aravind-segu <aravind.segu@databricks.com>
  • Loading branch information
aravind-segu committed Sep 18, 2024
1 parent d5ec433 commit 9d39254
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 2 deletions.
89 changes: 87 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import platform
import subprocess
import sys
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import google.auth
import requests
Expand Down Expand Up @@ -698,6 +699,90 @@ def inner() -> Dict[str, str]:
return inner


# This Code is derived from Mlflow DatabricksModelServingConfigProvider
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
class ModelServingAuthProvider():
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"

def __init__(self):
self.expiry_time = -1
self.current_token = None
self.refresh_duration = 300 # 300 Seconds

def should_fetch_model_serving_environment_oauth(self) -> bool:
"""
Check whether this is the model serving environment
Additionally check if the oauth token file path exists
"""

is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
return (is_in_model_serving_env == "true"
and os.path.isfile(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))

def get_model_dependency_oauth_token(self, should_retry=True) -> str:
# Use Cached value if it is valid
if self.current_token is not None and self.expiry_time > time.time():
return self.current_token

try:
with open(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
oauth_dict = json.load(f)
self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
self.expiry_time = time.time() + self.refresh_duration
except Exception as e:
# sleep and retry in case of any race conditions with OAuth refreshing
if should_retry:
logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
exc_info=e)
time.sleep(0.5)
return self.get_model_dependency_oauth_token(should_retry=False)
else:
raise RuntimeError(
"Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
) from e
return self.current_token

def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
if not self.should_fetch_model_serving_environment_oauth():
return None

# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
"DB_MODEL_SERVING_HOST_URL")
token = self.get_model_dependency_oauth_token()

return (host, token)


@credentials_strategy('model-serving', [])
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
try:
model_serving_auth_provider = ModelServingAuthProvider()
if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
return None
host, token = model_serving_auth_provider.get_databricks_host_token()
if token is None:
raise ValueError(
"Got malformed auth (empty token) when fetching auth implicitly available in Model Serving Environment. Please contact Databricks support"
)
if cfg.host is None:
cfg.host = host
except Exception as e:
logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
return None

logger.info("Using Databricks Model Serving Authentication")

def inner() -> Dict[str, str]:
# Call here again to get the refreshed token
_, token = model_serving_auth_provider.get_databricks_host_token()
return {"Authorization": f"Bearer {token}"}

return inner


class DefaultCredentials:
""" Select the first applicable credential provider from the chain """

Expand All @@ -706,7 +791,7 @@ def __init__(self) -> None:
self._auth_providers = [
pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
google_credentials, google_id
google_credentials, google_id, model_serving_auth
]

def auth_type(self) -> str:
Expand Down
98 changes: 98 additions & 0 deletions tests/test_model_serving_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import time

import pytest

from databricks.sdk.core import Config

from .conftest import raises

default_auth_base_error_message = \
"default auth: cannot configure default credentials, " \
"please check https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication " \
"to configure credentials for your preferred authentication method"


@pytest.mark.parametrize(
"env_values, oauth_file_name",
[([('IS_IN_DB_MODEL_SERVING_ENV', 'true'),
('DB_MODEL_SERVING_HOST_URL', 'x')], "tests/testdata/model-serving-test-token"),
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true'),
('DB_MODEL_SERVING_HOST_URL', 'x')], "tests/testdata/model-serving-test-token"),
([('IS_IN_DB_MODEL_SERVING_ENV', 'true'),
('DATABRICKS_MODEL_SERVING_HOST_URL', 'x')], "tests/testdata/model-serving-test-token"),
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true'),
('DATABRICKS_MODEL_SERVING_HOST_URL', 'x')], "tests/testdata/model-serving-test-token"), ])
def test_model_serving_auth(env_values, oauth_file_name, monkeypatch):
## In mlflow we check for these two environment variables to return the correct config
for (env_name, env_value) in env_values:
monkeypatch.setenv(env_name, env_value)
# patch mlflow to read the file from the test directory
monkeypatch.setattr(
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
oauth_file_name)

cfg = Config()

assert cfg.auth_type == 'model-serving'
headers = cfg.authenticate()
assert (cfg.host == 'x')
# Token defined in the test file
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'


@pytest.mark.parametrize("env_values, oauth_file_name", [
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')], "invalid_file_name"), # In Model Serving and Invalid File Name
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
], "invalid_file_name"), # In Model Serving and Invalid File Name
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
])
@raises(default_auth_base_error_message)
def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
for (env_name, env_value) in env_values:
monkeypatch.setenv(env_name, env_value)
monkeypatch.setattr(
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
oauth_file_name)

Config()


def test_model_serving_auth_refresh(monkeypatch):
## In mlflow we check for these two environment variables to return the correct config
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')

# patch mlflow to read the file from the test directory
monkeypatch.setattr(
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
"tests/testdata/model-serving-test-token")

cfg = Config()
assert cfg.auth_type == 'model-serving'

current_time = time.time()
headers = cfg.authenticate()
assert (cfg.host == 'x')
assert headers.get(
"Authorization") == 'Bearer databricks_sdk_unit_test_token' # Token defined in the test file

# Simulate refreshing the token by patching to to a new file
monkeypatch.setattr(
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
"tests/testdata/model-serving-test-token-v2")

monkeypatch.setattr('databricks.sdk.credentials_provider.time.time', lambda: current_time + 10)

headers = cfg.authenticate()
assert (cfg.host == 'x')
# Read from cache even though new path is set because expiry is still not hit
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'

# Expiry is 300 seconds so this should force an expiry and re read from the new file path
monkeypatch.setattr('databricks.sdk.credentials_provider.time.time', lambda: current_time + 600)

headers = cfg.authenticate()
assert (cfg.host == 'x')
# Read V2 now
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token_v2'
7 changes: 7 additions & 0 deletions tests/testdata/model-serving-test-token
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"OAUTH_TOKEN": [
{
"oauthTokenValue": "databricks_sdk_unit_test_token"
}
]
}
7 changes: 7 additions & 0 deletions tests/testdata/model-serving-test-token-v2
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"OAUTH_TOKEN": [
{
"oauthTokenValue": "databricks_sdk_unit_test_token_v2"
}
]
}

0 comments on commit 9d39254

Please sign in to comment.