Skip to content

Commit

Permalink
High-level interface for inference API
Browse files Browse the repository at this point in the history
* implement high-level staging models and inference api

* some refactoring; preparing for more

* implement factory pattern

* remove staging from this PR; add integration tests

* try to fix config issue in test

* hardcode endpoint in test

* move default parameters

* add test for streaming

* address matt k feedback

* add unit test for invalid backend model

* remove f string in logging; unit test for factories; minor cleanup

* address PR comments; update doc strings; add examples

* analysis utils module created

* add analysis utils to build

* add progress bar when not streaming

* make progress bar suppressible

GitOrigin-RevId: 71591768b016301927ae29b601a2092c08f88092
  • Loading branch information
johnnygreco committed Jan 17, 2024
1 parent 77cdcf8 commit 89c4512
Show file tree
Hide file tree
Showing 11 changed files with 789 additions and 34 deletions.
46 changes: 46 additions & 0 deletions src/gretel_client/analysis_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional

from gretel_client.dataframe import _DataFrameT

try:
import pandas as pd

PANDAS_IS_INSTALLED = True
except ImportError:
PANDAS_IS_INSTALLED = False

try:
import IPython

IPYTHON_IS_INSTALLED = True
except ImportError:
IPYTHON_IS_INSTALLED = False


def display_dataframe_in_notebook(
dataframe: _DataFrameT, settings: Optional[dict] = None
) -> None:
"""Display pandas DataFrame in notebook with better settings for readability.
This function is intended to be used in a Jupyter notebook.
Args:
dataframe: The pandas DataFrame to display.
settings: Optional properties to set on the DataFrame's style.
If None, default settings with text wrapping are used.
"""
if not PANDAS_IS_INSTALLED:
raise ImportError("Pandas is required to display dataframes in notebooks.")
if not IPYTHON_IS_INSTALLED:
raise ImportError("IPython is required to display dataframes in notebooks.")
if not isinstance(dataframe, pd.DataFrame):
raise TypeError(
f"Expected `dataframe` to be of type pandas.DataFrame, "
f"you gave {type(dataframe)}"
)
settings = settings or {
"text-align": "left",
"white-space": "normal",
"height": "auto",
}
IPython.display.display(dataframe.style.set_properties(**settings))
59 changes: 59 additions & 0 deletions src/gretel_client/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging
import sys

from typing import Optional

from gretel_client.config import configure_session
from gretel_client.inference_api.base import (
BaseInferenceAPI,
GretelInferenceAPIError,
InferenceAPIModelType,
)
from gretel_client.inference_api.tabular import (
TABLLM_DEFAULT_MODEL,
TabularLLMInferenceAPI,
)

logger = logging.getLogger(__name__)
logger.propagate = False
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.setLevel(logging.INFO)


class GretelFactories:
"""A class for creating objects that interact with Gretel's APIs."""

def __init__(self, **session_kwargs):
if len(session_kwargs) > 0:
configure_session(**session_kwargs)

def initialize_inference_api(
self,
model_type: InferenceAPIModelType = InferenceAPIModelType.TABULAR_LLM,
*,
backend_model: Optional[str] = None,
) -> BaseInferenceAPI:
"""Initializes and returns a gretel inference API object.
Args:
model_type: The type of the inference API model.
backend_model: The model used under the hood by the inference API.
Raises:
GretelInferenceAPIError: If the specified model type is not valid.
Returns:
An instance of the initialized inference API object.
"""
if model_type == InferenceAPIModelType.TABULAR_LLM:
gretel_api = TabularLLMInferenceAPI(
backend_model=backend_model or TABLLM_DEFAULT_MODEL,
)
else:
raise GretelInferenceAPIError(
f"{model_type} is not a valid inference API model type."
f"Valid types are {[t.value for t in InferenceAPIModelType]}"
)
logger.info("API path: %s%s", gretel_api.endpoint, gretel_api.api_path)
logger.info("Initialized %s 🚀", gretel_api.name)
return gretel_api
25 changes: 16 additions & 9 deletions src/gretel_client/gretel/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
)


class ModelName(str, Enum):
@dataclass(frozen=True)
class TabLLMDefaultParams:
temperature: float = 0.7
top_k: int = 40
top_p: float = 0.95


