Skip to content

Commit

Permalink
Merge pull request #407 from partoneplay/main
Browse files Browse the repository at this point in the history
Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo
  • Loading branch information
LarFii authored Dec 6, 2024
2 parents 3264fa0 + ad991f9 commit 69b6a9f
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 23 deletions.
140 changes: 140 additions & 0 deletions examples/lightrag_api_open_webui_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from datetime import datetime, timezone
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import inspect
import json
from pydantic import BaseModel
from typing import Optional

import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc

import nest_asyncio

WORKING_DIR = "./dickens"

logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)

rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="qwen2.5:latest",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embed(
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
),
),
)

with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())

# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()

app = FastAPI(title="LightRAG", description="LightRAG API open-webui")


# Data models
MODEL_NAME = "LightRAG:latest"


class Message(BaseModel):
role: Optional[str] = None
content: str


class OpenWebUIRequest(BaseModel):
stream: Optional[bool] = None
model: Optional[str] = None
messages: list[Message]


# API routes


@app.get("/")
async def index():
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"


@app.get("/ollama/api/version")
async def ollama_version():
return {"version": "0.4.7"}


@app.get("/ollama/api/tags")
async def ollama_tags():
return {
"models": [
{
"name": MODEL_NAME,
"model": MODEL_NAME,
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
"size": 4683087332,
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
"details": {
"parent_model": "",
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M",
},
}
]
}


@app.post("/ollama/api/chat")
async def ollama_chat(request: OpenWebUIRequest):
resp = rag.query(
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
)
if inspect.isasyncgen(resp):

async def ollama_resp(chunks):
async for chunk in chunks:
yield (
json.dumps(
{
"model": MODEL_NAME,
"created_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
),
"message": {
"role": "assistant",
"content": chunk,
},
"done": False,
},
ensure_ascii=False,
).encode("utf-8")
+ b"\n"
) # the b"\n" is important

return StreamingResponse(ollama_resp(resp), media_type="application/json")
else:
return resp


@app.get("/health")
async def health_check():
return {"status": "healthy"}


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8020)
19 changes: 19 additions & 0 deletions examples/lightrag_ollama_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import os
import inspect
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
Expand Down Expand Up @@ -49,3 +51,20 @@
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

# stream response
resp = rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid", stream=True),
)


async def print_stream(stream):
async for chunk in stream:
print(chunk, end="", flush=True)


if inspect.isasyncgen(resp):
asyncio.run(print_stream(resp))
else:
print(resp)
1 change: 1 addition & 0 deletions lightrag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class QueryParam:
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.
Expand Down
63 changes: 41 additions & 22 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any
from typing import List, Dict, Callable, Any, Union

import aioboto3
import aiohttp
Expand Down Expand Up @@ -36,6 +36,13 @@
get_best_cached_response,
)

import sys

if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator

os.environ["TOKENIZERS_PARALLELISM"] = "false"


Expand Down Expand Up @@ -474,7 +481,8 @@ async def ollama_model_if_cache(
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
Expand Down Expand Up @@ -517,28 +525,39 @@ async def ollama_model_if_cache(
return if_cache_return["return"]

response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """

result = response["message"]["content"]
async def inner():
async for chunk in response:
yield chunk["message"]["content"]

if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": result,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
return inner()
else:
result = response["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": result,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val
if is_embedding_cache_enabled
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
}
)
return result
)
return result


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -816,7 +835,7 @@ async def hf_model_complete(

async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"
Expand Down
3 changes: 2 additions & 1 deletion lightrag/operate.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,9 @@ async def kg_query(
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
if len(response) > len(sys_prompt):
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
Expand Down

0 comments on commit 69b6a9f

Please sign in to comment.