Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAGFlow streaming output suggestions #3738 #3881

Open
wants to merge 92 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
b071797
Add test for document (#3497)
Feiue Nov 19, 2024
b1001bf
fix: laws.py added missing import logging (#3501)
michalmasrna1 Nov 19, 2024
f424f19
Fix bugs (#3502)
JinHai-CN Nov 20, 2024
36e75b3
fix synonym bug (#3506)
KevinHuSh Nov 20, 2024
e16b7c5
smooth term weight (#3510)
KevinHuSh Nov 20, 2024
81f92d0
feat: Add Datasets component to home page #3221 (#3508)
cike8899 Nov 20, 2024
9314b03
fix: keyerror issue (#3512)
KevinHuSh Nov 20, 2024
95dc59d
Added kb_id filter to knn. Fix #3458 (#3513)
yuzhichang Nov 20, 2024
2062d7f
Make spark model robuster to model name (#3514)
KevinHuSh Nov 20, 2024
273678d
Fix: potential risk (#3515)
KevinHuSh Nov 20, 2024
c55231b
Fix set_output type hint (#3516)
yuzhichang Nov 20, 2024
cbef6fd
Merge remote-tracking branch 'remote/main'
Nov 21, 2024
f5ef1fb
Merge remote-tracking branch 'remote/main'
Nov 21, 2024
f4a7b92
Merge remote-tracking branch 'remote/main'
Nov 21, 2024
6976db1
Merge remote-tracking branch 'remote/main'
Nov 22, 2024
b8c31d5
Merge remote-tracking branch 'remote/main'
Nov 22, 2024
e9140ae
Merge remote-tracking branch 'remote/main'
Nov 22, 2024
70359e0
Merge remote-tracking branch 'remote/main'
Nov 25, 2024
22b0ad9
Merge remote-tracking branch 'remote/main'
Nov 25, 2024
ccb4e2f
Merge remote-tracking branch 'remote/main'
Nov 25, 2024
3f3e073
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
accbe5f
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
facc2d6
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
d030a23
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
f0c7e25
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
32f6517
Merge remote-tracking branch 'remote/main'
Nov 27, 2024
1211e22
Merge remote-tracking branch 'remote/main'
Nov 27, 2024
7934014
Merge remote-tracking branch 'remote/main'
Nov 28, 2024
49ad2bd
Merge remote-tracking branch 'remote/main'
Nov 28, 2024
00c1b41
Merge remote-tracking branch 'remote/main'
Nov 29, 2024
dd9fec8
Merge remote-tracking branch 'remote/main'
Nov 29, 2024
202ada4
Merge remote-tracking branch 'remote/main'
Dec 3, 2024
851ad89
Merge remote-tracking branch 'remote/main'
Dec 3, 2024
39ab46e
Merge remote-tracking branch 'remote/main'
Dec 3, 2024
c61bd86
Merge remote-tracking branch 'remote/main'
Dec 4, 2024
5dfc60d
Merge remote-tracking branch 'remote/main'
Dec 4, 2024
2160487
test: add session.py logs
Dec 4, 2024
c0dfab5
test: add dialog_service.py logs
Dec 4, 2024
742a871
test: chat_streamly stream
Dec 4, 2024
121d78a
Merge remote-tracking branch 'remote/main' into main_lz
Dec 4, 2024
762dea0
Merge remote-tracking branch 'remote/main'
Dec 4, 2024
d932105
Merge remote-tracking branch 'remote/main' into main_lz
Dec 4, 2024
c964b36
Merge remote-tracking branch 'remote/main' into main_lz
Dec 5, 2024
e51fcf8
Merge remote-tracking branch 'refs/remotes/remote/main' into main_lz
Dec 5, 2024
d090b1c
test: chat_streamly delta
Dec 5, 2024
3570ce4
Merge remote-tracking branch 'origin/main_lz'
Dec 5, 2024
30e9c29
Merge remote-tracking branch 'remote/main'
Dec 5, 2024
16cc7ec
Merge remote-tracking branch 'remote/main'
Dec 5, 2024
7c9c42f
Test: Comment log printing
Dec 5, 2024
57375be
Test: delete log printing
Dec 5, 2024
89466ed
Merge remote-tracking branch 'origin/main' into main_remote_lz
Dec 5, 2024
a6d21a1
Merge remote-tracking branch 'remote/main'
Dec 5, 2024
c94dc7f
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 6, 2024
7e4c5fe
Merge remote-tracking branch 'remote/main'
Dec 6, 2024
c26bdf7
Fix: The issue of truncation of the streaming output of the char mode…
Dec 6, 2024
06c8745
Fix: dialog_service.py The issue of truncation of the streaming outpu…
Dec 6, 2024
96dd427
Merge remote-tracking branch 'remote/main'
Dec 6, 2024
144f4e8
Merge remote-tracking branch 'refs/remotes/origin/main' into main_rem…
Dec 6, 2024
47b9c0c
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 6, 2024
df7defc
Merge remote-tracking branch 'remote/main'
Dec 6, 2024
fa86cdf
Merge remote-tracking branch 'remote_lz/main'
Dec 6, 2024
4be5f57
Merge remote-tracking branch 'remote/main'
Dec 6, 2024
67f5ad1
Merge remote-tracking branch 'refs/remotes/remote/main'
Dec 9, 2024
5cf6dc1
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 9, 2024
2c58f3c
Merge remote-tracking branch 'remote_lz/main' into main_remote_lz
Dec 9, 2024
5b0d908
Merge remote-tracking branch 'refs/remotes/origin/main' into main_rem…
Dec 9, 2024
6903acc
Fix: Delete the content of the comment
Dec 9, 2024
3a64029
Fix: Resolve conflicts
Dec 9, 2024
12175ab
Merge remote-tracking branch 'remote/main'
Dec 9, 2024
46f6d28
Merge remote-tracking branch 'remote/main'
Dec 9, 2024
ea93fb7
Merge remote-tracking branch 'remote/main'
Dec 9, 2024
a3667dc
Merge remote-tracking branch 'remote/main'
Dec 9, 2024
602f392
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
ba26f10
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 10, 2024
2937810
Merge remote-tracking branch 'remote_lz/main' into main_remote_lz
Dec 10, 2024
ba1161e
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
b7455a8
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
273fded
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
23aca28
Fix: Remove redundant references
Dec 10, 2024
2a6c252
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 10, 2024
d573164
Merge remote-tracking branch 'origin/main' into main_remote_lz
Dec 10, 2024
eaf622b
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
b55f21e
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 10, 2024
d9d196d
Merge remote-tracking branch 'origin/main' into main_remote_lz
Dec 10, 2024
5d62a80
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 11, 2024
70e9e73
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 11, 2024
cfb877d
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 12, 2024
b5b7397
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 12, 2024
8e8ad6c
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 13, 2024
34a1b3e
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 13, 2024
b020cf8
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 13, 2024
69344cb
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/apps/sdk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import json
from api.db import LLMType
Expand Down
8 changes: 7 additions & 1 deletion api/apps/tenant_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def create(tenant_id):
@manager.route('/<tenant_id>/user/<user_id>', methods=['DELETE']) # noqa: F821
@login_required
def rm(tenant_id, user_id):
if current_user.id != tenant_id and current_user.id != user_id:
if current_user.id != tenant_id:
return get_json_result(
data=False,
message='No authorization.',
Expand Down Expand Up @@ -111,6 +111,12 @@ def tenant_list():
@manager.route("/agree/<tenant_id>", methods=["PUT"]) # noqa: F821
@login_required
def agree(tenant_id):
if current_user.id != tenant_id:
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR)

try:
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL})
return get_json_result(data=True)
Expand Down
76 changes: 64 additions & 12 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,72 @@ def decorate_answer(answer):
(done_tm - retrieval_tm) * 1000)
return {"answer": answer, "reference": refs, "prompt": prompt}


#注释原先流式代码
# if stream:
# last_ans = ""
# answer = ""
# for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
# answer = ans
# logging.info("answer_stream : {}".format(ans))
# delta_ans = ans[len(last_ans):]
# if num_tokens_from_string(delta_ans) < 16:
# continue
# last_ans = answer
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# delta_ans = answer[len(last_ans):]
# if delta_ans:
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# yield decorate_answer(answer)

if stream:
last_ans = ""
# logging.info("stream_mode : {}".format(msg[1:]))
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(answer)
for delta in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
# 检查是否为总令牌数或通知信息
if isinstance(delta, str):
if delta.isdigit():
# 处理总令牌数(如果需要)
total_tokens = int(delta)
continue
elif "\n**ERROR**:" in delta:
# 处理错误信息
answer += delta
yield {"answer": answer, "reference": {}, "audio_binary": b''} # 错误时不生成音频
continue

# 处理增量文本
delta_ans = delta
# if num_tokens_from_string(delta_ans) < 16:
# continue # 根据需求调整阈值

# 更新完整的答案
answer += delta_ans

# 生成音频
audio = tts(tts_mdl, delta_ans)
# logging.info(f"Generated audio for delta: {delta_ans}")
yield {"answer": delta_ans, "reference": {}, "audio_binary": audio}
elif isinstance(delta, dict):
# 如果 chat_streamly 仍返回字典(不推荐)
# 例如: {"new_text": "新增内容", "position": 10}
new_text = delta.get("new_text", "")
if not new_text:
continue
if num_tokens_from_string(new_text) < 16:
continue

# 更新完整的答案
answer += new_text

# 生成音频
audio = tts(tts_mdl, new_text)
yield {"answer": answer, "reference": {}, "audio_binary": audio}

# 最终装饰答案
decorated_answer = decorate_answer(answer)
# logging.info(f"Final decorated answer: {decorated_answer}")
yield decorated_answer
else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
logging.debug("User: {}|Assistant: {}".format(
Expand Down
4 changes: 3 additions & 1 deletion poetry.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[virtualenvs]
in-project = true
create = true
prefer-active-python = true
prefer-active-python = true
[repositories.tuna]
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
81 changes: 68 additions & 13 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,50 @@ def chat(self, system, history, gen_conf):
except openai.APIError as e:
return "**ERROR**: " + str(e), 0

# def chat_streamly(self, system, history, gen_conf):
# if system:
# history.insert(0, {"role": "system", "content": system})
# ans = ""
# total_tokens = 0
# try:
# response = self.client.chat.completions.create(
# model=self.model_name,
# messages=history,
# stream=True,
# **gen_conf)
# for resp in response:
# if not resp.choices:
# continue
# if not resp.choices[0].delta.content:
# resp.choices[0].delta.content = ""
# ans += resp.choices[0].delta.content
#
# if not hasattr(resp, "usage") or not resp.usage:
# total_tokens = (
# total_tokens
# + num_tokens_from_string(resp.choices[0].delta.content)
# )
# elif isinstance(resp.usage, dict):
# total_tokens = resp.usage.get("total_tokens", total_tokens)
# else:
# total_tokens = resp.usage.total_tokens
#
# if resp.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# yield ans
#
# except openai.APIError as e:
# yield ans + "\n**ERROR**: " + str(e)
#
# yield total_tokens

def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})

ans = ""
total_tokens = 0
try:
Expand All @@ -71,30 +112,44 @@ def chat_streamly(self, system, history, gen_conf):
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content

if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
finish_reason = resp.choices[0].finish_reason
delta_content = resp.choices[0].delta.content if resp.choices[0].delta.content else ""

# 如果有新增文本,累积并输出增量
if delta_content:
ans += delta_content

# 更新令牌计数
if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else:
total_tokens = resp.usage.total_tokens
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else:
total_tokens = resp.usage.total_tokens

if resp.choices[0].finish_reason == "length":
yield delta_content

# 即使delta_content为空,也要检查finish_reason
if finish_reason == "length":
# 长度受限时添加提示信息
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
notification = LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
notification = LENGTH_NOTIFICATION_EN
yield notification

# 如果finish_reason为"stop"或其他值,可以在此添加相应逻辑
# (本示例中未对"stop"做额外处理,因为通常这意味着回答正常结束)

except openai.APIError as e:
# 返回错误信息
yield ans + "\n**ERROR**: " + str(e)

# 最终返回总令牌数
yield total_tokens


Expand Down
5 changes: 4 additions & 1 deletion rag/llm/cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from rag.nlp import is_english
from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory
from google.generativeai import client, GenerativeModel, GenerationConfig


class Base(ABC):
Expand Down Expand Up @@ -57,6 +58,7 @@ def chat(self, system, history, gen_conf, image=""):
except Exception as e:
return "**ERROR**: " + str(e), 0


def chat_streamly(self, system, history, gen_conf, image=""):
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
Expand Down Expand Up @@ -92,7 +94,8 @@ def chat_streamly(self, system, history, gen_conf, image=""):
yield ans + "\n**ERROR**: " + str(e)

yield tk_count



def image2base64(self, image):
if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8")
Expand Down
Loading