Skip to content

Commit

Permalink
Use timestamp check for token refresh
Browse files Browse the repository at this point in the history
Signed-off-by: Terence Lim <terencelimxp@gmail.com>
  • Loading branch information
terryyylim committed Apr 8, 2021
1 parent f20ccfa commit 6ca95c1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
13 changes: 12 additions & 1 deletion sdk/python/feast/grpc/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from http import HTTPStatus

import grpc
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(self, config: Config):

self._static_token = None
self._token = None
self._token_expiry_ts = time.time()

# If provided, set a static token
if config.exists(opt.AUTH_TOKEN):
Expand All @@ -169,6 +171,9 @@ def __init__(self, config: Config):

def get_signed_meta(self):
""" Creates a signed authorization metadata token."""

if time.time() > self._token_expiry_ts:
self._refresh_token()
return (("authorization", "Bearer {}".format(self._token)),)

def _refresh_token(self):
Expand All @@ -179,10 +184,13 @@ def _refresh_token(self):
self._token = self._static_token
return

from google.oauth2.id_token import fetch_id_token
from google.oauth2.id_token import fetch_id_token, verify_oauth2_token

try:
self._token = fetch_id_token(self._request, audience="feast.dev")
self._token_expiry_ts = verify_oauth2_token(self._token, self._request)[
"exp"
]
return
except DefaultCredentialsError:
pass
Expand All @@ -195,6 +203,9 @@ def _refresh_token(self):
credentials.refresh(self._request)
if hasattr(credentials, "id_token"):
self._token = credentials.id_token
self._token_expiry_ts = verify_oauth2_token(self._token, self._request)[
"exp"
]
return
except DefaultCredentialsError:
pass # Could not determine credentials, skip
Expand Down
12 changes: 10 additions & 2 deletions sdk/python/tests/grpc/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,13 @@ def test_get_auth_metadata_plugin_oauth_should_raise_when_config_is_incorrect(
get_auth_metadata_plugin(config_with_missing_variable)


@patch(
"google.oauth2.id_token.verify_token",
return_value={"iss": "accounts.google.com", "exp": 12341234},
)
@patch("google.oauth2.id_token.fetch_id_token", return_value="Some Token")
def test_get_auth_metadata_plugin_google_should_pass_with_token_from_gcloud_sdk(
fetch_id_token, config_google
verify_token, fetch_id_token, config_google
):
auth_metadata_plugin = get_auth_metadata_plugin(config_google)
assert isinstance(auth_metadata_plugin, GoogleOpenIDAuthMetadataPlugin)
Expand All @@ -158,6 +162,10 @@ def test_get_auth_metadata_plugin_google_should_pass_with_token_from_gcloud_sdk(
)


@patch(
"google.oauth2.id_token.verify_token",
return_value={"iss": "accounts.google.com", "exp": 12341234},
)
@patch(
"google.auth.default",
return_value=[
Expand All @@ -167,7 +175,7 @@ def test_get_auth_metadata_plugin_google_should_pass_with_token_from_gcloud_sdk(
)
@patch("google.oauth2.id_token.fetch_id_token", side_effect=DefaultCredentialsError())
def test_get_auth_metadata_plugin_google_should_pass_with_token_from_google_auth_lib(
fetch_id_token, default, config_google
verify_token, fetch_id_token, default, config_google
):
auth_metadata_plugin = get_auth_metadata_plugin(config_google)
assert isinstance(auth_metadata_plugin, GoogleOpenIDAuthMetadataPlugin)
Expand Down

0 comments on commit 6ca95c1

Please sign in to comment.