Skip to content

Commit

Permalink
[FEATURE] Add new get_response_metadata to Navigator SDK
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 36bdc36b35645afd86a01166097a761603b0d6f3
  • Loading branch information
matthewgrossman committed Nov 21, 2024
1 parent 8e4f879 commit ac55e06
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 11 deletions.
13 changes: 12 additions & 1 deletion src/gretel_client/inference_api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from gretel_client.config import ClientConfig, configure_session, get_session_config
from gretel_client.rest.api_client import ApiClient
from gretel_client.rest.configuration import Configuration

MODELS_API_PATH = "/v1/inference/models"

Expand Down Expand Up @@ -143,6 +142,7 @@ class BaseInferenceAPI(ABC):

_available_backend_models: list[str]
_model_type: str
_response_metadata: dict

def __init__(
self,
Expand Down Expand Up @@ -170,6 +170,7 @@ def __init__(
self.endpoint = session.endpoint
self._api_client = session._get_api_client(verify_ssl=verify_ssl)
self._available_backend_models = get_full_navigator_model_list(self._api_client)
self._response_metadata = {}
self.backend_model = backend_model

@abstractproperty
Expand Down Expand Up @@ -240,3 +241,13 @@ def _call_api(
body=body,
headers=headers,
)

def _set_response_metadata(self, response_metadata: dict) -> None:
self._response_metadata = response_metadata

def get_response_metadata(self) -> dict:
if not self._response_metadata:
raise GretelInferenceAPIError(
"Response metadata is only set after a request has completed."
)
return self._response_metadata
1 change: 1 addition & 0 deletions src/gretel_client/inference_api/natural_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ def generate(
},
},
)
self._set_response_metadata(response["metadata"])
return response["text"]
33 changes: 29 additions & 4 deletions src/gretel_client/inference_api/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _reset_stream(self) -> None:
self._last_stream_read = None
self._next_iter = None
self._generated_count = 0
self._set_response_metadata({})

