diff --git a/basic_demo/glm_server.py b/basic_demo/glm_server.py index 2ae8b22..b08ea34 100644 --- a/basic_demo/glm_server.py +++ b/basic_demo/glm_server.py @@ -444,23 +444,35 @@ async def predict_stream(model_id, gen_params): system_fingerprint = generate_id('fp_', 9) tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {} delta_text = "" + delta_confirming_texts = [] + confirm_tool_state = 'un_confirm' if tools else 'none' + # 带有tools时可以确认是否调用工具的最大字符长度 = 工具名最大长度 + 可能的前面有“\n”、后面“\n{”共3个字符。 + max_confirm_tool_length = len(max(tools, len)) + 3 if tools else 0 async for new_response in generate_stream_glm4(gen_params): decoded_unicode = new_response["text"] delta_text += decoded_unicode[len(output):] + if confirm_tool_state == 'un_confirm': + delta_confirming_texts.append(decoded_unicode[len(output):]) + output = decoded_unicode lines = output.strip().split("\n") # 检查是否为工具 # 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。 ##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。 - - if not is_function_call and len(lines) >= 2: + if confirm_tool_state == 'un_confirm' and len(lines) >= 2 and lines[1].startswith("{"): first_line = lines[0].strip() if first_line in tools: is_function_call = True function_name = first_line delta_text = lines[1] + confirm_tool_state == 'confirmed' + else: + confirm_tool_state == 'none' + # 当传入tools时,经过大模型输出几轮后,已经可以确认不需要调用工具了 + if confirm_tool_state == 'un_confirm' and max_confirm_tool_length < len(delta_text): + confirm_tool_state == 'none' # 工具调用返回 if is_function_call: if not has_send_first_chunk: @@ -524,7 +536,7 @@ async def predict_stream(model_id, gen_params): yield chunk.model_dump_json(exclude_unset=True) # 用户请求了 Function Call 但是框架还没确定是否为Function Call - elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call: + elif confirm_tool_state == 'un_confirm': continue # 常规返回 @@ -552,6 +564,29 @@ async def predict_stream(model_id, gen_params): yield chunk.model_dump_json(exclude_unset=True) has_send_first_chunk = True + for text in delta_confirming_texts: + message = DeltaMessage( + content=text, + role="assistant", + function_call=None, + ) + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=message, + finish_reason=finish_reason + ) + chunk = ChatCompletionResponse( + model=model_id, + id=response_id, + choices=[choice_data], + created=created_time, + system_fingerprint=system_fingerprint, + object="chat.completion.chunk" + ) + yield chunk.model_dump_json(exclude_unset=True) + delta_confirming_texts = [] + delta_text = "" + message = DeltaMessage( content=delta_text, role="assistant", diff --git a/basic_demo/langgraph_agent_stream_demo.py b/basic_demo/langgraph_agent_stream_demo.py new file mode 100644 index 0000000..66c8209 --- /dev/null +++ b/basic_demo/langgraph_agent_stream_demo.py @@ -0,0 +1,110 @@ +import asyncio + +from typing import Annotated +from typing_extensions import TypedDict +from langgraph.graph.message import add_messages +from langchain_core.tools import tool +from langgraph.prebuilt import ToolNode +from langchain_openai import ChatOpenAI +from langchain_core.runnables import RunnableConfig +from langgraph.graph import END, START, StateGraph +from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage, AnyMessage + +""" +This script build an agent by langgraph and stream LLM tokens +pip install langchain==0.2.16 +pip install langgraph==0.2.34 +pip install langchain_openai==0.1.9 +""" + + +class State(TypedDict): + messages: Annotated[list, add_messages] + + +@tool +def search(query: str): + """Call to surf the web.""" + return ["Cloudy with a chance of hail."] + + +tools = [search] + +model = ChatOpenAI( + temperature=0, + # model="glm-4", + model="GLM-4-Flash", + openai_api_key="[You Key]", + # openai_api_base="https://open.bigmodel.cn/api/paas/v4/", #使用智谱官方提供的是正常流式输出 + openai_api_base="You url by glm_server.py ", + streaming=True +) + + +class Agent: + + def __init__(self, model, tools, system=""): + self.system = system + workflow = StateGraph(State) + workflow.add_node("agent", self.call_model) + workflow.add_node("tools", ToolNode(tools)) + workflow.add_edge(START, "agent") + workflow.add_conditional_edges( + # First, we define the start node. We use `agent`. + # This means these are the edges taken after the `agent` node is called. + "agent", + # Next, we pass in the function that will determine which node is called next. + self.should_continue, + # Next we pass in the path map - all the nodes this edge could go to + ["tools", END], + ) + workflow.add_edge("tools", "agent") + self.model = model.bind_tools(tools) + self.app = workflow.compile() + + def should_continue(self, state: State): + messages = state["messages"] + last_message = messages[-1] + # If there is no function call, then we finish + if not last_message.tool_calls: + return END + # Otherwise if there is, we continue + else: + return "tools" + + async def call_model(self, state: State, config: RunnableConfig): + messages = state["messages"] + if self.system: + messages = [SystemMessage(content=self.system)] + messages + response = await self.model.ainvoke(messages, config) + # We return a list, because this will get added to the existing list + return {"messages": response} + + async def query(self, user_input: str): + inputs = [HumanMessage(content=user_input)] + first = True + async for msg, metadata in self.app.astream({"messages": inputs}, stream_mode="messages"): + if msg.content and not isinstance(msg, HumanMessage): + # 这里可以看出是否正常流式输出 + print(msg.content, end="|", flush=True) + + if isinstance(msg, AIMessageChunk): + if first: + gathered = msg + first = False + else: + gathered = gathered + msg + + if msg.tool_call_chunks: + print('tool_call_chunks...', gathered.tool_calls) + + +if __name__ == '__main__': + + input = "what is the weather in sf" + prompt = """ + You are smart research assistant. Use the search engine ... + """ + agent = Agent(model, tools, prompt) + asyncio.run(agent.query(input)) + diff --git a/basic_demo/requirements.txt b/basic_demo/requirements.txt index 4ff1483..771c759 100644 --- a/basic_demo/requirements.txt +++ b/basic_demo/requirements.txt @@ -16,6 +16,9 @@ einops>=0.8.0 pillow>=10.4.0 sse-starlette>=2.1.3 bitsandbytes>=0.43.3 # INT4 Loading +langchain_openai>=0.1.9 +langchain>=0.2.16 +langgraph>=0.2.34 # vllm>=0.6.4 # using with VLLM Framework # flash-attn>=2.6.3 # using with flash-attention 2