Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(graph): passthrough input types to invoke/stream #650

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -441,37 +441,39 @@ export class Graph<
export class CompiledGraph<
N extends string,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput = any,
State = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput = any,
Update = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType extends Record<string, any> = Record<string, any>
> extends Pregel<
Record<N | typeof START, PregelNode<RunInput, RunOutput>>,
Record<N | typeof START, PregelNode<State, Update>>,
Record<N | typeof START | typeof END | string, BaseChannel>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType & Record<string, any>
ConfigurableFieldType & Record<string, any>,
Update,
State
> {
declare NodeType: N;

declare RunInput: RunInput;
declare RunInput: State;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little strange that RunInput = StateType<SD>, I think we should change it to Update?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I agree, this looks to be the other way around, no? input should be update, output should be based on stream mode (sometimes update, sometimes state, etc)


declare RunOutput: RunOutput;
declare RunOutput: Update;

builder: Graph<N, RunInput, RunOutput>;
builder: Graph<N, State, Update>;

constructor({
builder,
...rest
}: { builder: Graph<N, RunInput, RunOutput> } & PregelParams<
Record<N | typeof START, PregelNode<RunInput, RunOutput>>,
}: { builder: Graph<N, State, Update> } & PregelParams<
Record<N | typeof START, PregelNode<State, Update>>,
Record<N | typeof START | typeof END | string, BaseChannel>
>) {
super(rest);
this.builder = builder;
}

attachNode(key: N, node: NodeSpec<RunInput, RunOutput>): void {
attachNode(key: N, node: NodeSpec<State, Update>): void {
this.channels[key] = new EphemeralValue();
this.nodes[key] = new PregelNode({
channels: [],
Expand Down Expand Up @@ -503,7 +505,7 @@ export class CompiledGraph<
attachBranch(
start: N | typeof START,
name: string,
branch: Branch<RunInput, N>
branch: Branch<State, N>
) {
// add hidden start node
if (start === START && this.nodes[START]) {
Expand Down Expand Up @@ -588,7 +590,7 @@ export class CompiledGraph<

for (const [key, nodeSpec] of Object.entries(this.builder.nodes) as [
N,
NodeSpec<RunInput, RunOutput>
NodeSpec<State, Update>
][]) {
const displayKey = _escapeMermaidKeywords(key);
const node = nodeSpec.runnable;
Expand Down Expand Up @@ -769,7 +771,7 @@ export class CompiledGraph<

for (const [key, nodeSpec] of Object.entries(this.builder.nodes) as [
N,
NodeSpec<RunInput, RunOutput>
NodeSpec<State, Update>
][]) {
const displayKey = _escapeMermaidKeywords(key);
const node = nodeSpec.runnable;
Expand Down
24 changes: 6 additions & 18 deletions libs/langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,9 @@ import {
} from "@langchain/core/language_models/base";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { All, BaseCheckpointSaver } from "@langchain/langgraph-checkpoint";
import {
END,
messagesStateReducer,
START,
StateGraph,
} from "../graph/index.js";
import { END, START, StateGraph } from "../graph/index.js";
import { MessagesAnnotation } from "../graph/messages_annotation.js";
import { CompiledStateGraph, StateGraphArgs } from "../graph/state.js";
import { CompiledStateGraph } from "../graph/state.js";
import { ToolNode } from "./tool_node.js";

export interface AgentState {
Expand Down Expand Up @@ -107,11 +102,12 @@ export type CreateReactAgentParams = {
* // Returns the messages in the state at each step of execution
* ```
*/

export function createReactAgent(
params: CreateReactAgentParams
): CompiledStateGraph<
AgentState,
Partial<AgentState>,
(typeof MessagesAnnotation)["State"],
(typeof MessagesAnnotation)["Update"],
typeof START | "agent" | "tools"
> {
const {
Expand All @@ -122,12 +118,6 @@ export function createReactAgent(
interruptBefore,
interruptAfter,
} = params;
const schema: StateGraphArgs<AgentState>["channels"] = {
messages: {
value: messagesStateReducer,
default: () => [],
},
};

let toolClasses: (StructuredToolInterface | DynamicTool | RunnableToolLike)[];
if (!Array.isArray(tools)) {
Expand Down Expand Up @@ -160,9 +150,7 @@ export function createReactAgent(
return { messages: [await modelRunnable.invoke(messages, config)] };
};

const workflow = new StateGraph<AgentState>({
channels: schema,
})
const workflow = new StateGraph(MessagesAnnotation)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done primarily to get the correct input types (based off the messagesReducer, but open to different solutions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good to make this change, the other syntax is deprecated

.addNode("agent", callModel)
.addNode("tools", new ToolNode<AgentState>(toolClasses))
.addEdge(START, "agent")
Expand Down
16 changes: 9 additions & 7 deletions libs/langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,13 @@ export class Pregel<
Nn extends StrRecord<string, PregelNode>,
Cc extends StrRecord<string, BaseChannel | ManagedValueSpec>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType extends Record<string, any> = StrRecord<string, any>
ConfigurableFieldType extends Record<string, any> = StrRecord<string, any>,
InputType = PregelInputType,
OutputType = PregelOutputType
>
extends Runnable<
PregelInputType,
PregelOutputType,
InputType | null,
OutputType,
PregelOptions<Nn, Cc, ConfigurableFieldType>
>
implements
Expand Down Expand Up @@ -901,7 +903,7 @@ export class Pregel<
* @param options.debug Whether to print debug information during execution.
*/
override async stream(
input: PregelInputType,
input: InputType | null,
options?: Partial<PregelOptions<Nn, Cc, ConfigurableFieldType>>
): Promise<IterableReadableStream<PregelOutputType>> {
return super.stream(input, options);
Expand Down Expand Up @@ -1221,9 +1223,9 @@ export class Pregel<
* @param options.debug Whether to print debug information during execution.
*/
override async invoke(
input: PregelInputType,
input: InputType | null,
options?: Partial<PregelOptions<Nn, Cc, ConfigurableFieldType>>
): Promise<PregelOutputType> {
): Promise<OutputType> {
const streamMode = options?.streamMode ?? "values";
const config = {
...options,
Expand All @@ -1238,6 +1240,6 @@ export class Pregel<
if (streamMode === "values") {
return chunks[chunks.length - 1];
}
return chunks;
return chunks as OutputType;
}
}
21 changes: 15 additions & 6 deletions libs/langgraph/src/tests/graph.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ describe("State", () => {
it("should allow reducers with different argument types", async () => {
const StateAnnotation = Annotation.Root({
val: Annotation<number>,
testval: Annotation<string[], string>({
reducer: (left, right) =>
right ? left.concat([right.toString()]) : left,
testval: Annotation<string[], string | string[]>({
reducer: (left, right) => {
if (typeof right === "string") {
return right ? left.concat([right.toString()]) : left;
}
return right.length ? left.concat(right) : left;
},
}),
});
const stateGraph = new StateGraph(StateAnnotation);
Expand All @@ -41,6 +45,7 @@ describe("State", () => {
.addEdge(START, "testnode")
.addEdge("testnode", END)
.compile();

expect(await graph.invoke({ testval: ["hello"] })).toEqual({
testval: ["hello", "hi!"],
val: 3,
Expand All @@ -51,12 +56,16 @@ describe("State", () => {
const stateGraph = new StateGraph<
unknown,
{ testval: string[] },
{ testval: string }
{ testval: string | string[] }
>({
channels: {
testval: {
reducer: (left: string[], right?: string) =>
right ? left.concat([right.toString()]) : left,
reducer: (left, right) => {
if (typeof right === "string") {
return right ? left.concat([right.toString()]) : left;
}
return right.length ? left.concat(right) : left;
},
},
},
});
Expand Down
4 changes: 3 additions & 1 deletion libs/langgraph/src/tests/prebuilt.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ describe("createReactAgent", () => {
expect(response.messages.length > 1).toBe(true);
const lastMessage = response.messages[response.messages.length - 1];
expect(lastMessage._getType()).toBe("ai");
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
expect((lastMessage.content as string).toLowerCase()).toContain(
"not too cold"
);
});

it("can stream a tool call with a checkpointer", async () => {
Expand Down
3 changes: 3 additions & 0 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3541,6 +3541,7 @@ export function runPregelTests(
hello: "there",
bye: "world",
messages: ["hello"],
// @ts-expect-error This should emit a TS error
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can now actually catch excessive keys!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha nice!

now: 345, // ignored because not in input schema
})
).toEqual({
Expand All @@ -3553,6 +3554,7 @@ export function runPregelTests(
hello: "there",
bye: "world",
messages: ["hello"],
// @ts-expect-error This should emit a TS error
now: 345, // ignored because not in input schema
})
)
Expand Down Expand Up @@ -3712,6 +3714,7 @@ export function runPregelTests(
};
const res = await app.invoke(
{
// @ts-expect-error Messages is not in schema
messages: ["initial input"],
},
config
Expand Down
14 changes: 10 additions & 4 deletions libs/langgraph/src/tests/tracing.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,16 @@ Only add steps to the plan that still NEED to be done. Do not return previously
state: PlanExecuteState
): Promise<Partial<PlanExecuteState>> {
const task = state.input;
const agentResponse = await agentExecutor.invoke({ input: task });
return {
pastSteps: [task, agentResponse.agentOutcome.returnValues.output],
};
const agentResponse = await agentExecutor.invoke({
input: task ?? undefined,
});

const outcome = agentResponse.agentOutcome;
if (!outcome || !("returnValues" in outcome)) {
throw new Error("Agent did not return a valid outcome.");
}

return { pastSteps: [task, outcome.returnValues.output] };
}

async function planStep(
Expand Down
Loading