Skip to content

Commit

Permalink
update issues THUDM#618 使用tools时无法stream的问题, THUDM#618
Browse files Browse the repository at this point in the history
  • Loading branch information
jurnea committed Oct 30, 2024
1 parent 6bf9f85 commit 8fed0d9
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 3 deletions.
41 changes: 38 additions & 3 deletions basic_demo/glm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,23 +444,35 @@ async def predict_stream(model_id, gen_params):
system_fingerprint = generate_id('fp_', 9)
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
delta_text = ""
delta_confirming_texts = []
confirm_tool_state = 'un_confirm' if tools else 'none'
# 带有tools时可以确认是否调用工具的最大字符长度 = 工具名最大长度 + 可能的前面有“\n”、后面“\n{”共3个字符。
max_confirm_tool_length = len(max(tools, len)) + 3 if tools else 0
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text += decoded_unicode[len(output):]
if confirm_tool_state == 'un_confirm':
delta_confirming_texts.append(decoded_unicode[len(output):])

output = decoded_unicode
lines = output.strip().split("\n")

# 检查是否为工具
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。

if not is_function_call and len(lines) >= 2:
if confirm_tool_state == 'un_confirm' and len(lines) >= 2 and lines[1].startswith("{"):
first_line = lines[0].strip()
if first_line in tools:
is_function_call = True
function_name = first_line
delta_text = lines[1]
confirm_tool_state == 'confirmed'
else:
confirm_tool_state == 'none'

# 当传入tools时,经过大模型输出几轮后,已经可以确认不需要调用工具了
if confirm_tool_state == 'un_confirm' and max_confirm_tool_length < len(delta_text):
confirm_tool_state == 'none'
# 工具调用返回
if is_function_call:
if not has_send_first_chunk:
Expand Down Expand Up @@ -524,7 +536,7 @@ async def predict_stream(model_id, gen_params):
yield chunk.model_dump_json(exclude_unset=True)

# 用户请求了 Function Call 但是框架还没确定是否为Function Call
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
elif confirm_tool_state == 'un_confirm':
continue

# 常规返回
Expand Down Expand Up @@ -552,6 +564,29 @@ async def predict_stream(model_id, gen_params):
yield chunk.model_dump_json(exclude_unset=True)
has_send_first_chunk = True

for text in delta_confirming_texts:
message = DeltaMessage(
content=text,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
delta_confirming_texts = []
delta_text = ""

message = DeltaMessage(
content=delta_text,
role="assistant",
Expand Down
110 changes: 110 additions & 0 deletions basic_demo/langgraph_agent_stream_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import asyncio

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage, AnyMessage

"""
This script build an agent by langgraph and stream LLM tokens
pip install langchain==0.2.16
pip install langgraph==0.2.34
pip install langchain_openai==0.1.9
"""


class State(TypedDict):
messages: Annotated[list, add_messages]


@tool
def search(query: str):
"""Call to surf the web."""
return ["Cloudy with a chance of hail."]


tools = [search]

model = ChatOpenAI(
temperature=0,
# model="glm-4",
model="GLM-4-Flash",
openai_api_key="[You Key]",
# openai_api_base="https://open.bigmodel.cn/api/paas/v4/", #使用智谱官方提供的是正常流式输出
openai_api_base="You url by glm_server.py ",
streaming=True
)


class Agent:

def __init__(self, model, tools, system=""):
self.system = system
workflow = StateGraph(State)
workflow.add_node("agent", self.call_model)
workflow.add_node("tools", ToolNode(tools))
workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
self.should_continue,
# Next we pass in the path map - all the nodes this edge could go to
["tools", END],
)
workflow.add_edge("tools", "agent")
self.model = model.bind_tools(tools)
self.app = workflow.compile()

def should_continue(self, state: State):
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return END
# Otherwise if there is, we continue
else:
return "tools"

async def call_model(self, state: State, config: RunnableConfig):
messages = state["messages"]
if self.system:
messages = [SystemMessage(content=self.system)] + messages
response = await self.model.ainvoke(messages, config)
# We return a list, because this will get added to the existing list
return {"messages": response}

async def query(self, user_input: str):
inputs = [HumanMessage(content=user_input)]
first = True
async for msg, metadata in self.app.astream({"messages": inputs}, stream_mode="messages"):
if msg.content and not isinstance(msg, HumanMessage):
# 这里可以看出是否正常流式输出
print(msg.content, end="|", flush=True)

if isinstance(msg, AIMessageChunk):
if first:
gathered = msg
first = False
else:
gathered = gathered + msg

if msg.tool_call_chunks:
print('tool_call_chunks...', gathered.tool_calls)


if __name__ == '__main__':

input = "what is the weather in sf"
prompt = """
You are smart research assistant. Use the search engine ...
"""
agent = Agent(model, tools, prompt)
asyncio.run(agent.query(input))

3 changes: 3 additions & 0 deletions basic_demo/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ einops>=0.8.0
pillow>=10.4.0
sse-starlette>=2.1.3
bitsandbytes>=0.43.3 # INT4 Loading
langchain_openai>=0.1.9
langchain>=0.2.16
langgraph>=0.2.34

# vllm>=0.6.4 # using with VLLM Framework
# flash-attn>=2.6.3 # using with flash-attention 2
Expand Down

0 comments on commit 8fed0d9

Please sign in to comment.