diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 1b121ee8..02650f9a 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -41,7 +41,7 @@ export interface BranchOptions { export class Branch { condition: ( input: IO, - config?: RunnableConfig + config: RunnableConfig ) => | string | Send diff --git a/libs/langgraph/src/prebuilt/agent_executor.ts b/libs/langgraph/src/prebuilt/agent_executor.ts index 21611833..d6a6e462 100644 --- a/libs/langgraph/src/prebuilt/agent_executor.ts +++ b/libs/langgraph/src/prebuilt/agent_executor.ts @@ -44,10 +44,7 @@ export function createAgentExecutor({ return "continue"; }; - const runAgent = async ( - data: AgentExecutorState, - config?: RunnableConfig - ) => { + const runAgent = async (data: AgentExecutorState, config: RunnableConfig) => { const agentOutcome = await agentRunnable.invoke(data, config); return { agentOutcome, @@ -56,7 +53,7 @@ export function createAgentExecutor({ const executeTools = async ( data: AgentExecutorState, - config?: RunnableConfig + config: RunnableConfig ): Promise> => { const agentAction = data.agentOutcome; if (!agentAction || "returnValues" in agentAction) { diff --git a/libs/langgraph/src/prebuilt/chat_agent_executor.ts b/libs/langgraph/src/prebuilt/chat_agent_executor.ts index 861d7f68..6cdd0715 100644 --- a/libs/langgraph/src/prebuilt/chat_agent_executor.ts +++ b/libs/langgraph/src/prebuilt/chat_agent_executor.ts @@ -70,7 +70,7 @@ export function createFunctionCallingExecutor({ // Define the function that calls the model const callModel = async ( state: FunctionCallingExecutorState, - config?: RunnableConfig + config: RunnableConfig ) => { const { messages } = state; const response = await newModel.invoke(messages, config); @@ -104,7 +104,7 @@ export function createFunctionCallingExecutor({ const callTool = async ( state: FunctionCallingExecutorState, - config?: RunnableConfig + config: RunnableConfig ) => { const action = _getAction(state); // We call the tool_executor and get back a response diff --git a/libs/langgraph/src/prebuilt/react_agent_executor.ts b/libs/langgraph/src/prebuilt/react_agent_executor.ts index bba625aa..4463c090 100644 --- a/libs/langgraph/src/prebuilt/react_agent_executor.ts +++ b/libs/langgraph/src/prebuilt/react_agent_executor.ts @@ -154,7 +154,7 @@ export function createReactAgent( } }; - const callModel = async (state: AgentState, config?: RunnableConfig) => { + const callModel = async (state: AgentState, config: RunnableConfig) => { const { messages } = state; // TODO: Auto-promote streaming. return { messages: [await modelRunnable.invoke(messages, config)] }; diff --git a/libs/langgraph/src/prebuilt/tool_executor.ts b/libs/langgraph/src/prebuilt/tool_executor.ts index a0366082..bbe03a98 100644 --- a/libs/langgraph/src/prebuilt/tool_executor.ts +++ b/libs/langgraph/src/prebuilt/tool_executor.ts @@ -50,7 +50,7 @@ export class ToolExecutor extends RunnableBinding< ...fields, }; const bound = RunnableLambda.from( - async (input: ToolInvocationInterface, config?: RunnableConfig) => + async (input: ToolInvocationInterface, config: RunnableConfig) => this._execute(input, config) ); super({ @@ -74,7 +74,7 @@ export class ToolExecutor extends RunnableBinding< */ async _execute( toolInvocation: ToolInvocationInterface, - config?: RunnableConfig + config: RunnableConfig ): Promise { if (!(toolInvocation.tool in this.toolMap)) { return this.invalidToolMsgTemplate diff --git a/libs/langgraph/src/pregel/write.ts b/libs/langgraph/src/pregel/write.ts index 9c276443..0d306834 100644 --- a/libs/langgraph/src/pregel/write.ts +++ b/libs/langgraph/src/pregel/write.ts @@ -53,8 +53,8 @@ export class ChannelWrite< .join(",")}>`; super({ ...{ writes, name, tags }, - func: async (input: RunInput, config?: RunnableConfig) => { - return this._write(input, config ?? {}); + func: async (input: RunInput, config: RunnableConfig) => { + return this._write(input, config); }, }); diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index f1a675e6..8b384a09 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -3787,20 +3787,14 @@ describe("StateGraph", () => { hello: Annotation, }); - const nodeA = ( - _: typeof StateAnnotation.State, - config?: RunnableConfig - ) => { + const nodeA = (_: typeof StateAnnotation.State, config: RunnableConfig) => { // Unfortunately can't infer input types at runtime :( - expect(config?.configurable?.foo).toEqual("bar"); + expect(config.configurable?.foo).toEqual("bar"); return {}; }; - const nodeB = ( - _: typeof StateAnnotation.State, - config?: RunnableConfig - ) => { - expect(config?.configurable?.foo).toEqual("bar"); + const nodeB = (_: typeof StateAnnotation.State, config: RunnableConfig) => { + expect(config.configurable?.foo).toEqual("bar"); return { hello: "again", now: 123, @@ -3897,6 +3891,20 @@ describe("StateGraph", () => { hello: "again", }); }); + + it("can be passed a conditional edge with required config arg", async () => { + const workflow = new StateGraph(MessagesAnnotation) + .addNode("nodeOne", () => ({})) + .addConditionalEdges("nodeOne", (_, config) => { + expect(config).toBeDefined(); + if (!config) { + throw new Error("config must be defined."); + } + return END; + }); + const app = workflow.compile(); + await app.invoke({ messages: [] }); + }); }); describe("PreBuilt", () => { @@ -5252,11 +5260,8 @@ describe("Managed Values (context) can be passed through state", () => { it("should be passed through state but not stored in checkpointer", async () => { const nodeOne = async ( data: typeof AgentAnnotation.State, - config?: RunnableConfig + config: RunnableConfig ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } expect(config.configurable?.thread_id).toEqual(threadId); expect(data.sharedStateKey).toEqual({}); @@ -5272,13 +5277,8 @@ describe("Managed Values (context) can be passed through state", () => { }; const nodeTwo = async ( - data: typeof AgentAnnotation.State, - config?: RunnableConfig + data: typeof AgentAnnotation.State ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } - expect(data.sharedStateKey).toEqual({ sharedStateValue: { value: "shared", @@ -5314,13 +5314,8 @@ describe("Managed Values (context) can be passed through state", () => { }; const nodeThree = async ( - data: typeof AgentAnnotation.State, - config?: RunnableConfig + data: typeof AgentAnnotation.State ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } - expect(data.sharedStateKey).toEqual({ sharedStateValue: { value: "updated", @@ -5389,11 +5384,8 @@ describe("Managed Values (context) can be passed through state", () => { it("can not access shared values from other 'on' keys", async () => { const nodeOne = async ( data: typeof AgentAnnotation.State, - config?: RunnableConfig + config: RunnableConfig ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } expect(config.configurable?.thread_id).toBe(threadId); expect(config.configurable?.assistant_id).toBe("a"); @@ -5410,11 +5402,8 @@ describe("Managed Values (context) can be passed through state", () => { const nodeTwo = async ( data: typeof AgentAnnotation.State, - config?: RunnableConfig + config: RunnableConfig ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } expect(config.configurable?.thread_id).toBe(threadId); expect(config.configurable?.assistant_id).toBe("b"); @@ -5431,12 +5420,8 @@ describe("Managed Values (context) can be passed through state", () => { const nodeThree = async ( data: typeof AgentAnnotation.State, - config?: RunnableConfig + config: RunnableConfig ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } - expect(config.configurable?.thread_id).toBe(threadId); expect(config.configurable?.assistant_id).toBe("a"); @@ -5451,12 +5436,8 @@ describe("Managed Values (context) can be passed through state", () => { const nodeFour = async ( data: typeof AgentAnnotation.State, - config?: RunnableConfig + config: RunnableConfig ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } - expect(config.configurable?.thread_id).toBe(threadId); expect(config.configurable?.assistant_id).toBe("b"); @@ -5611,11 +5592,8 @@ describe("Managed Values (context) can be passed through state", () => { // Define nodeOne that sets sharedStateKey and adds a message const nodeOne = async ( data: typeof AgentAnnotation.State, - config?: RunnableConfig + config: RunnableConfig ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } expect(config.configurable?.thread_id).toEqual(threadId); expect(data.sharedStateKey).toEqual({}); @@ -5632,13 +5610,8 @@ describe("Managed Values (context) can be passed through state", () => { // Define nodeTwo that updates sharedStateKey const nodeTwo = async ( - data: typeof AgentAnnotation.State, - config?: RunnableConfig + data: typeof AgentAnnotation.State ): Promise> => { - if (!config) { - throw new Error("config is undefined"); - } - expect(data.sharedStateKey).toEqual({ data: { value: "shared",