def _create_stream_if_needed(
self,
Expand Down Expand Up @@ -181,8 +182,12 @@ def _stream(
done_generation = False

attempt_count = 0
response_metadata: dict = {}

while self._generated_count < num_records:
# Keep going until we've given the user all their records and the most recent
# stream has been closed. We need to ensure the last stream is closed because
# that's when we receive additional metadata (model_ids, billing, etc).
while self._generated_count < num_records or self._curr_stream_id:
if history_buffer:
this_ref_data["sample_data"] = {
"table_headers": list(history_buffer[0].keys()),
Expand Down Expand Up @@ -215,18 +220,24 @@ def _stream(
for record in data_list:
if record["data_type"] == "logger.info":
logger.debug("%s: %s", self.__class__.__name__, record["data"])
elif record["data_type"] == "ResponseMetadata":
response_metadata = _combine_response_metadata(
response_metadata, json.loads(record["data"])
)
self._set_response_metadata(response_metadata)
elif record["data_type"] == "TabularResponse":
row_data = record["data"]["table_data"]
for row in row_data:
self._generated_count += 1
self._last_stream_read = time.time()
if self._generated_count > num_records:
# If we've already sent the user back all of their requested results,
# don't bother returning more data than they expect.
continue
if sample_buffer_size > 0:
history_buffer.append(row)
history_buffer = history_buffer[-sample_buffer_size:]
yield row
if self._generated_count >= num_records:
done_generation = True
break
elif record["data_type"] == "logger.error":
attempt_count += 1
err_string = record["data"]
Expand Down Expand Up @@ -522,3 +533,17 @@ def generate(
pbar_desc="Generating records",
disable_pbar=disable_progress_bar,
)


def _combine_response_metadata(
response_metadata: dict, new_response_data: dict
) -> dict:
if not response_metadata:
return new_response_data

for k, v in new_response_data["usage"].items():
if not isinstance(v, int):
continue
response_metadata["usage"][k] += v

return response_metadata
28 changes: 26 additions & 2 deletions tests/gretel_client/inference_api/test_tabular.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

from contextlib import nullcontext
from unittest.mock import Mock, patch

Expand All @@ -10,6 +12,21 @@

tabular.STREAM_SLEEP_TIME = 0

# two data points to cap off the end of the stream
ENDING_STREAM_DATA = [
{
"data": [
{
"data_type": "ResponseMetadata",
"data": json.dumps(
{"gretel": "a_model_id", "usage": {"input_bytes": 42}}
),
}
]
},
{"stream_state": {"status": "closed"}, "data": []},
]


@patch.object(api_base, "get_full_navigator_model_list")
@patch.object(api_base, "get_model")
Expand Down Expand Up @@ -121,7 +138,9 @@ def test_generate_timeout(
# this new stream ID which is what we assert
if len(records) == target_count - 5:
api._call_api = Mock(
side_effect=[{"stream_id": "new_stream_123"}] + [api_response] * 15
side_effect=[{"stream_id": "new_stream_123"}]
+ [api_response] * 15
+ ENDING_STREAM_DATA
)
api._last_stream_read -= timeout + 1

Expand All @@ -130,10 +149,15 @@ def test_generate_timeout(

assert len(records) == target_count
assert records[0] == {"foo": "bar"}
assert api._curr_stream_id == "new_stream_123"

# Because we've closed the stream, the stream_id should be reset
assert api._curr_stream_id == None

# We should have a history buffer that was sent as well
stream_body = api._call_api.call_args_list[0].kwargs["body"]
assert stream_body["ref_data"]["sample_data"]["table_headers"] == ["foo"]
assert len(stream_body["ref_data"]["sample_data"]["table_data"]) == 5
assert stream_body["ref_data"]["sample_data"]["table_data"][0] == {"foo": "bar"}

# Since the last stream closed, we should be able to get the metadata
assert api.get_response_metadata()["usage"]["input_bytes"] == 42
35 changes: 31 additions & 4 deletions tests/gretel_client/integration/test_inference_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os

from unittest.mock import patch

import pandas as pd
import pytest

from gretel_client.inference_api import tabular
from gretel_client.inference_api.base import GretelInferenceAPIError
from gretel_client.inference_api.natural_language import NaturalLanguageInferenceAPI
from gretel_client.inference_api.tabular import TabularInferenceAPI

PROMPT = """\
Generate a dataset of characters from the Simpsons.
Expand Down Expand Up @@ -72,7 +74,7 @@ def llm():

@pytest.fixture(scope="module")
def nav():
return TabularInferenceAPI(
return tabular.TabularInferenceAPI(
api_key=os.getenv("GRETEL_API_KEY"),
endpoint="https://api-dev.gretel.cloud",
validate=False,
Expand All @@ -91,17 +93,40 @@ def test_llm_inference_api_generate(llm):


def test_nav_inference_api_generate(nav):
with pytest.raises(GretelInferenceAPIError):
nav.get_response_metadata()

df = nav.generate(PROMPT, num_records=NUM_RECORDS)
assert isinstance(df, pd.DataFrame)
assert len(df) == NUM_RECORDS
metadata = nav.get_response_metadata()
assert metadata["usage"]["input_bytes"] == len(PROMPT)


def test_nav_inference_api_generate_stream(nav):
def test_nav_inference_api_generate_stream(nav: tabular.TabularInferenceAPI):
record_list = []
for record in nav.generate(PROMPT, num_records=NUM_RECORDS, stream=True):
assert isinstance(record, dict)
record_list.append(record)
assert len(record_list) == NUM_RECORDS
metadata = nav.get_response_metadata()
assert metadata["usage"]["input_bytes"] == len(PROMPT)


@patch.object(tabular, "MAX_ROWS_PER_STREAM", 3)
def test_nav_inference_api_generate_multiple_streams(nav: tabular.TabularInferenceAPI):
# request 10 records, where the max a stream can return is 3.
# This should require 4 unique streams to accomplish.
num_records = 10
expected_stream_count = 4

record_list = []
for record in nav.generate(PROMPT, num_records=num_records, stream=True):
assert isinstance(record, dict)
record_list.append(record)
assert len(record_list) == num_records
metadata = nav.get_response_metadata()
assert metadata["usage"]["input_bytes"] == len(PROMPT) * expected_stream_count


@pytest.mark.parametrize(
Expand Down Expand Up @@ -146,7 +171,9 @@ def test_nav_inference_api_edit_stream(nav):

def test_nav_inference_api_invalid_backend_model():
with pytest.raises(GretelInferenceAPIError):
TabularInferenceAPI(backend_model="invalid_model", skip_configure_session=True)
tabular.TabularInferenceAPI(
backend_model="invalid_model", skip_configure_session=True
)


def test_nav_inference_api_edit_invalid_seed_data_type(nav):
Expand Down

0 comments on commit ac55e06

Please sign in to comment.