Skip to content

Commit

Permalink
modify history
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Sep 6, 2023
1 parent 2e8daab commit 960c2d5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 88 deletions.
121 changes: 45 additions & 76 deletions pipelines/examples/chatpaper/chat_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,8 @@
import arxiv
import erniebot as eb
import gradio as gr
from utils import (
_apply_token,
get_shown_context,
load_all_json_path,
pdf2image,
retrieval,
)
from utils import _apply_token, load_all_json_path, pdf2image, retrieval, tackle_history

paper_id_list = []
single_paper_id = ""
parser = argparse.ArgumentParser()
parser.add_argument("--api_type", type=str, default="qianfan")
parser.add_argument("--api_key", type=str, default="", help="The API Key.")
Expand Down Expand Up @@ -70,38 +62,19 @@
请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{documents}</已知信息>
<问题>{query}</问题>"""

all_json_id = load_all_json_path(args.txt_file)


def clear_input():
"""
Clear input of paper
"""
global single_paper_id
single_paper_id = ""
return "", "", None


def retrieval_clear_session():
"""
Clear ids of retrieved papers
"""
global paper_id_list
paper_id_list = []
return None, {}


def retrieval_papers(query, state={}):
def retrieval_papers(query, history=[]):
"""
Retrieve papers
"""
query = query.strip().replace("<br>", "\n")
context = state.setdefault("context", [])
global paper_id_list
context = tackle_history(history)
if query:
if not paper_id_list:
context.append({"system": args.system_prompt, "role": "user", "content": query})
if len(history) == 1:
paper_id_list = []
context.append({"role": "user", "content": query})
prediction = retrieval(
query=query,
es_host=args.es_host,
Expand Down Expand Up @@ -131,9 +104,11 @@ def retrieval_papers(query, state={}):
"**" + str(len(paper_id_list)) + "." + title + "**" + "\n" + key_words + "\n" + abstract
)
all_content += paper_content + "\n\n"
context.append({"role": "assistant", "content": all_content})
shown_context = get_shown_context(context)
history.append(["下面请基于这几篇论文进行问题,单篇文档问答请使用单篇问答精读翻译", ",".join(paper_id_list)])
history.append([query, all_content])
else:
# history = [[user_msg(None),system_msg],[user_hint(None),paper_id]]
paper_id_list = history[1][1].split(",")
content = ""
for id in paper_id_list:
prediction = retrieval(
Expand All @@ -157,19 +132,15 @@ def retrieval_papers(query, state={}):
content += "\n".join([item.content for item in prediction["documents"]])
content = PROMPT_RETRIVER.format(documents=content, query=query)
content = content[: args.max_token]
context.append({"system": args.system_prompt, "role": "user", "content": content})
context.append({"role": "user", "content": content})
eb.api_type = args.api_type
access_token = _apply_token(args.api_key, args.secret_key)
eb.access_token = access_token
model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model
response = eb.ChatCompletion.create(model=model, messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
context[-2]["content"] = query
shown_context = get_shown_context(context)
else:
shown_context = get_shown_context(context)
return None, shown_context, state
history.append([query, bot_response])
return None, history


def retrieval_title(title):
Expand Down Expand Up @@ -198,13 +169,14 @@ def retrieval_title(title):
return None


def infer(query, state):
def infer(query, history=[]):
"""Model inference."""
eb.api_type = args.api_type
access_token = _apply_token(args.api_key, args.secret_key)
eb.access_token = access_token
query = query.strip().replace("<br>", "\n")
context = state.setdefault("context", [])
context = tackle_history(history)
single_paper_id = history[1][1]
if query:
if single_paper_id:
prediction = retrieval(
Expand Down Expand Up @@ -232,44 +204,36 @@ def infer(query, state):
model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model
response = eb.ChatCompletion.create(model=model, messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
context[-2]["content"] = query
shown_context = get_shown_context(context)
history.append([query, bot_response])
else:
context.append({"system": args.system_prompt, "role": "user", "content": query})
response = eb.ChatFile.create(messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
shown_context = get_shown_context(context)
else:
shown_context = get_shown_context(context)
return None, shown_context, state
history.append([query, bot_response])
return None, history


def upload_file(file_name, file_url, file_upload, state={}):
def upload_file(file_name, file_url, file_upload, history=[]):
"""
Upload the file to bos or retrieve the json_file of the paper
"""
global single_paper_id
single_paper_id = ""
if file_name:
try:
json_file_path, file_id = retrieval_title(file_name)
json_file_path = json_file_path.replace("/", "_").replace(".pdf", "")
json_file_path = all_json_id[json_file_path]
single_paper_id = file_id
with open(json_file_path, mode="r") as json_file:
json_content = json.load(json_file)
content = json_content["content"]
except:
content = "这篇论文目前尚未加入到论文库中"
history.append([None, file_id])
return (
gr.Gallery.update(visible=False),
gr.File.update(visible=False),
None,
state,
history,
gr.Chatbot.update(
[["", content]],
[[None, content]],
visible=True,
scale=30,
height=600,
Expand All @@ -292,22 +256,21 @@ def upload_file(file_name, file_url, file_upload, state={}):
url = eb.utils.upload_file_to_bos(
file_name, filename_in_bos, access_key_id=args.bos_ak, secret_access_key=args.bos_sk
)
history.append([None, None])
content = "<file>{}</file><url>{}</url>".format(real_filename, url)
content = content.strip().replace("<br>", "\n")
context = state.setdefault("context", [])
context.append({"system": "你是一位AI小助手", "role": "user", "content": content})
context = tackle_history(history)
context.append({"role": "user", "content": content})
access_token = _apply_token(args.api_key, args.secret_key)
eb.api_type = args.api_type
eb.access_token = access_token
response = eb.ChatFile.create(messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
shown_context = get_shown_context(context)
history.append([content, bot_response])
return (
gr.Gallery.update(imgs, visible=True),
gr.File.update(file_name, label="原文下载链接", visible=True),
shown_context,
state,
history,
gr.Chatbot.update(visible=False),
)

Expand All @@ -334,16 +297,19 @@ def upload_file(file_name, file_url, file_upload, state={}):
height=600, value=[[None, "你好, 我是维普Chatpaper小助手, 我这里收录了100w篇论文,可以提供您专业的学术咨询.请问有什么可以帮您的吗?"]]
) # height聊天框高度, value 默认语句
retrieval_textbox = gr.Textbox(placeholder="最近自监督学习论文有哪些?")
retrieval_state = gr.State({})
with gr.Row():
retrieval_submit_btn = gr.Button("🚀 提交", variant="primary", scale=2, min_width=0)
retrieval_clear_btn = gr.Button("清除", variant="primary", scale=2, min_width=0)
retrieval_submit_btn.click(
retrieval_papers,
inputs=[retrieval_textbox, retrieval_state],
outputs=[retrieval_textbox, retrieval_chatbot, retrieval_state],
inputs=[retrieval_textbox, retrieval_chatbot],
outputs=[retrieval_textbox, retrieval_chatbot],
)
retrieval_clear_btn.click(
lambda _: ([[None, "你好, 我是维普Chatpaper文章精读翻译小助手,可以提供您专业的学术咨询.请问有什么可以帮您的吗?"]]),
inputs=[retrieval_clear_btn],
outputs=[retrieval_chatbot],
)
retrieval_clear_btn.click(retrieval_clear_session, inputs=[], outputs=[retrieval_chatbot, retrieval_state])
with gr.Tab("单篇精读翻译"): # 封装chatFile的能力
with gr.Accordion("文章精读翻译:输入区(输入方式三选一,三种输入方式优先级依次降低)", open=True, elem_id="input-panel") as area_input_primary:
with gr.Row():
Expand Down Expand Up @@ -385,22 +351,25 @@ def upload_file(file_name, file_url, file_upload, state={}):
scale=30,
height=600,
)
state = gr.State({})
message = gr.Textbox(placeholder="请问具体描述这篇文章的方法?", scale=7)
with gr.Row():
submit_btn = gr.Button("🚀 提交", variant="primary", scale=2, min_width=0)
clear_btn = gr.Button("清除", variant="primary", scale=2, min_width=0)
submit.click(
upload_file,
inputs=[file_name, file_url, file_upload, state],
outputs=[ori_paper, ori_pdf, chatbot, state, ori_json],
inputs=[file_name, file_url, file_upload, chatbot],
outputs=[ori_paper, ori_pdf, chatbot, ori_json],
)
clear.click(
lambda _: ("", "", None, [[None, "你好, 我是维普Chatpaper文章精读翻译小助手,可以提供您专业的学术咨询.请问有什么可以帮您的吗?"]]),
inputs=[],
outputs=[file_name, file_url, file_upload, chatbot],
)
clear.click(clear_input, inputs=[], outputs=[file_name, file_url, file_upload])
submit_btn.click(infer, inputs=[message, state], outputs=[message, chatbot, state])
submit_btn.click(infer, inputs=[message, chatbot], outputs=[message, chatbot])
clear_btn.click(
lambda _: (None, {}),
lambda _: ([[None, "你好, 我是维普Chatpaper文章精读翻译小助手,可以提供您专业的学术咨询.请问有什么可以帮您的吗?"]]),
inputs=clear_btn,
outputs=[chatbot, state],
outputs=[chatbot],
api_name="clear",
show_progress=False,
)
Expand Down
35 changes: 23 additions & 12 deletions pipelines/examples/chatpaper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
from pipelines.pipelines import Pipeline


def load_all_json_path(path):
json_path = {}
with open(path, encoding="utf-8", mode="r") as f:
for line in f:
try:
json_id, json_name = line.strip().split()
json_path[json_id] = json_name
except:
continue
return json_path


def pdf2image(pdfPath, imgPath, zoom_x=10, zoom_y=10, rotation_angle=0):
"""
Convert PDF to Image
Expand All @@ -37,18 +49,6 @@ def pdf2image(pdfPath, imgPath, zoom_x=10, zoom_y=10, rotation_angle=0):
return image_path


def load_all_json_path(path):
json_path = {}
with open(path, encoding="utf-8", mode="r") as f:
for line in f:
try:
json_id, json_name = line.strip().split()
json_path[json_id] = json_name
except:
continue
return json_path


def _apply_token(api_key, secret_key):
"""
Gererate an access token.
Expand All @@ -72,6 +72,17 @@ def get_shown_context(context):
return shown_context


def tackle_history(history=[]):
messages = []
if len(history) < 3:
return messages
for turn_idx in range(2, len(history)):
messages.extend(
[{"role": "user", "content": history[turn_idx][0]}, {"role": "assistant", "content": history[turn_idx][1]}]
)
return messages


def retrieval(
query: str,
file_id: Optional[str] = None,
Expand Down

0 comments on commit 960c2d5

Please sign in to comment.