From 335179196aba2e510979cbce9b09bdbae6d2fd4d Mon Sep 17 00:00:00 2001 From: partoneplay Date: Fri, 6 Dec 2024 08:48:55 +0800 Subject: [PATCH] Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo --- examples/lightrag_api_open_webui_demo.py | 140 +++++++++++++++++++++++ examples/lightrag_ollama_demo.py | 19 +++ lightrag/base.py | 1 + lightrag/llm.py | 63 ++++++---- lightrag/operate.py | 3 +- 5 files changed, 203 insertions(+), 23 deletions(-) create mode 100644 examples/lightrag_api_open_webui_demo.py diff --git a/examples/lightrag_api_open_webui_demo.py b/examples/lightrag_api_open_webui_demo.py new file mode 100644 index 00000000..17e1817e --- /dev/null +++ b/examples/lightrag_api_open_webui_demo.py @@ -0,0 +1,140 @@ +from datetime import datetime, timezone +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +import inspect +import json +from pydantic import BaseModel +from typing import Optional + +import os +import logging +from lightrag import LightRAG, QueryParam +from lightrag.llm import ollama_model_complete, ollama_embed +from lightrag.utils import EmbeddingFunc + +import nest_asyncio + +WORKING_DIR = "./dickens" + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=ollama_model_complete, + llm_model_name="qwen2.5:latest", + llm_model_max_async=4, + llm_model_max_token_size=32768, + llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}}, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: ollama_embed( + texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434" + ), + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +app = FastAPI(title="LightRAG", description="LightRAG API open-webui") + + +# Data models +MODEL_NAME = "LightRAG:latest" + + +class Message(BaseModel): + role: Optional[str] = None + content: str + + +class OpenWebUIRequest(BaseModel): + stream: Optional[bool] = None + model: Optional[str] = None + messages: list[Message] + + +# API routes + + +@app.get("/") +async def index(): + return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings" + + +@app.get("/ollama/api/version") +async def ollama_version(): + return {"version": "0.4.7"} + + +@app.get("/ollama/api/tags") +async def ollama_tags(): + return { + "models": [ + { + "name": MODEL_NAME, + "model": MODEL_NAME, + "modified_at": "2024-11-12T20:22:37.561463923+08:00", + "size": 4683087332, + "digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e", + "details": { + "parent_model": "", + "format": "gguf", + "family": "qwen2", + "families": ["qwen2"], + "parameter_size": "7.6B", + "quantization_level": "Q4_K_M", + }, + } + ] + } + + +@app.post("/ollama/api/chat") +async def ollama_chat(request: OpenWebUIRequest): + resp = rag.query( + request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True) + ) + if inspect.isasyncgen(resp): + + async def ollama_resp(chunks): + async for chunk in chunks: + yield ( + json.dumps( + { + "model": MODEL_NAME, + "created_at": datetime.now(timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ" + ), + "message": { + "role": "assistant", + "content": chunk, + }, + "done": False, + }, + ensure_ascii=False, + ).encode("utf-8") + + b"\n" + ) # the b"\n" is important + + return StreamingResponse(ollama_resp(resp), media_type="application/json") + else: + return resp + + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8020) diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index 1a320d13..162900c4 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -1,4 +1,6 @@ +import asyncio import os +import inspect import logging from lightrag import LightRAG, QueryParam from lightrag.llm import ollama_model_complete, ollama_embedding @@ -49,3 +51,20 @@ print( rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) ) + +# stream response +resp = rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid", stream=True), +) + + +async def print_stream(stream): + async for chunk in stream: + print(chunk, end="", flush=True) + + +if inspect.isasyncgen(resp): + asyncio.run(print_stream(resp)) +else: + print(resp) diff --git a/lightrag/base.py b/lightrag/base.py index ea84c000..f5a6e0c0 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -19,6 +19,7 @@ class QueryParam: only_need_context: bool = False only_need_prompt: bool = False response_type: str = "Multiple Paragraphs" + stream: bool = False # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. top_k: int = 60 # Number of document chunks to retrieve. diff --git a/lightrag/llm.py b/lightrag/llm.py index 33fdd182..418a66d2 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -27,7 +27,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM import torch from pydantic import BaseModel, Field -from typing import List, Dict, Callable, Any +from typing import List, Dict, Callable, Any, Union from .base import BaseKVStorage from .utils import ( compute_args_hash, @@ -37,6 +37,13 @@ get_best_cached_response, ) +import sys + +if sys.version_info < (3, 9): + from typing import AsyncIterator +else: + from collections.abc import AsyncIterator + os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -454,7 +461,8 @@ async def ollama_model_if_cache( system_prompt=None, history_messages=[], **kwargs, -) -> str: +) -> Union[str, AsyncIterator[str]]: + stream = True if kwargs.get("stream") else False kwargs.pop("max_tokens", None) # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) @@ -494,28 +502,39 @@ async def ollama_model_if_cache( return if_cache_return["return"] response = await ollama_client.chat(model=model, messages=messages, **kwargs) + if stream: + """ cannot cache stream response """ - result = response["message"]["content"] + async def inner(): + async for chunk in response: + yield chunk["message"]["content"] - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": result, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, + return inner() + else: + result = response["message"]["content"] + if hashing_kv is not None: + await hashing_kv.upsert( + { + args_hash: { + "return": result, + "model": model, + "embedding": quantized.tobytes().hex() + if is_embedding_cache_enabled + else None, + "embedding_shape": quantized.shape + if is_embedding_cache_enabled + else None, + "embedding_min": min_val + if is_embedding_cache_enabled + else None, + "embedding_max": max_val + if is_embedding_cache_enabled + else None, + "original_prompt": prompt, + } } - } - ) - return result + ) + return result @lru_cache(maxsize=1) @@ -785,7 +804,7 @@ async def hf_model_complete( async def ollama_model_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: +) -> Union[str, AsyncIterator[str]]: keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["format"] = "json" diff --git a/lightrag/operate.py b/lightrag/operate.py index a846cfc5..fa6695bc 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -534,8 +534,9 @@ async def kg_query( response = await use_model_func( query, system_prompt=sys_prompt, + stream=query_param.stream, ) - if len(response) > len(sys_prompt): + if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "")