From 037f6cf9eef3699d48ac37ef4a4d6d376f86fc9a Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 1 Nov 2024 19:15:26 +0100 Subject: [PATCH 1/4] feat(graph): passthrough input and output types to invoke/stream --- libs/langgraph/src/graph/graph.ts | 4 +++- libs/langgraph/src/graph/state.ts | 1 + libs/langgraph/src/pregel/index.ts | 18 ++++++++++-------- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 30ef7854..250ecc54 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -450,7 +450,9 @@ export class CompiledGraph< Record>, Record, // eslint-disable-next-line @typescript-eslint/no-explicit-any - ConfigurableFieldType & Record + ConfigurableFieldType & Record, + RunInput, + RunOutput > { declare NodeType: N; diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index b9bc67d4..e80140b0 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -49,6 +49,7 @@ import type { RetryPolicy } from "../pregel/utils/index.js"; import { isConfiguredManagedValue, ManagedValueSpec } from "../managed/base.js"; import type { LangGraphRunnableConfig } from "../pregel/runnable_types.js"; import { isPregelLike } from "../pregel/utils/subgraph.js"; +import { MessagesAnnotation } from "./messages_annotation.js"; const ROOT = "__root__"; diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index cbdf3b1a..a660056c 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -194,11 +194,13 @@ export class Pregel< Nn extends StrRecord, Cc extends StrRecord, // eslint-disable-next-line @typescript-eslint/no-explicit-any - ConfigurableFieldType extends Record = StrRecord + ConfigurableFieldType extends Record = StrRecord, + InputType = PregelInputType, + OutputType = PregelOutputType > extends Runnable< - PregelInputType, - PregelOutputType, + InputType, + OutputType, PregelOptions > implements @@ -901,9 +903,9 @@ export class Pregel< * @param options.debug Whether to print debug information during execution. */ override async stream( - input: PregelInputType, + input: InputType, options?: Partial> - ): Promise> { + ): Promise> { return super.stream(input, options); } @@ -1221,9 +1223,9 @@ export class Pregel< * @param options.debug Whether to print debug information during execution. */ override async invoke( - input: PregelInputType, + input: InputType, options?: Partial> - ): Promise { + ): Promise { const streamMode = options?.streamMode ?? "values"; const config = { ...options, @@ -1238,6 +1240,6 @@ export class Pregel< if (streamMode === "values") { return chunks[chunks.length - 1]; } - return chunks; + return chunks as OutputType; } } From af7d8c1f58610d6a6d697b2d20d536a38031054f Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 1 Nov 2024 19:17:26 +0100 Subject: [PATCH 2/4] Fix lint --- libs/langgraph/src/graph/state.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index e80140b0..b9bc67d4 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -49,7 +49,6 @@ import type { RetryPolicy } from "../pregel/utils/index.js"; import { isConfiguredManagedValue, ManagedValueSpec } from "../managed/base.js"; import type { LangGraphRunnableConfig } from "../pregel/runnable_types.js"; import { isPregelLike } from "../pregel/utils/subgraph.js"; -import { MessagesAnnotation } from "./messages_annotation.js"; const ROOT = "__root__"; From 17ab0e3275813447dd152ea45d42f2ce71afb4a3 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 1 Nov 2024 19:18:28 +0100 Subject: [PATCH 3/4] Accept null as input as well (to continue) --- libs/langgraph/src/pregel/index.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index a660056c..968a8f5d 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -199,7 +199,7 @@ export class Pregel< OutputType = PregelOutputType > extends Runnable< - InputType, + InputType | null, OutputType, PregelOptions > @@ -903,7 +903,7 @@ export class Pregel< * @param options.debug Whether to print debug information during execution. */ override async stream( - input: InputType, + input: InputType | null, options?: Partial> ): Promise> { return super.stream(input, options); @@ -1223,7 +1223,7 @@ export class Pregel< * @param options.debug Whether to print debug information during execution. */ override async invoke( - input: InputType, + input: InputType | null, options?: Partial> ): Promise { const streamMode = options?.streamMode ?? "values"; From ab7a649f6938159d3adffcff800f2bdd04fb6922 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Mon, 4 Nov 2024 16:21:53 +0100 Subject: [PATCH 4/4] Opt out of return value typing for now, swap StateType with UpdateType for correct semantics, use Annotation.Root for prebuilt agent executor --- libs/langgraph/src/graph/graph.ts | 28 +++++++++---------- .../src/prebuilt/react_agent_executor.ts | 24 ++++------------ libs/langgraph/src/pregel/index.ts | 2 +- libs/langgraph/src/tests/graph.test.ts | 21 ++++++++++---- libs/langgraph/src/tests/prebuilt.int.test.ts | 4 ++- libs/langgraph/src/tests/pregel.test.ts | 3 ++ libs/langgraph/src/tests/tracing.int.test.ts | 14 +++++++--- 7 files changed, 52 insertions(+), 44 deletions(-) diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 250ecc54..70e033d9 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -441,39 +441,39 @@ export class Graph< export class CompiledGraph< N extends string, // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunInput = any, + State = any, // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput = any, + Update = any, // eslint-disable-next-line @typescript-eslint/no-explicit-any ConfigurableFieldType extends Record = Record > extends Pregel< - Record>, + Record>, Record, // eslint-disable-next-line @typescript-eslint/no-explicit-any ConfigurableFieldType & Record, - RunInput, - RunOutput + Update, + State > { declare NodeType: N; - declare RunInput: RunInput; + declare RunInput: State; - declare RunOutput: RunOutput; + declare RunOutput: Update; - builder: Graph; + builder: Graph; constructor({ builder, ...rest - }: { builder: Graph } & PregelParams< - Record>, + }: { builder: Graph } & PregelParams< + Record>, Record >) { super(rest); this.builder = builder; } - attachNode(key: N, node: NodeSpec): void { + attachNode(key: N, node: NodeSpec): void { this.channels[key] = new EphemeralValue(); this.nodes[key] = new PregelNode({ channels: [], @@ -505,7 +505,7 @@ export class CompiledGraph< attachBranch( start: N | typeof START, name: string, - branch: Branch + branch: Branch ) { // add hidden start node if (start === START && this.nodes[START]) { @@ -590,7 +590,7 @@ export class CompiledGraph< for (const [key, nodeSpec] of Object.entries(this.builder.nodes) as [ N, - NodeSpec + NodeSpec ][]) { const displayKey = _escapeMermaidKeywords(key); const node = nodeSpec.runnable; @@ -771,7 +771,7 @@ export class CompiledGraph< for (const [key, nodeSpec] of Object.entries(this.builder.nodes) as [ N, - NodeSpec + NodeSpec ][]) { const displayKey = _escapeMermaidKeywords(key); const node = nodeSpec.runnable; diff --git a/libs/langgraph/src/prebuilt/react_agent_executor.ts b/libs/langgraph/src/prebuilt/react_agent_executor.ts index 1c7793b7..5e6ca079 100644 --- a/libs/langgraph/src/prebuilt/react_agent_executor.ts +++ b/libs/langgraph/src/prebuilt/react_agent_executor.ts @@ -20,14 +20,9 @@ import { } from "@langchain/core/language_models/base"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { All, BaseCheckpointSaver } from "@langchain/langgraph-checkpoint"; -import { - END, - messagesStateReducer, - START, - StateGraph, -} from "../graph/index.js"; +import { END, START, StateGraph } from "../graph/index.js"; import { MessagesAnnotation } from "../graph/messages_annotation.js"; -import { CompiledStateGraph, StateGraphArgs } from "../graph/state.js"; +import { CompiledStateGraph } from "../graph/state.js"; import { ToolNode } from "./tool_node.js"; export interface AgentState { @@ -107,11 +102,12 @@ export type CreateReactAgentParams = { * // Returns the messages in the state at each step of execution * ``` */ + export function createReactAgent( params: CreateReactAgentParams ): CompiledStateGraph< - AgentState, - Partial, + (typeof MessagesAnnotation)["State"], + (typeof MessagesAnnotation)["Update"], typeof START | "agent" | "tools" > { const { @@ -122,12 +118,6 @@ export function createReactAgent( interruptBefore, interruptAfter, } = params; - const schema: StateGraphArgs["channels"] = { - messages: { - value: messagesStateReducer, - default: () => [], - }, - }; let toolClasses: (StructuredToolInterface | DynamicTool | RunnableToolLike)[]; if (!Array.isArray(tools)) { @@ -160,9 +150,7 @@ export function createReactAgent( return { messages: [await modelRunnable.invoke(messages, config)] }; }; - const workflow = new StateGraph({ - channels: schema, - }) + const workflow = new StateGraph(MessagesAnnotation) .addNode("agent", callModel) .addNode("tools", new ToolNode(toolClasses)) .addEdge(START, "agent") diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 968a8f5d..fd6c7035 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -905,7 +905,7 @@ export class Pregel< override async stream( input: InputType | null, options?: Partial> - ): Promise> { + ): Promise> { return super.stream(input, options); } diff --git a/libs/langgraph/src/tests/graph.test.ts b/libs/langgraph/src/tests/graph.test.ts index cee5f1ff..852c72c1 100644 --- a/libs/langgraph/src/tests/graph.test.ts +++ b/libs/langgraph/src/tests/graph.test.ts @@ -23,9 +23,13 @@ describe("State", () => { it("should allow reducers with different argument types", async () => { const StateAnnotation = Annotation.Root({ val: Annotation, - testval: Annotation({ - reducer: (left, right) => - right ? left.concat([right.toString()]) : left, + testval: Annotation({ + reducer: (left, right) => { + if (typeof right === "string") { + return right ? left.concat([right.toString()]) : left; + } + return right.length ? left.concat(right) : left; + }, }), }); const stateGraph = new StateGraph(StateAnnotation); @@ -41,6 +45,7 @@ describe("State", () => { .addEdge(START, "testnode") .addEdge("testnode", END) .compile(); + expect(await graph.invoke({ testval: ["hello"] })).toEqual({ testval: ["hello", "hi!"], val: 3, @@ -51,12 +56,16 @@ describe("State", () => { const stateGraph = new StateGraph< unknown, { testval: string[] }, - { testval: string } + { testval: string | string[] } >({ channels: { testval: { - reducer: (left: string[], right?: string) => - right ? left.concat([right.toString()]) : left, + reducer: (left, right) => { + if (typeof right === "string") { + return right ? left.concat([right.toString()]) : left; + } + return right.length ? left.concat(right) : left; + }, }, }, }); diff --git a/libs/langgraph/src/tests/prebuilt.int.test.ts b/libs/langgraph/src/tests/prebuilt.int.test.ts index d2745bda..9c69b5c0 100644 --- a/libs/langgraph/src/tests/prebuilt.int.test.ts +++ b/libs/langgraph/src/tests/prebuilt.int.test.ts @@ -63,7 +63,9 @@ describe("createReactAgent", () => { expect(response.messages.length > 1).toBe(true); const lastMessage = response.messages[response.messages.length - 1]; expect(lastMessage._getType()).toBe("ai"); - expect(lastMessage.content.toLowerCase()).toContain("not too cold"); + expect((lastMessage.content as string).toLowerCase()).toContain( + "not too cold" + ); }); it("can stream a tool call with a checkpointer", async () => { diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index ad593ef0..6856e6df 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -3541,6 +3541,7 @@ export function runPregelTests( hello: "there", bye: "world", messages: ["hello"], + // @ts-expect-error This should emit a TS error now: 345, // ignored because not in input schema }) ).toEqual({ @@ -3553,6 +3554,7 @@ export function runPregelTests( hello: "there", bye: "world", messages: ["hello"], + // @ts-expect-error This should emit a TS error now: 345, // ignored because not in input schema }) ) @@ -3712,6 +3714,7 @@ export function runPregelTests( }; const res = await app.invoke( { + // @ts-expect-error Messages is not in schema messages: ["initial input"], }, config diff --git a/libs/langgraph/src/tests/tracing.int.test.ts b/libs/langgraph/src/tests/tracing.int.test.ts index 77dbfcef..9aa92a14 100644 --- a/libs/langgraph/src/tests/tracing.int.test.ts +++ b/libs/langgraph/src/tests/tracing.int.test.ts @@ -470,10 +470,16 @@ Only add steps to the plan that still NEED to be done. Do not return previously state: PlanExecuteState ): Promise> { const task = state.input; - const agentResponse = await agentExecutor.invoke({ input: task }); - return { - pastSteps: [task, agentResponse.agentOutcome.returnValues.output], - }; + const agentResponse = await agentExecutor.invoke({ + input: task ?? undefined, + }); + + const outcome = agentResponse.agentOutcome; + if (!outcome || !("returnValues" in outcome)) { + throw new Error("Agent did not return a valid outcome."); + } + + return { pastSteps: [task, outcome.returnValues.output] }; } async function planStep(