diff --git a/lightrag/api/azure_openai_lightrag_server.py b/lightrag/api/azure_openai_lightrag_server.py index ec106d61..084f3caa 100644 --- a/lightrag/api/azure_openai_lightrag_server.py +++ b/lightrag/api/azure_openai_lightrag_server.py @@ -136,6 +136,7 @@ class SearchMode(str, Enum): class QueryRequest(BaseModel): query: str mode: SearchMode = SearchMode.hybrid + only_need_context: bool = False # stream: bool = False @@ -308,7 +309,7 @@ async def query_text(request: QueryRequest): try: response = await rag.aquery( request.query, - param=QueryParam(mode=request.mode, stream=False), + param=QueryParam(mode=request.mode, stream=False, only_need_context=request.only_need_context), ) return QueryResponse(response=response) except Exception as e: @@ -319,7 +320,7 @@ async def query_text_stream(request: QueryRequest): try: response = await rag.aquery( request.query, - param=QueryParam(mode=request.mode, stream=True), + param=QueryParam(mode=request.mode, stream=True, only_need_context=request.only_need_context), ) if inspect.isasyncgen(response): diff --git a/lightrag/api/lollms_lightrag_server.py b/lightrag/api/lollms_lightrag_server.py index e5444b9a..45b4dc4b 100644 --- a/lightrag/api/lollms_lightrag_server.py +++ b/lightrag/api/lollms_lightrag_server.py @@ -130,6 +130,7 @@ class QueryRequest(BaseModel): query: str mode: SearchMode = SearchMode.hybrid stream: bool = False + only_need_context: bool = False class QueryResponse(BaseModel): @@ -266,7 +267,7 @@ async def query_text(request: QueryRequest): try: response = await rag.aquery( request.query, - param=QueryParam(mode=request.mode, stream=request.stream), + param=QueryParam(mode=request.mode, stream=request.stream, only_need_context=request.only_need_context), ) if request.stream: @@ -283,7 +284,7 @@ async def query_text(request: QueryRequest): async def query_text_stream(request: QueryRequest): try: response = rag.query( - request.query, param=QueryParam(mode=request.mode, stream=True) + request.query, param=QueryParam(mode=request.mode, stream=True, only_need_context=request.only_need_context) ) async def stream_generator(): diff --git a/lightrag/api/ollama_lightrag_server.py b/lightrag/api/ollama_lightrag_server.py index 0d73a4d1..5bbc32c2 100644 --- a/lightrag/api/ollama_lightrag_server.py +++ b/lightrag/api/ollama_lightrag_server.py @@ -130,6 +130,7 @@ class QueryRequest(BaseModel): query: str mode: SearchMode = SearchMode.hybrid stream: bool = False + only_need_context: bool = False class QueryResponse(BaseModel): @@ -266,7 +267,7 @@ async def query_text(request: QueryRequest): try: response = await rag.aquery( request.query, - param=QueryParam(mode=request.mode, stream=request.stream), + param=QueryParam(mode=request.mode, stream=request.stream, only_need_context=request.only_need_context), ) if request.stream: @@ -283,7 +284,7 @@ async def query_text(request: QueryRequest): async def query_text_stream(request: QueryRequest): try: response = rag.query( - request.query, param=QueryParam(mode=request.mode, stream=True) + request.query, param=QueryParam(mode=request.mode, stream=True, only_need_context=request.only_need_context) ) async def stream_generator(): diff --git a/lightrag/api/openai_lightrag_server.py b/lightrag/api/openai_lightrag_server.py index 59cdd80d..051bb1ff 100644 --- a/lightrag/api/openai_lightrag_server.py +++ b/lightrag/api/openai_lightrag_server.py @@ -119,6 +119,7 @@ class QueryRequest(BaseModel): query: str mode: SearchMode = SearchMode.hybrid stream: bool = False + only_need_context: bool = False class QueryResponse(BaseModel): @@ -270,7 +271,7 @@ async def query_text(request: QueryRequest): try: response = await rag.aquery( request.query, - param=QueryParam(mode=request.mode, stream=request.stream), + param=QueryParam(mode=request.mode, stream=request.stream, only_need_context=request.only_need_context), ) if request.stream: @@ -287,7 +288,7 @@ async def query_text(request: QueryRequest): async def query_text_stream(request: QueryRequest): try: response = rag.query( - request.query, param=QueryParam(mode=request.mode, stream=True) + request.query, param=QueryParam(mode=request.mode, stream=True, only_need_context=request.only_need_context) ) async def stream_generator():