Skip to content

Commit

Permalink
Merge pull request #1105 from guardrails-ai/fix_server_export_mismatch
Browse files Browse the repository at this point in the history
fix missing exports for server
  • Loading branch information
zsimjee authored Sep 30, 2024
2 parents 575a3bc + 618ad21 commit 667c4a4
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 10 deletions.
12 changes: 10 additions & 2 deletions guardrails/async_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,11 @@ async def _exec(
output=llm_output,
base_model=self._base_model,
full_schema_reask=full_schema_reask,
disable_tracer=(not self._allow_metrics_collection),
disable_tracer=(
not self._allow_metrics_collection
if isinstance(self._allow_metrics_collection, bool)
else None
),
exec_options=self._exec_opts,
)
# Here we have an async generator
Expand All @@ -391,7 +395,11 @@ async def _exec(
output=llm_output,
base_model=self._base_model,
full_schema_reask=full_schema_reask,
disable_tracer=(not self._allow_metrics_collection),
disable_tracer=(
not self._allow_metrics_collection
if isinstance(self._allow_metrics_collection, bool)
else None
),
exec_options=self._exec_opts,
)
# Why are we using a different method here instead of just overriding?
Expand Down
12 changes: 10 additions & 2 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,11 @@ def _exec(
output=llm_output,
base_model=self._base_model,
full_schema_reask=full_schema_reask,
disable_tracer=(not self._allow_metrics_collection),
disable_tracer=(
not self._allow_metrics_collection
if isinstance(self._allow_metrics_collection, bool)
else None
),
exec_options=self._exec_opts,
)
return runner(call_log=call_log, prompt_params=prompt_params)
Expand All @@ -927,7 +931,11 @@ def _exec(
output=llm_output,
base_model=self._base_model,
full_schema_reask=full_schema_reask,
disable_tracer=(not self._allow_metrics_collection),
disable_tracer=(
not self._allow_metrics_collection
if isinstance(self._allow_metrics_collection, bool)
else None
),
exec_options=self._exec_opts,
)
call = runner(call_log=call_log, prompt_params=prompt_params)
Expand Down
2 changes: 1 addition & 1 deletion guardrails/hub_telemetry/hub_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def wrapper(*args, **kwargs):
nonlocal origin
origin = origin if origin is not None else name
add_attributes(span, attrs, name, origin, *args, **kwargs)
return _run_async_gen(fn, *args, **kwargs)
return fn(*args, **kwargs)
else:
return fn(*args, **kwargs)

Expand Down
6 changes: 5 additions & 1 deletion guardrails/run/async_stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ async def async_step(
validate_subschema=True,
stream=True,
)
# TODO why? how does it happen in the other places we handle streams
if validated_fragment is None:
validated_fragment = ""

