Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(langgraph): Update code to reflect config being a required arg #501

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export interface BranchOptions<IO, N extends string> {
export class Branch<IO, N extends string> {
condition: (
input: IO,
config?: RunnableConfig
config: RunnableConfig
) =>
| string
| Send
Expand Down
7 changes: 2 additions & 5 deletions libs/langgraph/src/prebuilt/agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ export function createAgentExecutor({
return "continue";
};

const runAgent = async (
data: AgentExecutorState,
config?: RunnableConfig
) => {
const runAgent = async (data: AgentExecutorState, config: RunnableConfig) => {
const agentOutcome = await agentRunnable.invoke(data, config);
return {
agentOutcome,
Expand All @@ -56,7 +53,7 @@ export function createAgentExecutor({

const executeTools = async (
data: AgentExecutorState,
config?: RunnableConfig
config: RunnableConfig
): Promise<Partial<AgentExecutorState>> => {
const agentAction = data.agentOutcome;
if (!agentAction || "returnValues" in agentAction) {
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/src/prebuilt/chat_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ export function createFunctionCallingExecutor<Model extends object>({
// Define the function that calls the model
const callModel = async (
state: FunctionCallingExecutorState,
config?: RunnableConfig
config: RunnableConfig
) => {
const { messages } = state;
const response = await newModel.invoke(messages, config);
Expand Down Expand Up @@ -104,7 +104,7 @@ export function createFunctionCallingExecutor<Model extends object>({

const callTool = async (
state: FunctionCallingExecutorState,
config?: RunnableConfig
config: RunnableConfig
) => {
const action = _getAction(state);
// We call the tool_executor and get back a response
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ export function createReactAgent(
}
};

const callModel = async (state: AgentState, config?: RunnableConfig) => {
const callModel = async (state: AgentState, config: RunnableConfig) => {
const { messages } = state;
// TODO: Auto-promote streaming.
return { messages: [await modelRunnable.invoke(messages, config)] };
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/src/prebuilt/tool_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export class ToolExecutor extends RunnableBinding<
...fields,
};
const bound = RunnableLambda.from(
async (input: ToolInvocationInterface, config?: RunnableConfig) =>
async (input: ToolInvocationInterface, config: RunnableConfig) =>
this._execute(input, config)
);
super({
Expand All @@ -74,7 +74,7 @@ export class ToolExecutor extends RunnableBinding<
*/
async _execute(
toolInvocation: ToolInvocationInterface,
config?: RunnableConfig
config: RunnableConfig
): Promise<ToolExecutorOutputType> {
if (!(toolInvocation.tool in this.toolMap)) {
return this.invalidToolMsgTemplate
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/src/pregel/write.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ export class ChannelWrite<
.join(",")}>`;
super({
...{ writes, name, tags },
func: async (input: RunInput, config?: RunnableConfig) => {
return this._write(input, config ?? {});
func: async (input: RunInput, config: RunnableConfig) => {
return this._write(input, config);
},
});

Expand Down
81 changes: 27 additions & 54 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3787,20 +3787,14 @@ describe("StateGraph", () => {
hello: Annotation<string>,
});

const nodeA = (
_: typeof StateAnnotation.State,
config?: RunnableConfig
) => {
const nodeA = (_: typeof StateAnnotation.State, config: RunnableConfig) => {
// Unfortunately can't infer input types at runtime :(
expect(config?.configurable?.foo).toEqual("bar");
expect(config.configurable?.foo).toEqual("bar");
return {};
};

const nodeB = (
_: typeof StateAnnotation.State,
config?: RunnableConfig
) => {
expect(config?.configurable?.foo).toEqual("bar");
const nodeB = (_: typeof StateAnnotation.State, config: RunnableConfig) => {
expect(config.configurable?.foo).toEqual("bar");
return {
hello: "again",
now: 123,
Expand Down Expand Up @@ -3897,6 +3891,20 @@ describe("StateGraph", () => {
hello: "again",
});
});

it("can be passed a conditional edge with required config arg", async () => {
const workflow = new StateGraph(MessagesAnnotation)
.addNode("nodeOne", () => ({}))
.addConditionalEdges("nodeOne", (_, config) => {
expect(config).toBeDefined();
if (!config) {
throw new Error("config must be defined.");
}
return END;
});
const app = workflow.compile();
await app.invoke({ messages: [] });
});
});

describe("PreBuilt", () => {
Expand Down Expand Up @@ -5252,11 +5260,8 @@ describe("Managed Values (context) can be passed through state", () => {
it("should be passed through state but not stored in checkpointer", async () => {
const nodeOne = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
config: RunnableConfig
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}
expect(config.configurable?.thread_id).toEqual(threadId);

expect(data.sharedStateKey).toEqual({});
Expand All @@ -5272,13 +5277,8 @@ describe("Managed Values (context) can be passed through state", () => {
};

const nodeTwo = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
data: typeof AgentAnnotation.State
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}

expect(data.sharedStateKey).toEqual({
sharedStateValue: {
value: "shared",
Expand Down Expand Up @@ -5314,13 +5314,8 @@ describe("Managed Values (context) can be passed through state", () => {
};

const nodeThree = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
data: typeof AgentAnnotation.State
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}

expect(data.sharedStateKey).toEqual({
sharedStateValue: {
value: "updated",
Expand Down Expand Up @@ -5389,11 +5384,8 @@ describe("Managed Values (context) can be passed through state", () => {
it("can not access shared values from other 'on' keys", async () => {
const nodeOne = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
config: RunnableConfig
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}
expect(config.configurable?.thread_id).toBe(threadId);
expect(config.configurable?.assistant_id).toBe("a");

Expand All @@ -5410,11 +5402,8 @@ describe("Managed Values (context) can be passed through state", () => {

const nodeTwo = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
config: RunnableConfig
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}
expect(config.configurable?.thread_id).toBe(threadId);
expect(config.configurable?.assistant_id).toBe("b");

Expand All @@ -5431,12 +5420,8 @@ describe("Managed Values (context) can be passed through state", () => {

const nodeThree = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
config: RunnableConfig
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}

expect(config.configurable?.thread_id).toBe(threadId);
expect(config.configurable?.assistant_id).toBe("a");

Expand All @@ -5451,12 +5436,8 @@ describe("Managed Values (context) can be passed through state", () => {

const nodeFour = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
config: RunnableConfig
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}

expect(config.configurable?.thread_id).toBe(threadId);
expect(config.configurable?.assistant_id).toBe("b");

Expand Down Expand Up @@ -5611,11 +5592,8 @@ describe("Managed Values (context) can be passed through state", () => {
// Define nodeOne that sets sharedStateKey and adds a message
const nodeOne = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
config: RunnableConfig
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}
expect(config.configurable?.thread_id).toEqual(threadId);

expect(data.sharedStateKey).toEqual({});
Expand All @@ -5632,13 +5610,8 @@ describe("Managed Values (context) can be passed through state", () => {

// Define nodeTwo that updates sharedStateKey
const nodeTwo = async (
data: typeof AgentAnnotation.State,
config?: RunnableConfig
data: typeof AgentAnnotation.State
): Promise<Partial<typeof AgentAnnotation.State>> => {
if (!config) {
throw new Error("config is undefined");
}

expect(data.sharedStateKey).toEqual({
data: {
value: "shared",
Expand Down
Loading