-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
High-level interface for inference API
* implement high-level staging models and inference api * some refactoring; preparing for more * implement factory pattern * remove staging from this PR; add integration tests * try to fix config issue in test * hardcode endpoint in test * move default parameters * add test for streaming * address matt k feedback * add unit test for invalid backend model * remove f string in logging; unit test for factories; minor cleanup * address PR comments; update doc strings; add examples * analysis utils module created * add analysis utils to build * add progress bar when not streaming * make progress bar suppressible GitOrigin-RevId: 71591768b016301927ae29b601a2092c08f88092
- Loading branch information
1 parent
77cdcf8
commit 89c4512
Showing
11 changed files
with
789 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Optional | ||
|
||
from gretel_client.dataframe import _DataFrameT | ||
|
||
try: | ||
import pandas as pd | ||
|
||
PANDAS_IS_INSTALLED = True | ||
except ImportError: | ||
PANDAS_IS_INSTALLED = False | ||
|
||
try: | ||
import IPython | ||
|
||
IPYTHON_IS_INSTALLED = True | ||
except ImportError: | ||
IPYTHON_IS_INSTALLED = False | ||
|
||
|
||
def display_dataframe_in_notebook( | ||
dataframe: _DataFrameT, settings: Optional[dict] = None | ||
) -> None: | ||
"""Display pandas DataFrame in notebook with better settings for readability. | ||
This function is intended to be used in a Jupyter notebook. | ||
Args: | ||
dataframe: The pandas DataFrame to display. | ||
settings: Optional properties to set on the DataFrame's style. | ||
If None, default settings with text wrapping are used. | ||
""" | ||
if not PANDAS_IS_INSTALLED: | ||
raise ImportError("Pandas is required to display dataframes in notebooks.") | ||
if not IPYTHON_IS_INSTALLED: | ||
raise ImportError("IPython is required to display dataframes in notebooks.") | ||
if not isinstance(dataframe, pd.DataFrame): | ||
raise TypeError( | ||
f"Expected `dataframe` to be of type pandas.DataFrame, " | ||
f"you gave {type(dataframe)}" | ||
) | ||
settings = settings or { | ||
"text-align": "left", | ||
"white-space": "normal", | ||
"height": "auto", | ||
} | ||
IPython.display.display(dataframe.style.set_properties(**settings)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import logging | ||
import sys | ||
|
||
from typing import Optional | ||
|
||
from gretel_client.config import configure_session | ||
from gretel_client.inference_api.base import ( | ||
BaseInferenceAPI, | ||
GretelInferenceAPIError, | ||
InferenceAPIModelType, | ||
) | ||
from gretel_client.inference_api.tabular import ( | ||
TABLLM_DEFAULT_MODEL, | ||
TabularLLMInferenceAPI, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.propagate = False | ||
logger.addHandler(logging.StreamHandler(sys.stdout)) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
class GretelFactories: | ||
"""A class for creating objects that interact with Gretel's APIs.""" | ||
|
||
def __init__(self, **session_kwargs): | ||
if len(session_kwargs) > 0: | ||
configure_session(**session_kwargs) | ||
|
||
def initialize_inference_api( | ||
self, | ||
model_type: InferenceAPIModelType = InferenceAPIModelType.TABULAR_LLM, | ||
*, | ||
backend_model: Optional[str] = None, | ||
) -> BaseInferenceAPI: | ||
"""Initializes and returns a gretel inference API object. | ||
Args: | ||
model_type: The type of the inference API model. | ||
backend_model: The model used under the hood by the inference API. | ||
Raises: | ||
GretelInferenceAPIError: If the specified model type is not valid. | ||
Returns: | ||
An instance of the initialized inference API object. | ||
""" | ||
if model_type == InferenceAPIModelType.TABULAR_LLM: | ||
gretel_api = TabularLLMInferenceAPI( | ||
backend_model=backend_model or TABLLM_DEFAULT_MODEL, | ||
) | ||
else: | ||
raise GretelInferenceAPIError( | ||
f"{model_type} is not a valid inference API model type." | ||
f"Valid types are {[t.value for t in InferenceAPIModelType]}" | ||
) | ||
logger.info("API path: %s%s", gretel_api.endpoint, gretel_api.api_path) | ||
logger.info("Initialized %s 🚀", gretel_api.name) | ||
return gretel_api |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import json | ||
|
||
from abc import ABC, abstractproperty | ||
from enum import Enum | ||
from typing import Any, Dict, Optional | ||
|
||
from gretel_client import analysis_utils | ||
from gretel_client.config import configure_session, get_session_config | ||
from gretel_client.dataframe import _DataFrameT | ||
|
||
MODELS_API_PATH = "/v1/inference/models" | ||
|
||
|
||
class GretelInferenceAPIError(Exception): | ||
"""Raised when an error occurs with the Inference API.""" | ||
|
||
|
||
class InferenceAPIModelType(str, Enum): | ||
TABULAR_LLM = "tabllm" | ||
|
||
|
||
class BaseInferenceAPI(ABC): | ||
"""Base class for Gretel Inference API objects.""" | ||
|
||
def __init__(self, **session_kwargs): | ||
if len(session_kwargs) > 0: | ||
configure_session(**session_kwargs) | ||
session_config = get_session_config() | ||
if session_config.default_runner != "cloud": | ||
raise GretelInferenceAPIError( | ||
"Gretel's Inference API is currently only " | ||
"available within Gretel Cloud. Your current runner " | ||
f"is configured to: {session_config.default_runner}" | ||
) | ||
self._api_client = session_config._get_api_client() | ||
self.endpoint = session_config.endpoint | ||
self._available_backend_models = [ | ||
m for m in self._call_api("get", self.models_api_path).get("models", []) | ||
] | ||
|
||
@abstractproperty | ||
def api_path(self) -> str: | ||
... | ||
|
||
@property | ||
def models_api_path(self) -> str: | ||
return MODELS_API_PATH | ||
|
||
@property | ||
def name(self) -> str: | ||
return self.__class__.__name__ | ||
|
||
def display_dataframe_in_notebook( | ||
self, dataframe: _DataFrameT, settings: Optional[dict] = None | ||
) -> None: | ||
"""Display pandas DataFrame in notebook with better settings for readability. | ||
This function is intended to be used in a Jupyter notebook. | ||
Args: | ||
dataframe: The pandas DataFrame to display. | ||
settings: Optional properties to set on the DataFrame's style. | ||
If None, default settings with text wrapping are used. | ||
""" | ||
analysis_utils.display_dataframe_in_notebook(dataframe, settings) | ||
|
||
def _call_api( | ||
self, | ||
method: str, | ||
path: str, | ||
query_params: Optional[dict] = None, | ||
body: Optional[dict] = None, | ||
headers: Optional[dict] = None, | ||
) -> Dict[str, Any]: | ||
"""Make a direct API call to Gretel Cloud. | ||
Args: | ||
method: "get", "post", etc | ||
path: The full request path, any path params must be already included. | ||
query_params: Optional URL based query parameters | ||
body: An optional JSON payload to send | ||
headers: Any custom headers that need to bet set. | ||
NOTE: | ||
This function will automatically inject the appropriate API hostname and | ||
authentication from the Gretel configuration. | ||
""" | ||
if headers is None: | ||
headers = {} | ||
|
||
method = method.upper() | ||
|
||
if not path.startswith("/"): | ||
path = "/" + path | ||
|
||
# Utilize the ApiClient method to inject the proper authentication | ||
# into our headers, since Gretel only uses header-based auth we don't | ||
# need to pass any other data into this | ||
# | ||
# NOTE: This function does a pointer-like update of ``headers`` | ||
self._api_client.update_params_for_auth( | ||
headers=headers, | ||
querys=None, | ||
auth_settings=self._api_client.configuration.auth_settings(), | ||
resource_path=None, | ||
method=None, | ||
body=None, | ||
) | ||
|
||
url = self._api_client.configuration.host + path | ||
|
||
response = self._api_client.request( | ||
method, url, query_params=query_params, body=body, headers=headers | ||
) | ||
|
||
return json.loads(response.data.decode()) |
Oops, something went wrong.