Skip to content

Commit

Permalink
Backport PR jupyterlab#653: Use new langchain-openai partner package
Browse files Browse the repository at this point in the history
  • Loading branch information
startakovsky authored and meeseeksmachine committed Mar 27, 2024
1 parent c2ba62a commit 1a0740e
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 118 deletions.
4 changes: 2 additions & 2 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ Jupyter AI supports the following model providers:
| GPT4All | `gpt4all` | N/A | `gpt4all` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| NVIDIA | `nvidia-chat` | `NVIDIA_API_KEY` | `langchain_nvidia_ai_endpoints` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `openai` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `langchain-openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `langchain-openai` |
| SageMaker | `sagemaker-endpoint` | N/A | `boto3` |

The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
Expand Down
4 changes: 0 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
CohereEmbeddingsProvider,
GPT4AllEmbeddingsProvider,
HfHubEmbeddingsProvider,
OpenAIEmbeddingsProvider,
QianfanEmbeddingsEndpointProvider,
)
from .exception import store_exception
Expand All @@ -17,16 +16,13 @@
from .providers import (
AI21Provider,
AnthropicProvider,
AzureChatOpenAIProvider,
BaseProvider,
BedrockChatProvider,
BedrockProvider,
ChatAnthropicProvider,
ChatOpenAIProvider,
CohereProvider,
GPT4AllProvider,
HfHubProvider,
OpenAIProvider,
QianfanProvider,
SmEndpointProvider,
TogetherAIProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, **model_kwargs)


class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
id = "openai"
name = "OpenAI"
models = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large",
]
model_id_key = "model"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")


class CohereEmbeddingsProvider(BaseEmbeddingsProvider, CohereEmbeddings):
id = "cohere"
name = "Cohere"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI, OpenAIEmbeddings

from ..embedding_providers import BaseEmbeddingsProvider
from ..providers import BaseProvider, EnvAuthStrategy, TextField


class OpenAIProvider(BaseProvider, OpenAI):
id = "openai"
name = "OpenAI"
models = ["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct"]
model_id_key = "model_name"
pypi_package_deps = ["langchain_openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
id = "openai-chat"
name = "OpenAI"
models = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-0301", # Deprecated as of 2024-06-13
"gpt-3.5-turbo-0613", # Deprecated as of 2024-06-13
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613", # Deprecated as of 2024-06-13
"gpt-4",
"gpt-4-turbo-preview",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
]
model_id_key = "model_name"
pypi_package_deps = ["langchain_openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

fields = [
TextField(
key="openai_api_base", label="Base API URL (optional)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
id = "azure-chat-openai"
name = "Azure OpenAI"
models = ["*"]
model_id_key = "deployment_name"
model_id_label = "Deployment name"
pypi_package_deps = ["langchain_openai"]
auth_strategy = EnvAuthStrategy(name="AZURE_OPENAI_API_KEY")
registry = True

fields = [
TextField(
key="openai_api_base", label="Base API URL (required)", format="text"
),
TextField(
key="openai_api_version", label="API version (required)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]


class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
id = "openai"
name = "OpenAI"
models = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large",
]
model_id_key = "model"
pypi_package_deps = ["langchain_openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
93 changes: 0 additions & 93 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain_community.chat_models import (
AzureChatOpenAI,
BedrockChat,
ChatAnthropic,
ChatOpenAI,
QianfanChatEndpoint,
)
from langchain_community.llms import (
Expand Down Expand Up @@ -587,97 +585,6 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)


class OpenAIProvider(BaseProvider, OpenAI):
id = "openai"
name = "OpenAI"
models = ["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct"]
model_id_key = "model_name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
id = "openai-chat"
name = "OpenAI"
models = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-0301", # Deprecated as of 2024-06-13
"gpt-3.5-turbo-0613", # Deprecated as of 2024-06-13
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613", # Deprecated as of 2024-06-13
"gpt-4",
"gpt-4-turbo-preview",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
]
model_id_key = "model_name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

fields = [
TextField(
key="openai_api_base", label="Base API URL (optional)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
id = "azure-chat-openai"
name = "Azure OpenAI"
models = ["*"]
model_id_key = "deployment_name"
model_id_label = "Deployment name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="AZURE_OPENAI_API_KEY")
registry = True

fields = [
TextField(
key="openai_api_base", label="Base API URL (required)", format="text"
),
TextField(
key="openai_api_version", label="API version (required)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]


class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
Expand Down
10 changes: 5 additions & 5 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ all = [
"huggingface_hub",
"ipywidgets",
"langchain_nvidia_ai_endpoints",
"langchain-openai",
"pillow",
"openai~=1.6.1",
"boto3",
"qianfan",
"together",
Expand All @@ -56,9 +56,9 @@ anthropic = "jupyter_ai_magics:AnthropicProvider"
cohere = "jupyter_ai_magics:CohereProvider"
gpt4all = "jupyter_ai_magics:GPT4AllProvider"
huggingface_hub = "jupyter_ai_magics:HfHubProvider"
openai = "jupyter_ai_magics:OpenAIProvider"
openai-chat = "jupyter_ai_magics:ChatOpenAIProvider"
azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider"
openai = "jupyter_ai_magics.partner_providers.openai:OpenAIProvider"
openai-chat = "jupyter_ai_magics.partner_providers.openai:ChatOpenAIProvider"
azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIProvider"
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"
Expand All @@ -73,7 +73,7 @@ bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider"
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider"
openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider"
qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"

[tool.hatch.version]
Expand Down
1 change: 0 additions & 1 deletion packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies = [
"jupyterlab>=3.5,<4",
"aiosqlite>=0.18",
"importlib_metadata>=5.2.0",
"tiktoken", # required for OpenAIEmbeddings
"jupyter_ai_magics",
"dask[distributed]",
"faiss-cpu", # Not distributed by official repo
Expand Down

0 comments on commit 1a0740e

Please sign in to comment.