class ModelType(str, Enum):
"""Name of the model parameter dict in the config.
Note: The values are the names used in the model configs.
Expand Down Expand Up @@ -61,21 +68,21 @@ class ModelConfigSections:


CONFIG_SETUP_DICT = {
ModelName.ACTGAN: ModelConfigSections(
ModelType.ACTGAN: ModelConfigSections(
model_name="actgan",
config_sections=["params", "generate", "privacy_filters", "evaluate"],
data_source_optional=False,
report_type=ReportType.SQS,
extra_kwargs=["ref_data"],
),
ModelName.AMPLIFY: ModelConfigSections(
ModelType.AMPLIFY: ModelConfigSections(
model_name="amplify",
config_sections=["params", "evaluate"],
data_source_optional=False,
report_type=ReportType.SQS,
extra_kwargs=["ref_data"],
),
ModelName.LSTM: ModelConfigSections(
ModelType.LSTM: ModelConfigSections(
model_name="lstm",
config_sections=[
"params",
Expand All @@ -90,14 +97,14 @@ class ModelConfigSections:
report_type=ReportType.SQS,
extra_kwargs=["ref_data"],
),
ModelName.TABULAR_DP: ModelConfigSections(
ModelType.TABULAR_DP: ModelConfigSections(
model_name="tabular_dp",
config_sections=["params", "generate", "evaluate"],
data_source_optional=False,
report_type=ReportType.SQS,
extra_kwargs=["ref_data"],
),
ModelName.GPT_X: ModelConfigSections(
ModelType.GPT_X: ModelConfigSections(
model_name="gpt",
config_sections=["params", "generate"],
data_source_optional=True,
Expand All @@ -110,7 +117,7 @@ class ModelConfigSections:
"ref_data",
],
),
ModelName.DGAN: ModelConfigSections(
ModelType.DGAN: ModelConfigSections(
model_name="dgan",
config_sections=["params", "generate"],
data_source_optional=False,
Expand All @@ -137,7 +144,7 @@ def _backwards_compat_transform_config(
"""
model_type, model_config_section = extract_model_config_section(config)
if (
model_type == ModelName.GPT_X.value
model_type == ModelType.GPT_X.value
and "params" in non_default_settings
and "params" not in model_config_section
):
Expand Down Expand Up @@ -192,7 +199,7 @@ def create_model_config_from_base(
"""
config = smart_read_model_config(base_config)
model_type, model_config_section = extract_model_config_section(config)
setup = CONFIG_SETUP_DICT[ModelName(model_type)]
setup = CONFIG_SETUP_DICT[ModelType(model_type)]

config = _backwards_compat_transform_config(config, non_default_settings)

Expand Down
25 changes: 16 additions & 9 deletions src/gretel_client/gretel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from gretel_client.config import configure_session
from gretel_client.dataframe import _DataFrameT
from gretel_client.factories import GretelFactories
from gretel_client.gretel.artifact_fetching import (
fetch_final_model_config,
fetch_model_logs,
Expand Down Expand Up @@ -54,9 +55,10 @@ def _convert_to_valid_data_source(
class Gretel:
"""High-level interface for interacting with Gretel's APIs.
An instance of this class is bound to a single Gretel project. If a project
name is not provided at instantiation, a new project will be created with the
first job submission. You can change projects using the `set_project` method.
To bound an instance of this class to a Gretel project, provide a project
name at instantiation or use the `set_project` method. If a job is submitted
(via a `submit_*` method) without a project set, a randomly-named project will
be created and set as the current project.
Args:
project_name (str): Name of new or existing project. If a new project name
Expand Down Expand Up @@ -95,8 +97,9 @@ def __init__(
):
configure_session(**session_kwargs)

self._project: Optional[Project] = None
self._user_id: str = get_me()["_id"][9:]
self._project: Optional[Project] = None
self.factories = GretelFactories()

if project_name is not None:
self.set_project(name=project_name, display_name=project_display_name)
Expand All @@ -105,18 +108,22 @@ def _assert_project_is_set(self):
"""Raise an error if a project has not been set."""
if self._project is None:
raise GretelProjectNotSetError(
"A project must be set to fetch models and their artifacts. "
"A project must be set to run this method. "
"Use `set_project` to create or select an existing project."
)

def get_project(self) -> Project:
def _generate_random_label(self) -> str:
return f"{uuid.uuid4().hex[:5]}-{self._user_id}"

def get_project(self, **kwargs) -> Project:
"""Returns the current Gretel project.
If a project has not been set, a new one will be created.
If a project has not been set, a new one will be created. The optional
kwargs are the same as those available for the `set_project` method.
"""
if self._project is None:
logger.info("No project set -> creating a new one...")
self.set_project()
self.set_project(**kwargs)
return self._project

def set_project(
Expand All @@ -138,7 +145,7 @@ def set_project(
Raises:
ApiException: If an error occurs while creating the project.
"""
name = name or f"gretel-sdk-{uuid.uuid4().hex[:5]}-{self._user_id}"
name = name or f"gretel-sdk-{self._generate_random_label()}"

try:
project = get_project(
Expand Down
4 changes: 0 additions & 4 deletions src/gretel_client/gretel/job_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def project_url(self) -> str:
def model_url(self) -> str:
...

@abstractproperty
def job_status(self) -> Status:
...


@dataclass
class TrainJobResults(GretelJobResults):
Expand Down
Empty file.
116 changes: 116 additions & 0 deletions src/gretel_client/inference_api/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json

from abc import ABC, abstractproperty
from enum import Enum
from typing import Any, Dict, Optional

from gretel_client import analysis_utils
from gretel_client.config import configure_session, get_session_config
from gretel_client.dataframe import _DataFrameT

MODELS_API_PATH = "/v1/inference/models"


class GretelInferenceAPIError(Exception):
"""Raised when an error occurs with the Inference API."""


class InferenceAPIModelType(str, Enum):
TABULAR_LLM = "tabllm"


class BaseInferenceAPI(ABC):
"""Base class for Gretel Inference API objects."""

def __init__(self, **session_kwargs):
if len(session_kwargs) > 0:
configure_session(**session_kwargs)
session_config = get_session_config()
if session_config.default_runner != "cloud":
raise GretelInferenceAPIError(
"Gretel's Inference API is currently only "
"available within Gretel Cloud. Your current runner "
f"is configured to: {session_config.default_runner}"
)
self._api_client = session_config._get_api_client()
self.endpoint = session_config.endpoint
self._available_backend_models = [
m for m in self._call_api("get", self.models_api_path).get("models", [])
]

@abstractproperty
def api_path(self) -> str:
...

@property
def models_api_path(self) -> str:
return MODELS_API_PATH

@property
def name(self) -> str:
return self.__class__.__name__

def display_dataframe_in_notebook(
self, dataframe: _DataFrameT, settings: Optional[dict] = None
) -> None:
"""Display pandas DataFrame in notebook with better settings for readability.
This function is intended to be used in a Jupyter notebook.
Args:
dataframe: The pandas DataFrame to display.
settings: Optional properties to set on the DataFrame's style.
If None, default settings with text wrapping are used.
"""
analysis_utils.display_dataframe_in_notebook(dataframe, settings)

def _call_api(
self,
method: str,
path: str,
query_params: Optional[dict] = None,
body: Optional[dict] = None,
headers: Optional[dict] = None,
) -> Dict[str, Any]:
"""Make a direct API call to Gretel Cloud.
Args:
method: "get", "post", etc
path: The full request path, any path params must be already included.
query_params: Optional URL based query parameters
body: An optional JSON payload to send
headers: Any custom headers that need to bet set.
NOTE:
This function will automatically inject the appropriate API hostname and
authentication from the Gretel configuration.
"""
if headers is None:
headers = {}

method = method.upper()

if not path.startswith("/"):
path = "/" + path

# Utilize the ApiClient method to inject the proper authentication
# into our headers, since Gretel only uses header-based auth we don't
# need to pass any other data into this
#
# NOTE: This function does a pointer-like update of ``headers``
self._api_client.update_params_for_auth(
headers=headers,
querys=None,
auth_settings=self._api_client.configuration.auth_settings(),
resource_path=None,
method=None,
body=None,
)

url = self._api_client.configuration.host + path

response = self._api_client.request(
method, url, query_params=query_params, body=body, headers=headers
)

return json.loads(response.data.decode())
Loading

0 comments on commit 89c4512

Please sign in to comment.