From 618c79915cd0096f606eb236853cb9dc22b6f939 Mon Sep 17 00:00:00 2001 From: Stainless Bot Date: Wed, 15 Nov 2023 16:54:12 +0000 Subject: [PATCH] feat(client): support reading the base url from an env variable --- README.md | 1 + src/openai/_client.py | 4 ++++ tests/test_client.py | 12 ++++++++++++ tests/utils.py | 17 ++++++++++++++++- 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e7e65828b8..82eeb57ccb 100644 --- a/README.md +++ b/README.md @@ -437,6 +437,7 @@ import httpx from openai import OpenAI client = OpenAI( + # Or use the `OPENAI_BASE_URL` env var base_url="http://my.test.server.example.com:8083", http_client=httpx.Client( proxies="http://my.test.proxy.example.com", diff --git a/src/openai/_client.py b/src/openai/_client.py index 7820d5f96d..6664dc4233 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -99,6 +99,8 @@ def __init__( organization = os.environ.get("OPENAI_ORG_ID") self.organization = organization + if base_url is None: + base_url = os.environ.get("OPENAI_BASE_URL") if base_url is None: base_url = f"https://api.openai.com/v1" @@ -307,6 +309,8 @@ def __init__( organization = os.environ.get("OPENAI_ORG_ID") self.organization = organization + if base_url is None: + base_url = os.environ.get("OPENAI_BASE_URL") if base_url is None: base_url = f"https://api.openai.com/v1" diff --git a/tests/test_client.py b/tests/test_client.py index e3daa4d2b1..e295d193e8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -26,6 +26,8 @@ make_request_options, ) +from .utils import update_env + base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") api_key = "My API Key" @@ -399,6 +401,11 @@ class Model2(BaseModel): assert isinstance(response, Model1) assert response.foo == 1 + def test_base_url_env(self) -> None: + with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"): + client = OpenAI(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + @pytest.mark.parametrize( "client", [ @@ -932,6 +939,11 @@ class Model2(BaseModel): assert isinstance(response, Model1) assert response.foo == 1 + def test_base_url_env(self) -> None: + with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"): + client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + @pytest.mark.parametrize( "client", [ diff --git a/tests/utils.py b/tests/utils.py index 3cccab223a..b513794017 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,9 @@ from __future__ import annotations +import os import traceback -from typing import Any, TypeVar, cast +import contextlib +from typing import Any, TypeVar, Iterator, cast from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type @@ -103,3 +105,16 @@ def _assert_list_type(type_: type[object], value: object) -> None: inner_type = get_args(type_)[0] for entry in value: assert_type(inner_type, entry) # type: ignore + + +@contextlib.contextmanager +def update_env(**new_env: str) -> Iterator[None]: + old = os.environ.copy() + + try: + os.environ.update(new_env) + + yield None + finally: + os.environ.clear() + os.environ.update(old)