Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo #407

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions examples/lightrag_api_open_webui_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from datetime import datetime, timezone
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import inspect
import json
from pydantic import BaseModel
from typing import Optional

import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc

import nest_asyncio

WORKING_DIR = "./dickens"

logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)

rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="qwen2.5:latest",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embed(
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
),
),
)

with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())

# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()

app = FastAPI(title="LightRAG", description="LightRAG API open-webui")


# Data models
MODEL_NAME = "LightRAG:latest"


class Message(BaseModel):
role: Optional[str] = None
content: str


class OpenWebUIRequest(BaseModel):
stream: Optional[bool] = None
model: Optional[str] = None
messages: list[Message]


# API routes


@app.get("/")
async def index():
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"


@app.get("/ollama/api/version")
async def ollama_version():
return {"version": "0.4.7"}


@app.get("/ollama/api/tags")
async def ollama_tags():
return {
"models": [
{
"name": MODEL_NAME,
"model": MODEL_NAME,
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
"size": 4683087332,
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
"details": {
"parent_model": "",
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M",
},
}
]
}


@app.post("/ollama/api/chat")
async def ollama_chat(request: OpenWebUIRequest):
resp = rag.query(
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
)
if inspect.isasyncgen(resp):

async def ollama_resp(chunks):
async for chunk in chunks:
yield (
json.dumps(
{
"model": MODEL_NAME,
"created_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
),
"message": {
"role": "assistant",
"content": chunk,
},
"done": False,
},
ensure_ascii=False,
).encode("utf-8")
+ b"\n"
) # the b"\n" is important

return StreamingResponse(ollama_resp(resp), media_type="application/json")
else:
return resp


@app.get("/health")
async def health_check():
return {"status": "healthy"}


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8020)
19 changes: 19 additions & 0 deletions examples/lightrag_ollama_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import os
import inspect
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
Expand Down Expand Up @@ -49,3 +51,20 @@
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

# stream response
resp = rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid", stream=True),
)


async def print_stream(stream):
async for chunk in stream:
print(chunk, end="", flush=True)


if inspect.isasyncgen(resp):
asyncio.run(print_stream(resp))
else:
print(resp)
1 change: 1 addition & 0 deletions lightrag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class QueryParam:
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.
Expand Down
63 changes: 41 additions & 22 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any
from typing import List, Dict, Callable, Any, Union

import aioboto3
import aiohttp
Expand Down Expand Up @@ -36,6 +36,13 @@
get_best_cached_response,
)

import sys

if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator

os.environ["TOKENIZERS_PARALLELISM"] = "false"


Expand Down Expand Up @@ -474,7 +481,8 @@ async def ollama_model_if_cache(
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
Expand Down Expand Up @@ -517,28 +525,39 @@ async def ollama_model_if_cache(
return if_cache_return["return"]

response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """

result = response["message"]["content"]
async def inner():
async for chunk in response:
yield chunk["message"]["content"]

if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": result,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
return inner()
else:
result = response["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": result,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val
if is_embedding_cache_enabled
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
}
)
return result
)
return result


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -816,7 +835,7 @@ async def hf_model_complete(

async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"
Expand Down
3 changes: 2 additions & 1 deletion lightrag/operate.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,9 @@ async def kg_query(
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
if len(response) > len(sys_prompt):
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
Expand Down
Loading