Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent大更新合并 #1666

Merged
merged 9 commits into from
Oct 7, 2023
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ langchain>=0.0.302
fschat[model_worker]==0.2.30
openai
sentence_transformers
transformers==4.33.3
torch>=2.0.1
transformers>=4.34
torch>=2.0.1 # 推荐2.1
torchvision
torchaudio
fastapi>=0.103.2
Expand Down
4 changes: 2 additions & 2 deletions requirements_api.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
langchain>=0.0.302
fschat[model_worker]==0.2.30
fschat[model_worker]>=0.2.30
openai
sentence_transformers
transformers>=4.33.3
transformers>=4.34
torch>=2.0.1
torchvision
torchaudio
Expand Down
40 changes: 22 additions & 18 deletions server/agent/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Status:
agent_action: int = 4
agent_finish: int = 5
error: int = 6
make_tool: int = 7
tool_finish: int = 7


class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
Expand All @@ -29,11 +29,19 @@ def __init__(self):
self.queue = asyncio.Queue()
self.done = asyncio.Event()
self.cur_tool = {}
self.out = True

async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:

# 对于截断不能自理的大模型,我来帮他截断
stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
for stop_word in stop_words:
index = input_str.find(stop_word)
if index != -1:
input_str = input_str[:index]
break

self.cur_tool = {
"tool_name": serialized["name"],
"input_str": input_str,
Expand All @@ -44,13 +52,13 @@ async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run
"final_answer": "",
"error": "",
}
# print("\nInput Str:",self.cur_tool["input_str"])
self.queue.put_nowait(dumps(self.cur_tool))

