Skip to content

Commit

Permalink
ENGINT-156: Cache session config
Browse files Browse the repository at this point in the history
GitOrigin-RevId: c7c80034f8f94811e49e51389ed1a95dd71fdfe2
  • Loading branch information
drew committed Mar 21, 2022
1 parent 048c8b5 commit 072fd03
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 8 deletions.
9 changes: 9 additions & 0 deletions src/gretel_client/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
RunnerMode,
write_config,
)
from gretel_client.projects.common import f
from gretel_client.users import users


class GretelCliHandler(ExceptionHandler):
Expand Down Expand Up @@ -89,6 +91,13 @@ def configure(
sc.print(data=config.masked)


@cli.command(help="Check account, user and configuration details.")
@pass_session
def whoami(sc: SessionContext):
me = users.get_me()
sc.print(data={f.EMAIL: me[f.EMAIL], "config": sc.config.masked})


cli.add_command(models)
cli.add_command(records)
cli.add_command(projects)
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_client/cli/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def handle_error(ex: Exception, ctx: click.Context):
how the error came to be.
"""

if ctx.obj.debug:
if ctx.obj and ctx.obj.debug:
ctx.obj.log.debug(traceback.format_exc())

for ex_t, handler in exception_map().items():
Expand Down
90 changes: 84 additions & 6 deletions src/gretel_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import json
import logging
import os
import traceback

from enum import Enum
from getpass import getpass
from pathlib import Path
from typing import Optional, Type, TypeVar, Union

from urllib3.util import Retry

from gretel_client.rest.api.projects_api import ProjectsApi
from gretel_client.rest.api.users_api import UsersApi
from gretel_client.rest.api_client import ApiClient
from gretel_client.rest.configuration import Configuration

Expand All @@ -37,6 +40,8 @@
GRETEL_PREVIEW_FEATURES = "GRETEL_PREVIEW_FEATURES"
"""Env variable to manage preview features"""

GRETEL_ENVS = [GRETEL_API_KEY, GRETEL_PROJECT]


class PreviewFeatures(Enum):
"""Manage preview feature configurations"""
Expand Down Expand Up @@ -136,6 +141,10 @@ def _get_api_client(
)
return ApiClient(configuration)

@property
def email(self) -> str:
return self.get_api(UsersApi).users_me()["data"]["me"]["email"]

def get_api(
self,
api_interface: Type[T],
Expand Down Expand Up @@ -185,7 +194,7 @@ def __eq__(self, other: ClientConfig) -> bool:
def masked(self) -> dict:
"""Returns a masked representation of the config object."""
c = self.as_dict
c["api_key"] = "[redacted from output]"
c["api_key"] = self.masked_api_key
return c

@property
Expand All @@ -208,6 +217,23 @@ def _get_config_path() -> Path:
return Path().home() / f".{GRETEL}" / "config.json"


def clear_gretel_config():
"""Removes any Gretel configuration files from the host file system.
If any Gretel related environment variables exist, this will also remove
them from the current processes.
"""
try:
config = _get_config_path()
config.unlink()
config.parent.rmdir()
except (FileNotFoundError, OSError):
pass
for env_var in GRETEL_ENVS:
if env_var in os.environ:
del os.environ[env_var]


def _load_config(config_path: Path = None) -> ClientConfig:
"""This will load in a Gretel config that can be used for making
requests to Gretel's API.
Expand Down Expand Up @@ -264,18 +290,70 @@ def get_session_config() -> ClientConfig:
return _session_client_config


def configure_session(config: Union[str, ClientConfig]):
def configure_session(
config: Optional[Union[str, ClientConfig]] = None,
*,
api_key: Optional[str] = None,
endpoint: Optional[str] = None,
cache: str = "no",
validate: bool = False,
clear: bool = False,
):
"""Updates client config for the session
Args:
config: The config to update. If the config is a string, this function
will attempt to parse it as a Gretel URI.
config: The config to update. This config takes precedence over
other parameters such as ``api_key`` or ``endpoint``.
api_key: Configures your Gretel API key. If ``api_key`` is set to
"prompt" and no Api Key is found on the system, ``getpass``
will be used to prompt for the key.
endpoint: Specifies the Gretel API endpoint. This must be a fully
qualified URL.
cache: Valid options include "yes" and "no". If cache is "no"
the session configuration will not be written to disk. If cache is
"yes", session configuration will be written to disk only if a
configuration doesn't exist.
validate: If set to ``True`` this will check that login credentials
are valid.
clear: If set to ``True`` any existing Gretel credentials will be
removed from the host.
"""
if clear:
clear_gretel_config()

if not config:
config = _load_config()

if isinstance(config, str):
raise NotImplementedError("Gretel URIs are not supported yet.")

if api_key == "prompt":
if config.api_key:
print("Found cached Gretel credentials")
else:
api_key = getpass("Gretel Api Key")

if api_key and api_key.startswith("grt") or endpoint:
config = ClientConfig(endpoint=endpoint, api_key=api_key)

if cache == "yes":
try:
ClientConfig.from_file(_get_config_path())
except Exception:
print("Caching Gretel config to disk.")
write_config(config)

global _session_client_config
if isinstance(config, ClientConfig):
_session_client_config = config
if isinstance(config, str):
raise NotImplementedError("Gretel URIs are not supported yet.")

if validate:
print(f"Using endpoint {config.endpoint}")
try:
print(f"Logged in as {config.email} \u2705")
except Exception:
print("Failed to validate credentials. Please check your config.")
traceback.print_exc()


_custom_logger = None
Expand Down
1 change: 1 addition & 0 deletions src/gretel_client/projects/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class f:
RUNNER_MODE = "runner_mode"
CONTAINER_IMAGE = "container_image"
HANDLER = "handler"
EMAIL = "email"


YES = "yes"
Expand Down
39 changes: 38 additions & 1 deletion tests/gretel_client/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

from pathlib import Path
from unittest.mock import patch
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -60,3 +61,39 @@ def test_configure_preview_features():
_load_config()
) # ensure the session is reloaded with new env variables
assert get_session_config().preview_features_enabled


@patch("gretel_client.config.get_session_config")
@patch("gretel_client.config._get_config_path")
@patch("gretel_client.config.getpass")
@patch("gretel_client.config.write_config")
def test_configure_session_with_cache(
write_config: MagicMock,
get_pass: MagicMock,
_get_config_path: MagicMock,
get_session_config: MagicMock,
):
get_session_config.return_value = None
_get_config_path.return_value = Path("/path/that/does/not/exist")
get_pass.return_value = "grtu..."

with mock.patch.dict("os.environ", {}, clear=True):
configure_session(api_key="prompt", cache="yes")
get_pass.assert_called_once()
write_config.assert_called_once()
assert write_config.call_args[0][0].api_key == "grtu..."

write_config.reset_mock()

configure_session(api_key="grtu...")
write_config.assert_not_called()


@patch("gretel_client.config._get_config_path")
def test_clear_gretel_config(_get_config_path: MagicMock):
_get_config_path.return_value.exists.return_value = False
with mock.patch.dict("os.environ", {}, clear=True):
configure_session(clear=True)
config_path = _get_config_path.return_value
config_path.unlink.assert_called_once()
config_path.parent.rmdir.assert_called_once()

0 comments on commit 072fd03

Please sign in to comment.