Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(langgraph): Allow option to prevent subgraphs inheriting checkpointer #570

Merged
merged 4 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions libs/langgraph/src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,29 @@ export class InvalidUpdateError extends Error {
return "InvalidUpdateError";
}
}

export class MultipleSubgraphsError extends Error {
constructor(message?: string) {
super(message);
this.name = "MultipleSubgraphError";
}

static get unminifiable_name() {
return "MultipleSubgraphError";
}
}

/**
* Used for subgraph detection.
*/
export const getSubgraphsSeenSet = () => {
if (
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(globalThis as any)[Symbol.for("LG_CHECKPOINT_SEEN_NS_SET")] === undefined
) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(globalThis as any)[Symbol.for("LG_CHECKPOINT_SEEN_NS_SET")] = new Set();
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (globalThis as any)[Symbol.for("LG_CHECKPOINT_SEEN_NS_SET")];
};
2 changes: 1 addition & 1 deletion libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ export class Graph<
interruptBefore,
interruptAfter,
}: {
checkpointer?: BaseCheckpointSaver;
checkpointer?: BaseCheckpointSaver | false;
interruptBefore?: N[] | All;
interruptAfter?: N[] | All;
} = {}): CompiledGraph<N> {
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ export class StateGraph<
interruptBefore,
interruptAfter,
}: {
checkpointer?: BaseCheckpointSaver;
checkpointer?: BaseCheckpointSaver | false;
store?: BaseStore;
interruptBefore?: N[] | All;
interruptAfter?: N[] | All;
Expand Down
8 changes: 5 additions & 3 deletions libs/langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ export class Pregel<

debug: boolean = false;

checkpointer?: BaseCheckpointSaver;
checkpointer?: BaseCheckpointSaver | false;

retryPolicy?: RetryPolicy;

Expand Down Expand Up @@ -877,15 +877,16 @@ export class Pregel<
}

let defaultCheckpointer: BaseCheckpointSaver | undefined;
if (
if (this.checkpointer === false) {
defaultCheckpointer = undefined;
} else if (
config !== undefined &&
config.configurable?.[CONFIG_KEY_CHECKPOINTER] !== undefined
) {
defaultCheckpointer = config.configurable[CONFIG_KEY_CHECKPOINTER];
} else {
defaultCheckpointer = this.checkpointer;
}

const defaultStore: BaseStore | undefined = config.store ?? this.store;

return [
Expand Down Expand Up @@ -1000,6 +1001,7 @@ export class Pregel<
}
if (
this.checkpointer !== undefined &&
this.checkpointer !== false &&
inputConfig.configurable === undefined
) {
throw new Error(
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/src/pregel/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ export function* mapInput<C extends PropertyKey>(
}
} else if (Array.isArray(inputChannels)) {
throw new Error(
"Input chunk must be an object when inputChannels is an array"
`Input chunk must be an object when "inputChannels" is an array`
);
} else {
yield [inputChannels, chunk];
Expand Down
19 changes: 18 additions & 1 deletion libs/langgraph/src/pregel/loop.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ import {
readChannels,
} from "./io.js";
import {
getSubgraphsSeenSet,
EmptyInputError,
GraphInterrupt,
isGraphInterrupt,
MultipleSubgraphsError,
} from "../errors.js";
import { getNewChannelVersions, patchConfigurable } from "./utils/index.js";
import {
Expand Down Expand Up @@ -84,6 +86,7 @@ export type PregelLoopInitializeParams = {
managed: ManagedValueMapping;
stream: StreamProtocol;
store?: BaseStore;
checkSubgraphs?: boolean;
};

type PregelLoopParams = {
Expand Down Expand Up @@ -232,6 +235,7 @@ export class PregelLoop {

static async initialize(params: PregelLoopInitializeParams) {
let { config, stream } = params;
const { checkSubgraphs = true } = params;
if (
stream !== undefined &&
config.configurable?.[CONFIG_KEY_STREAM] !== undefined
Expand Down Expand Up @@ -311,7 +315,20 @@ export class PregelLoop {
// Start the store. This is a batch store, so it will run continuously
store.start();
}

if (checkSubgraphs && isNested && params.checkpointer !== undefined) {
if (getSubgraphsSeenSet().has(config.configurable?.checkpoint_ns)) {
throw new MultipleSubgraphsError(
[
"Detected the same subgraph called multiple times by the same node.",
"This is not allowed if checkpointing is enabled.",
"",
`You can disable checkpointing for a subgraph by compiling it with ".compile({ checkpointer: false });"`,
].join("\n")
);
} else {
getSubgraphsSeenSet().add(config.configurable?.checkpoint_ns);
}
}
return new PregelLoop({
input: params.input,
config,
Expand Down
12 changes: 11 additions & 1 deletion libs/langgraph/src/pregel/retry.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { isGraphInterrupt } from "../errors.js";
import { getSubgraphsSeenSet, isGraphInterrupt } from "../errors.js";
import { PregelExecutableTask } from "./types.js";
import type { RetryPolicy } from "./utils/index.js";

Expand Down Expand Up @@ -167,6 +167,16 @@ async function _runWithRetry(
2
)} seconds (attempt ${attempts}) after ${errorName}: ${error}`
);
// Clear checkpoint_ns seen (for subgraph detection)
const checkpointNs = pregelTask.config?.configurable?.checkpoint_ns;
if (checkpointNs) {
getSubgraphsSeenSet().delete(checkpointNs);
}
} finally {
const checkpointNs = pregelTask.config?.configurable?.checkpoint_ns;
if (checkpointNs) {
getSubgraphsSeenSet().delete(checkpointNs);
}
}
}
return {
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/src/pregel/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export interface PregelInterface<
*/
debug?: boolean;

checkpointer?: BaseCheckpointSaver;
checkpointer?: BaseCheckpointSaver | false;

retryPolicy?: RetryPolicy;

Expand Down
4 changes: 3 additions & 1 deletion libs/langgraph/src/tests/pregel.io.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ describe("mapInput", () => {
// do nothing, error will be thrown
continue;
}
}).toThrow("Input chunk must be an object when inputChannels is an array");
}).toThrow(
`Input chunk must be an object when "inputChannels" is an array`
);
});
});

Expand Down
Loading
Loading