async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
tags: List[str] | None = None, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.agent_finish,
status=Status.tool_finish,
output_str=output.replace("Answer:", ""),
)
self.queue.put_nowait(dumps(self.cur_tool))
Expand All @@ -65,19 +73,11 @@ async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: U

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if token:
if "Action" in token:
self.out = False
self.cur_tool.update(
status=Status.running,
llm_token="\n\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
if self.out:
self.cur_tool.update(
self.cur_tool.update(
status=Status.running,
llm_token=token,
)
self.queue.put_nowait(dumps(self.cur_tool))
)
self.queue.put_nowait(dumps(self.cur_tool))

async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.cur_tool.update(
Expand All @@ -87,15 +87,13 @@ async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **k
self.queue.put_nowait(dumps(self.cur_tool))

async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.complete,
llm_token="",
llm_token="\n",
)
self.queue.put_nowait(dumps(self.cur_tool))

async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.error,
error=str(error),
Expand All @@ -107,4 +105,10 @@ async def on_agent_finish(
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
# 返回最终答案
self.cur_tool.update(
status=Status.agent_finish,
final_answer=finish.return_values["output"],
)
self.queue.put_nowait(dumps(self.cur_tool))
self.cur_tool = {}
67 changes: 52 additions & 15 deletions server/agent/custom_template.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations
from langchain.agents import Tool, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from typing import List, Union
from typing import List, Union, Tuple, Dict
from langchain.schema import AgentAction, AgentFinish
import re
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN

begin = False
class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str
Expand All @@ -19,47 +21,82 @@ def format(self, **kwargs) -> str:
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
# Set the agent_scratchpad variable to that value
# Set the agent_scratchpad variable to that value
kwargs["agent_scratchpad"] = thoughts
# Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
# Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
# Return the formatted templatepr
# print( self.template.format(**kwargs), end="\n\n")
return self.template.format(**kwargs)


class CustomOutputParser(AgentOutputParser):
begin: bool = False
def __init__(self):
super().__init__()
self.begin = True

def parse(self, llm_output: str) -> AgentFinish | AgentAction | str:
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
# Check if agent should finish
support_agent = ["gpt","Qwen","qwen-api","baichuan-api"]
if not any(agent in LLM_MODEL for agent in support_agent) and self.begin:
self.begin = False
stop_words = ["Observation:"]
min_index = len(llm_output)
for stop_word in stop_words:
index = llm_output.find(stop_word)
if index != -1 and index < min_index:
min_index = index
llm_output = llm_output[:min_index]

if "Final Answer:" in llm_output:
output = llm_output.split("Final Answer:", 1)[-1].strip()
self.begin = True
return AgentFinish(
# Return values is generally always a dictionary with a single `output` key
# It is not recommended to try anything else at the moment :)
return_values={"output": llm_output.replace("Final Answer:", "").strip()},
# return_values={"output": llm_output.replace("Final Answer:", "").strip()},
return_values={"output": output},
log=llm_output,
)

# Parse out the action and action input
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
match = re.search(regex, llm_output, re.DOTALL)
if not match:
parts = llm_output.split("Action:")
if len(parts) < 2:
return AgentFinish(
return_values={"output": f"调用agent失败: `{llm_output}`"},
log=llm_output,
)
action = match.group(1).strip()
action_input = match.group(2)

action = parts[1].split("Action Input:")[0].strip()
action_input = parts[1].split("Action Input:")[1].strip()

# 原来的正则化检查方式
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
# print("llm_output",llm_output)
# match = re.search(regex, llm_output, re.DOTALL)
# print("match",match)
# if not match:
# return AgentFinish(
# return_values={"output": f"调用agent失败: `{llm_output}`"},
# log=llm_output,
# )
# action = match.group(1).strip()
# action_input = match.group(2)

# Return the action and action input

try:
ans = AgentAction(
tool=action,
tool_input=action_input.strip(" ").strip('"'),
log=llm_output
tool=action,
tool_input=action_input.strip(" ").strip('"'),
log=llm_output
)
return ans
except:
return AgentFinish(
return_values={"output": f"调用agent失败: `{llm_output}`"},
log=llm_output,
)



18 changes: 13 additions & 5 deletions server/agent/math.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
## 单独运行的时候需要添加
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from langchain.prompts import PromptTemplate
from langchain.chains import LLMMathChain
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import get_ChatOpenAI
from configs.model_config import LLM_MODEL, TEMPERATURE
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import CallbackManagerForToolRun

_PROMPT_TEMPLATE = """将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
_PROMPT_TEMPLATE = """
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
问题: ${{包含数学问题的问题。}}
```text
${{解决问题的单行数学表达式}}
Expand Down Expand Up @@ -68,3 +71,8 @@ def calculate(query: str):
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_math.run(query)
return ans

if __name__ == "__main__":
result = calculate("2的三次方")
print("答案:",result)

8 changes: 4 additions & 4 deletions server/agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
Tool.from_function(
func=translate,
name="翻译工具",
description="翻译各种语言"
description="如果你无法访问互联网,并且需要翻译各种语言,应该使用这个工具"
),
Tool.from_function(
func=weathercheck,
name="天气查询工具",
description="查询天气",
description="如果你无法访问互联网,并需要查询中国各地未来24小时的天气,你应该使用这个工具,每轮对话仅能使用一次",
),
Tool.from_function(
func=shell,
Expand All @@ -35,12 +35,12 @@
Tool.from_function(
func=search_knowledge,
name="知识库查询工具",
description="使用西交利物浦大学大数据专业的本专业数据库来解答问题",
description="访问知识库来获取答案",
),
Tool.from_function(
func=search_internet,
name="互联网查询工具",
description="访问Bing互联网来解答问题",
description="如果你无法访问互联网,这个工具可以帮助你访问Bing互联网来解答问题",
),

]
Expand Down
39 changes: 14 additions & 25 deletions server/agent/translator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
## 单独运行的时候需要添加
import sys
import os

from server.utils import get_ChatOpenAI

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from server.utils import get_ChatOpenAI
from langchain.chains.llm_math.prompt import PROMPT
from configs.model_config import LLM_MODEL,TEMPERATURE

Expand All @@ -16,25 +15,12 @@
2. 无论提供的是陈述句或疑问句,只进行翻译
3. 不添加与原文无关的内容

原文: ${{用户需要翻译的原文和目标语言}}
{question}
```output
${{翻译结果}}
```
答案: ${{答案}}

以下是两个例子
问题: 翻译13成英语
```text
13 英语
```output
thirteen
以下是两个例子
问题: 翻译 我爱你 成法语
```text
13 法语
```output
Je t'aime.
问题: ${{用户需要翻译的原文和目标语言}}
答案: 你翻译结果

现在,这是我的问题:
问题: {question}

'''

PROMPT = PromptTemplate(
Expand All @@ -51,5 +37,8 @@ def translate(query: str):
)
llm_translate = LLMChain(llm=model, prompt=PROMPT)
ans = llm_translate.run(query)
return ans

return ans
if __name__ == "__main__":
result = translate("Can Love remember the question and the answer? 这句话如何诗意的翻译成中文")
print("答案:",result)
Loading