Skip to content

Commit

Permalink
feat(add): LLM integration of Cloudflare Workers AI (langchain-ai#14322)
Browse files Browse the repository at this point in the history
Add [Text Generation by Cloudflare Workers
AI](https://developers.cloudflare.com/workers-ai/models/text-generation/).
It's a new LLM integration.

- Dependencies: N/A
  • Loading branch information
cxumol authored and aymeric-roucher committed Dec 11, 2023
1 parent 69c1011 commit 746a49d
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 0 deletions.
127 changes: 127 additions & 0 deletions docs/docs/integrations/llms/cloudflare_workersai.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cloudflare Workers AI\n",
"\n",
"[Cloudflare AI documentation](https://developers.cloudflare.com/workers-ai/models/text-generation/) listed all generative text models available.\n",
"\n",
"Both Cloudflare account ID and API token are required. Find how to obtain them from [this document](https://developers.cloudflare.com/workers-ai/get-started/rest-api/)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.llms.cloudflare_workersai import CloudflareWorkersAI\n",
"from langchain.prompts import PromptTemplate\n",
"\n",
"template = \"\"\"Human: {question}\n",
"\n",
"AI Assistant: \"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Get authentication before running LLM."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"\n",
"my_account_id = getpass.getpass(\"Enter your Cloudflare account ID:\\n\\n\")\n",
"my_api_token = getpass.getpass(\"Enter your Cloudflare API token:\\n\\n\")\n",
"llm = CloudflareWorkersAI(account_id=my_account_id, api_token=my_api_token)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"AI Assistant: Ah, a fascinating question! The answer to why roses are red is a bit complex, but I'll do my best to explain it in a simple and polite manner.\\nRoses are red due to the presence of a pigment called anthocyanin. Anthocyanin is a type of flavonoid, a class of plant compounds that are responsible for the red, purple, and blue colors found in many fruits and vegetables.\\nNow, you might be wondering why roses specifically have this pigment. The answer lies in the evolutionary history of roses. You see, roses have been around for millions of years, and their red color has likely played a crucial role in attracting pollinators like bees and butterflies. These pollinators are drawn to the bright colors of roses, which helps the plants reproduce and spread their seeds.\\nSo, to summarize, the reason roses are red is because of the anthocyanin pigment, which is a result of millions of years of evolutionary pressures shaping the plant's coloration to attract pollinators. I hope that helps clarify things for\""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"question = \"Why are roses red?\"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ah | , | a | most | excellent | question | , | my | dear | human | ! | * | ad | just | s | glass | es | * | The | sky | appears | blue | due | to | a | phenomen | on | known | as | Ray | le | igh | scatter | ing | . | When | sun | light | enters | Earth | ' | s | atmosphere | , | it | enc | oun | ters | tiny | mole | cules | of | g | ases | such | as | nit | ro | gen | and | o | xygen | . | These | mole | cules | scatter | the | light | in | all | directions | , | but | they | scatter | shorter | ( | blue | ) | w | avel | ength | s | more | than | longer | ( | red | ) | w | avel | ength | s | . | This | is | known | as | Ray | le | igh | scatter | ing | . | \n",
" | As | a | result | , | the | blue | light | is | dispers | ed | throughout | the | atmosphere | , | giving | the | sky | its | characteristic | blue | h | ue | . | The | blue | light | appears | more | prominent | during | sun | r | ise | and | sun | set | due | to | the | scatter | ing | of | light | by | the | Earth | ' | s | atmosphere | at | these | times | . | \n",
" | I | hope | this | explanation | has | been | helpful | , | my | dear | human | ! | Is | there | anything | else | you | would | like | to | know | ? | * | sm | iles | * | * | | "
]
}
],
"source": [
"# Using streaming\n",
"for chunk in llm.stream(\"Why is sky blue?\"):\n",
" print(chunk, end=\" | \", flush=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
127 changes: 127 additions & 0 deletions libs/langchain/langchain/llms/cloudflare_workersai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Optional

import requests
from langchain_core.outputs import GenerationChunk

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

logger = logging.getLogger(__name__)


class CloudflareWorkersAI(LLM):
"""Langchain LLM class to help to access Cloudflare Workers AI service.
To use, you must provide an API token and
account ID to access Cloudflare Workers AI, and
pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.cloudflare_workersai import CloudflareWorkersAI
my_account_id = "my_account_id"
my_api_token = "my_secret_api_token"
llm_model = "@cf/meta/llama-2-7b-chat-int8"
cf_ai = CloudflareWorkersAI(
account_id=my_account_id,
api_token=my_api_token,
model=llm_model
)
"""

account_id: str
api_token: str
model: str = "@cf/meta/llama-2-7b-chat-int8"
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
streaming: bool = False
endpoint_url: str = ""

def __init__(self, **kwargs: Any) -> None:
"""Initialize the Cloudflare Workers AI class."""
super().__init__(**kwargs)

self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"

@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "cloudflare"

@property
def _default_params(self) -> Dict[str, Any]:
"""Default parameters"""
return {}

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Identifying parameters"""
return {
"account_id": self.account_id,
"api_token": self.api_token,
"model": self.model,
"base_url": self.base_url,
}

def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response:
"""Call Cloudflare Workers API"""
headers = {"Authorization": f"Bearer {self.api_token}"}
data = {"prompt": prompt, "stream": self.streaming, **params}
response = requests.post(self.endpoint_url, headers=headers, json=data)
return response

def _process_response(self, response: requests.Response) -> str:
"""Process API response"""
if response.ok:
data = response.json()
return data["result"]["response"]
else:
raise ValueError(f"Request failed with status {response.status_code}")

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Streaming prediction"""
original_steaming: bool = self.streaming
self.streaming = True
_response_prefix_count = len("data: ")
_response_stream_end = b"data: [DONE]"
for chunk in self._call_api(prompt, kwargs).iter_lines():
if chunk == _response_stream_end:
break
if len(chunk) > _response_prefix_count:
try:
data = json.loads(chunk[_response_prefix_count:])
except Exception as e:
logger.debug(chunk)
raise e
if data is not None and "response" in data:
yield GenerationChunk(text=data["response"])
if run_manager:
run_manager.on_llm_new_token(data["response"])
logger.debug("stream end")
self.streaming = original_steaming

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Regular prediction"""
if self.streaming:
return "".join(
[c.text for c in self._stream(prompt, stop, run_manager, **kwargs)]
)
else:
response = self._call_api(prompt, kwargs)
return self._process_response(response)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import responses

from langchain.llms.cloudflare_workersai import CloudflareWorkersAI


@responses.activate
def test_cloudflare_workersai_call() -> None:
responses.add(
responses.POST,
"https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8",
json={"result": {"response": "4"}},
status=200,
)

llm = CloudflareWorkersAI(
account_id="my_account_id",
api_token="my_api_token",
model="@cf/meta/llama-2-7b-chat-int8",
)
output = llm("What is 2 + 2?")

assert output == "4"


@responses.activate
def test_cloudflare_workersai_stream() -> None:
response_body = ['data: {"response": "Hello"}', "data: [DONE]"]
responses.add(
responses.POST,
"https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8",
body="\n".join(response_body),
status=200,
)

llm = CloudflareWorkersAI(
account_id="my_account_id",
api_token="my_api_token",
model="@cf/meta/llama-2-7b-chat-int8",
streaming=True,
)

outputs = []
for chunk in llm.stream("Say Hello"):
outputs.append(chunk)

assert "".join(outputs) == "Hello"

0 comments on commit 746a49d

Please sign in to comment.