if isinstance(validated_fragment, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
Expand All @@ -165,7 +169,7 @@ async def async_step(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
)
validation_response += cast(str, validated_fragment)
validation_response += validated_fragment
passed = call_log.status == pass_status
yield ValidationOutcome(
call_id=call_log.id, # type: ignore
Expand Down
20 changes: 17 additions & 3 deletions guardrails/telemetry/guard_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

from opentelemetry import context, trace
from opentelemetry.trace import StatusCode, Tracer, Span
from opentelemetry.trace import StatusCode, Tracer, Span, Link, get_tracer

from guardrails.settings import settings
from guardrails.classes.generic.stack import Stack
Expand All @@ -22,6 +22,10 @@
from guardrails.telemetry.common import add_user_attributes
from guardrails.version import GUARDRAILS_VERSION

import sys

if sys.version_info.minor < 10:
from guardrails.utils.polyfills import anext

# from sentence_transformers import SentenceTransformer
# import numpy as np
Expand Down Expand Up @@ -195,8 +199,18 @@ async def trace_async_stream_guard(
while next_exists:
try:
res = await anext(result) # type: ignore
add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
if not guard_span.is_recording():
# Assuming you have a tracer instance
tracer = get_tracer(__name__)
# Create a new span and link it to the previous span
with tracer.start_as_current_span(
"new_guard_span", # type: ignore
links=[Link(guard_span.get_span_context())],
) as new_span:
guard_span = new_span

add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
except StopIteration:
next_exists = False
Expand Down
4 changes: 4 additions & 0 deletions guardrails/telemetry/runner_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from guardrails.utils.safe_get import safe_get
from guardrails.version import GUARDRAILS_VERSION

import sys

if sys.version_info.minor < 10:
from guardrails.utils.polyfills import anext

#########################################
### START Runner.step Instrumentation ###
Expand Down
1 change: 0 additions & 1 deletion guardrails/utils/hub_telemetry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def initialize_tracer(
"""Initializes a tracer for Guardrails Hub."""
if enabled is None:
enabled = settings.rc.enable_metrics or False

self._enabled = enabled
self._carrier = {}
self._service_name = service_name
Expand Down
8 changes: 8 additions & 0 deletions guardrails/utils/openai_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
is_static_openai_chat_acreate_func,
is_static_openai_chat_create_func,
is_static_openai_create_func,
get_static_openai_create_func,
get_static_openai_chat_create_func,
get_static_openai_acreate_func,
get_static_openai_chat_acreate_func,
)

__all__ = [
Expand All @@ -16,4 +20,8 @@
"is_static_openai_acreate_func",
"is_static_openai_chat_acreate_func",
"OpenAIServiceUnavailableError",
"get_static_openai_create_func",
"get_static_openai_chat_create_func",
"get_static_openai_acreate_func",
"get_static_openai_chat_acreate_func",
]
33 changes: 33 additions & 0 deletions guardrails/utils/openai_utils/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import openai

import warnings
from guardrails.classes.llm.llm_response import LLMResponse
from guardrails.utils.openai_utils.base import BaseOpenAIClient
from guardrails.utils.openai_utils.streaming_utils import (
Expand All @@ -12,6 +13,38 @@
from guardrails.telemetry import trace_llm_call, trace_operation


def get_static_openai_create_func():
warnings.warn(
"This function is deprecated. " " and will be removed in 0.6.0",
DeprecationWarning,
)
return openai.completions.create


def get_static_openai_chat_create_func():
warnings.warn(
"This function is deprecated and will be removed in 0.6.0",
DeprecationWarning,
)
return openai.chat.completions.create


def get_static_openai_acreate_func():
warnings.warn(
"This function is deprecated and will be removed in 0.6.0",
DeprecationWarning,
)
return None


def get_static_openai_chat_acreate_func():
warnings.warn(
"This function is deprecated and will be removed in 0.6.0",
DeprecationWarning,
)
return None


def is_static_openai_create_func(llm_api: Optional[Callable]) -> bool:
try:
return llm_api == openai.completions.create
Expand Down
184 changes: 184 additions & 0 deletions tests/integration_tests/test_async_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# 3 tests
# 1. Test streaming with OpenAICallable (mock openai.Completion.create)
# 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create)
# 3. Test string schema streaming
# Using the LowerCase Validator, and a custom validator to show new streaming behavior
from typing import Any, Callable, Dict, List, Optional, Union

import asyncio
import pytest

import guardrails as gd
from guardrails.utils.casting_utils import to_int
from guardrails.validator_base import (
ErrorSpan,
FailResult,
OnFailAction,
PassResult,
ValidationResult,
Validator,
register_validator,
)
from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII


@register_validator(name="minsentencelength", data_type=["string", "list"])
class MinSentenceLengthValidator(Validator):
def __init__(
self,
min: Optional[int] = None,
max: Optional[int] = None,
on_fail: Optional[Callable] = None,
):
super().__init__(
on_fail=on_fail,
min=min,
max=max,
)
self._min = to_int(min)
self._max = to_int(max)

def sentence_split(self, value):
return list(map(lambda x: x + ".", value.split(".")[:-1]))

def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult:
sentences = self.sentence_split(value)
error_spans = []
index = 0
for sentence in sentences:
if len(sentence) < self._min:
error_spans.append(
ErrorSpan(
start=index,
end=index + len(sentence),
reason=f"Sentence has length less than {self._min}. "
f"Please return a longer output, "
f"that is shorter than {self._max} characters.",
)
)
if len(sentence) > self._max:
error_spans.append(
ErrorSpan(
start=index,
end=index + len(sentence),
reason=f"Sentence has length greater than {self._max}. "
f"Please return a shorter output, "
f"that is shorter than {self._max} characters.",
)
)
index = index + len(sentence)
if len(error_spans) > 0:
return FailResult(
validated_chunk=value,
error_spans=error_spans,
error_message=f"Sentence has length less than {self._min}. "
f"Please return a longer output, "
f"that is shorter than {self._max} characters.",
)
return PassResult(validated_chunk=value)

def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult:
return super().validate_stream(chunk, metadata, **kwargs)


class Delta:
content: str

def __init__(self, content):
self.content = content


class Choice:
text: str
finish_reason: str
index: int
delta: Delta

def __init__(self, text, delta, finish_reason, index=0):
self.index = index
self.delta = delta
self.text = text
self.finish_reason = finish_reason


class MockOpenAIV1ChunkResponse:
choices: list
model: str

def __init__(self, choices, model):
self.choices = choices
self.model = model


class Response:
def __init__(self, chunks):
self.chunks = chunks

async def gen():
for chunk in self.chunks:
yield MockOpenAIV1ChunkResponse(
choices=[
Choice(
delta=Delta(content=chunk),
text=chunk,
finish_reason=None,
)
],
model="OpenAI model name",
)
await asyncio.sleep(0) # Yield control to the event loop

self.completion_stream = gen()


POETRY_CHUNKS = [
"John, under ",
"GOLDEN bridges",
", roams,\n",
"SAN Francisco's ",
"hills, his HOME.\n",
"Dreams of",
" FOG, and salty AIR,\n",
"In his HEART",
", he's always THERE.",
]


@pytest.mark.asyncio
async def test_filter_behavior(mocker):
mocker.patch(
"litellm.acompletion",
return_value=Response(POETRY_CHUNKS),
)

guard = gd.AsyncGuard().use_many(
MockDetectPII(
on_fail=OnFailAction.FIX,
pii_entities="pii",
replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"},
),
LowerCase(on_fail=OnFailAction.FILTER),
)
prompt = """Write me a 4 line poem about John in San Francisco.
Make every third word all caps."""
gen = await guard(
model="gpt-3.5-turbo",
max_tokens=10,
temperature=0,
stream=True,
prompt=prompt,
)

text = ""
final_res = None
async for res in gen:
final_res = res
text += res.validated_output

assert final_res.raw_llm_output == ", he's always THERE."
# TODO deep dive this
assert text == (
"John, under GOLDEN bridges, roams,\n"
"SAN Francisco's Dreams of FOG, and salty AIR,\n"
"In his HEART"
)

0 comments on commit 667c4a4

Please sign in to comment.