diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx
index b05adbc097c..5bc449fac6f 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx
@@ -106,7 +106,6 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer |
variant="ghost"
leftIcon={}
onClick={adapter.filterer.cancel}
- isLoading={isProcessing}
loadingText={t('controlLayers.filter.cancel')}
>
{t('controlLayers.filter.cancel')}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts
index cc457a4f15a..1ee23fdec50 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts
@@ -14,7 +14,7 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import type { UploadOptions } from 'services/api/endpoints/images';
-import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
+import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import { assert } from 'tsafe';
@@ -210,7 +210,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
if (cachedImageName) {
- imageDTO = await getImageDTO(cachedImageName);
+ imageDTO = await getImageDTOSafe(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, imageName: cachedImageName, imageDTO }, 'Using cached composite raster layer image');
return imageDTO;
@@ -374,7 +374,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
if (cachedImageName) {
- imageDTO = await getImageDTO(cachedImageName);
+ imageDTO = await getImageDTOSafe(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached composite inpaint mask image');
return imageDTO;
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts
index 5b66a7c332e..3fd501a4682 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts
@@ -1,5 +1,4 @@
-import type { SerializableObject } from 'common/types';
-import { withResultAsync } from 'common/util/result';
+import { withResult, withResultAsync } from 'common/util/result';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -13,9 +12,9 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { debounce } from 'lodash-es';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
-import { getImageDTO } from 'services/api/endpoints/images';
+import { serializeError } from 'serialize-error';
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
-import { type BatchConfig, type ImageDTO, isControlNetOrT2IAdapterModelConfig, type S } from 'services/api/types';
+import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
type CanvasEntityFiltererConfig = {
@@ -38,6 +37,11 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
subscriptions = new Set<() => void>();
config: CanvasEntityFiltererConfig = DEFAULT_CONFIG;
+ /**
+ * The AbortController used to cancel the filter processing.
+ */
+ abortController: AbortController | null = null;
+
$isFiltering = atom(false);
$hasProcessed = atom(false);
$isProcessing = atom(false);
@@ -100,63 +104,82 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
processImmediate = async () => {
const config = this.$filterConfig.get();
- const isValid = IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
+ const filterData = IMAGE_FILTERS[config.type];
+
+ // Cannot get TS to be happy with `config`, thinks it should be `never`... eh...
+ const isValid = filterData.validateConfig?.(config as never) ?? true;
if (!isValid) {
+ this.log.error({ config }, 'Invalid filter config');
return;
}
- this.log.trace({ config }, 'Previewing filter');
+ this.log.trace({ config }, 'Processing filter');
const rect = this.parent.transformer.getRelativeRect();
- const imageDTO = await this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } });
- const nodeId = getPrefixedId('filter_node');
- const batch = this.buildBatchConfig(imageDTO, config, nodeId);
-
- // Listen for the filter processing completion event
- const completedListener = async (event: S['InvocationCompleteEvent']) => {
- if (event.origin !== this.id || event.invocation_source_id !== nodeId) {
- return;
- }
- this.manager.socket.off('invocation_complete', completedListener);
- this.manager.socket.off('invocation_error', errorListener);
-
- this.log.trace({ event } as SerializableObject, 'Handling filter processing completion');
- const { result } = event;
- assert(result.type === 'image_output', `Processor did not return an image output, got: ${result}`);
-
- const imageDTO = await getImageDTO(result.image.image_name);
- assert(imageDTO, "Failed to fetch processor output's image DTO");
+ const rasterizeResult = await withResultAsync(() =>
+ this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } })
+ );
+ if (rasterizeResult.isErr()) {
+ this.log.error({ error: serializeError(rasterizeResult.error) }, 'Error rasterizing entity');
+ this.$isProcessing.set(false);
+ return;
+ }
- this.imageState = imageDTOToImageObject(imageDTO);
+ this.$isProcessing.set(true);
- await this.parent.bufferRenderer.setBuffer(this.imageState, true);
+ const imageDTO = rasterizeResult.value;
+ // Cannot get TS to be happy with `config`, thinks it should be `never`... eh...
+ const buildGraphResult = withResult(() => filterData.buildGraph(imageDTO, config as never));
+ if (buildGraphResult.isErr()) {
+ this.log.error({ error: serializeError(buildGraphResult.error) }, 'Error building filter graph');
this.$isProcessing.set(false);
- this.$hasProcessed.set(true);
- };
- const errorListener = (event: S['InvocationErrorEvent']) => {
- if (event.origin !== this.id || event.invocation_source_id !== nodeId) {
- return;
- }
- this.manager.socket.off('invocation_complete', completedListener);
- this.manager.socket.off('invocation_error', errorListener);
-
- this.log.error({ event } as SerializableObject, 'Error processing filter');
+ return;
+ }
+
+ const controller = new AbortController();
+ this.abortController = controller;
+
+ const { graph, outputNodeId } = buildGraphResult.value;
+ const filterResult = await withResultAsync(() =>
+ this.manager.stateApi.runGraphAndReturnImageOutput({
+ graph,
+ outputNodeId,
+ // The filter graph should always be prepended to the queue so it's processed ASAP.
+ prepend: true,
+ /**
+ * The filter node may need to download a large model. Currently, the models required by the filter nodes are
+ * downloaded just-in-time, as required by the filter. If we use a timeout here, we might get into a catch-22
+ * where the filter node is waiting for the model to download, but the download gets canceled if the filter
+ * node times out.
+ *
+ * (I suspect the model download will actually _not_ be canceled if the graph is canceled, but let's not chance it!)
+ *
+ * TODO(psyche): Figure out a better way to handle this. Probably need to download the models ahead of time.
+ */
+ // timeout: 5000,
+ /**
+ * The filter node should be able to cancel the request if it's taking too long. This will cancel the graph's
+ * queue item and clear any event listeners on the request.
+ */
+ signal: controller.signal,
+ })
+ );
+ if (filterResult.isErr()) {
+ this.log.error({ error: serializeError(filterResult.error) }, 'Error processing filter');
this.$isProcessing.set(false);
- };
+ this.abortController = null;
+ return;
+ }
- this.manager.socket.on('invocation_complete', completedListener);
- this.manager.socket.on('invocation_error', errorListener);
+ this.log.trace({ imageDTO: filterResult.value }, 'Filter processed');
+ this.imageState = imageDTOToImageObject(filterResult.value);
- this.log.trace({ batch } as SerializableObject, 'Enqueuing filter batch');
+ await this.parent.bufferRenderer.setBuffer(this.imageState, true);
- this.$isProcessing.set(true);
- const req = this.manager.stateApi.enqueueBatch(batch);
- const result = await withResultAsync(req.unwrap);
- if (result.isErr()) {
- this.$isProcessing.set(false);
- }
- req.reset();
+ this.$isProcessing.set(false);
+ this.$hasProcessed.set(true);
+ this.abortController = null;
};
process = debounce(this.processImmediate, this.config.processDebounceMs);
@@ -188,6 +211,8 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
reset = () => {
this.log.trace('Resetting filter');
+ this.abortController?.abort();
+ this.abortController = null;
this.parent.bufferRenderer.clearBuffer();
this.parent.transformer.updatePosition();
this.parent.renderer.syncCache(true);
@@ -205,31 +230,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.manager.stateApi.$filteringAdapter.set(null);
};
- buildBatchConfig = (imageDTO: ImageDTO, config: FilterConfig, id: string): BatchConfig => {
- // TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
- const node = IMAGE_FILTERS[config.type].buildNode(imageDTO, config as never);
- node.id = id;
- const batch: BatchConfig = {
- prepend: true,
- batch: {
- graph: {
- nodes: {
- [node.id]: {
- ...node,
- // filtered images are always intermediate - do not save to gallery
- is_intermediate: true,
- },
- },
- edges: [],
- },
- origin: this.id,
- runs: 1,
- },
- };
-
- return batch;
- };
-
repr = () => {
return {
id: this.id,
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts
index af1f5c35aaa..da7bc9d7999 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts
@@ -1,5 +1,6 @@
import { $authToken } from 'app/store/nanostores/authToken';
import { rgbColorToString } from 'common/util/colorCodeTransformers';
+import { withResult } from 'common/util/result';
import { SyncableMap } from 'common/util/SyncableMap/SyncableMap';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -27,7 +28,7 @@ import { debounce } from 'lodash-es';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
-import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
+import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
@@ -356,14 +357,25 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
};
/**
- * Rasterizes the parent entity. If the entity has a rasterization cache for the given rect, the cached image is
- * returned. Otherwise, the entity is rasterized and the image is uploaded to the server.
+ * Rasterizes the parent entity, returning a promise that resolves to the image DTO.
+ *
+ * If the entity has a rasterization cache for the given rect, the cached image is returned. Otherwise, the entity is
+ * rasterized and the image is uploaded to the server.
*
* The rasterization cache is reset when the entity's state changes. The buffer object is not considered part of the
* entity state for this purpose as it is a temporary object.
*
- * @param rect The rect to rasterize. If omitted, the entity's full rect will be used.
- * @returns A promise that resolves to the rasterized image DTO.
+ * If rasterization fails for any reason, the promise will reject.
+ *
+ * @param options The rasterization options.
+ * @param options.rect The region of the entity to rasterize.
+ * @param options.replaceObjects Whether to replace the entity's objects with the rasterized image. If you just want
+ * the entity's image, omit or set this to false.
+ * @param options.attrs The Konva node attributes to apply to the rasterized image group. For example, you might want
+ * to disable filters or set the opacity to the rasterized image.
+ * @param options.bg Draws the entity on a canvas with the given background color. If omitted, the entity is drawn on
+ * a transparent canvas.
+ * @returns A promise that resolves to the rasterized image DTO or rejects if rasterization fails.
*/
rasterize = async (options: {
rect: Rect;
@@ -383,7 +395,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
if (cachedImageName) {
- imageDTO = await getImageDTO(cachedImageName);
+ imageDTO = await getImageDTOSafe(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached rasterized image');
return imageDTO;
@@ -423,26 +435,38 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
if (this.parent.transformer.$isPendingRectCalculation.get()) {
return;
}
+
const pixelRect = this.parent.transformer.$pixelRect.get();
if (pixelRect.width === 0 || pixelRect.height === 0) {
return;
}
- try {
- // TODO(psyche): This is an internal Konva method, so it may break in the future. Can we make this API public?
- const canvas = this.konva.objectGroup._getCachedSceneCanvas()._canvas as HTMLCanvasElement | undefined | null;
- if (canvas) {
- const nodeRect = this.parent.transformer.$nodeRect.get();
- const rect = {
- x: pixelRect.x - nodeRect.x,
- y: pixelRect.y - nodeRect.y,
- width: pixelRect.width,
- height: pixelRect.height,
- };
- this.$canvasCache.set({ rect, canvas });
- }
- } catch (error) {
+
+ /**
+ * TODO(psyche): This is an internal Konva method, so it may break in the future. Can we make this API public?
+ *
+ * This method's API is unknown. It has been experimentally determined that it may throw, so we need to handle
+ * errors.
+ */
+ const getCacheCanvasResult = withResult(
+ () => this.konva.objectGroup._getCachedSceneCanvas()._canvas as HTMLCanvasElement | undefined | null
+ );
+ if (getCacheCanvasResult.isErr()) {
// We are using an internal Konva method, so we need to catch any errors that may occur.
- this.log.warn({ error: serializeError(error) }, 'Failed to update preview canvas');
+ this.log.warn({ error: serializeError(getCacheCanvasResult.error) }, 'Failed to update preview canvas');
+ return;
+ }
+
+ const canvas = getCacheCanvasResult.value;
+
+ if (canvas) {
+ const nodeRect = this.parent.transformer.$nodeRect.get();
+ const rect = {
+ x: pixelRect.x - nodeRect.x,
+ y: pixelRect.y - nodeRect.y,
+ width: pixelRect.width,
+ height: pixelRect.height,
+ };
+ this.$canvasCache.set({ rect, canvas });
}
}, 300);
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts
index d32ea51726b..7b0208bf41c 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts
@@ -1,3 +1,4 @@
+import { withResultAsync } from 'common/util/result';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -15,6 +16,7 @@ import type { GroupConfig } from 'konva/lib/Group';
import { debounce, get } from 'lodash-es';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
+import { serializeError } from 'serialize-error';
import { assert } from 'tsafe';
type CanvasEntityTransformerConfig = {
@@ -575,7 +577,12 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.log.debug('Applying transform');
this.$isProcessing.set(true);
const rect = this.getRelativeRect();
- await this.parent.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } });
+ const rasterizeResult = await withResultAsync(() =>
+ this.parent.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } })
+ );
+ if (rasterizeResult.isErr()) {
+ this.log.error({ error: serializeError(rasterizeResult.error) }, 'Failed to rasterize entity');
+ }
this.requestRectCalculation();
this.stopTransform();
};
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectImage.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectImage.ts
index f9b981bca41..eeacad2f49a 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectImage.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectImage.ts
@@ -11,7 +11,7 @@ import type { CanvasImageState } from 'features/controlLayers/store/types';
import { t } from 'i18next';
import Konva from 'konva';
import type { Logger } from 'roarr';
-import { getImageDTO } from 'services/api/endpoints/images';
+import { getImageDTOSafe } from 'services/api/endpoints/images';
export class CanvasObjectImage extends CanvasModuleBase {
readonly type = 'object_image';
@@ -100,7 +100,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
this.konva.placeholder.text.text(t('common.loadingImage', 'Loading Image'));
}
- const imageDTO = await getImageDTO(imageName);
+ const imageDTO = await getImageDTOSafe(imageName);
if (imageDTO === null) {
this.onFailedToLoadImage();
return;
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
index 5e7cbbad0f4..7bfbbe9dd46 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
@@ -2,6 +2,7 @@ import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppStore, RootState } from 'app/store/store';
+import { withResultAsync } from 'common/util/result';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -38,10 +39,13 @@ import type {
RgbaColor,
} from 'features/controlLayers/store/types';
import { RGBA_BLACK } from 'features/controlLayers/store/types';
+import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
+import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
-import type { BatchConfig } from 'services/api/types';
+import type { BatchConfig, ImageDTO, S } from 'services/api/types';
+import { QueueError } from 'services/events/errors';
import { assert } from 'tsafe';
import type { CanvasEntityAdapter } from './CanvasEntity/types';
@@ -187,14 +191,200 @@ export class CanvasStateApiModule extends CanvasModuleBase {
};
/**
- * Enqueues a batch, pushing state to redux.
- */
- enqueueBatch = (batch: BatchConfig) => {
- return this.store.dispatch(
+ * Run a graph and return an image output. The specified output node must return an image output, else the promise
+ * will reject with an error.
+ *
+ * @param arg The arguments for the function.
+ * @param arg.graph The graph to execute.
+ * @param arg.outputNodeId The id of the node whose output will be retrieved.
+ * @param arg.destination The destination to assign to the batch. If omitted, the destination is not set.
+ * @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue.
+ * @param arg.timeout The timeout for the batch. If omitted, there is no timeout.
+ * @param arg.signal An optional signal to cancel the operation. If omitted, the operation cannot be canceled!
+ *
+ * @returns A promise that resolves to the image output or rejects with an error.
+ *
+ * @example
+ *
+ * ```ts
+ * const graph = new Graph();
+ * const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } });
+ * const controller = new AbortController();
+ * const imageDTO = await this.manager.stateApi.runGraphAndReturnImageOutput({
+ * graph,
+ * outputNodeId: outputNode.id,
+ * prepend: true,
+ * signal: controller.signal,
+ * });
+ * // To cancel the operation:
+ * controller.abort();
+ * ```
+ */
+ runGraphAndReturnImageOutput = async (arg: {
+ graph: Graph;
+ outputNodeId: string;
+ destination?: string;
+ prepend?: boolean;
+ timeout?: number;
+ signal?: AbortSignal;
+ }): Promise => {
+ const { graph, outputNodeId, destination, prepend, timeout, signal } = arg;
+
+ /**
+ * We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a
+ * race condition:
+ * - The queue item id is not available until the graph is enqueued
+ * - The graph may complete before we can set up the listeners to handle the completion event
+ *
+ * The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued,
+ * so we will use that to filter events.
+ */
+ const origin = getPrefixedId(graph.id);
+
+ const batch: BatchConfig = {
+ prepend,
+ batch: {
+ graph: graph.getGraph(),
+ origin,
+ destination,
+ runs: 1,
+ },
+ };
+
+ /**
+ * If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout
+ * if the graph completes or errors before the timeout.
+ */
+ let timeoutId: number | null = null;
+ const _clearTimeout = () => {
+ if (timeoutId !== null) {
+ window.clearTimeout(timeoutId);
+ timeoutId = null;
+ }
+ };
+
+ /**
+ * First, enqueue the graph - we need the `batch_id` to cancel the graph. But to get the `batch_id`, we need to
+ * `await` the request. You might be tempted to `await` the request inside the result promise, but we should not
+ * `await` inside a promise executor.
+ *
+ * See: https://eslint.org/docs/latest/rules/no-async-promise-executor
+ */
+ const enqueueRequest = this.store.dispatch(
queueApi.endpoints.enqueueBatch.initiate(batch, {
+ // Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status
+ // updates.
fixedCacheKey: 'enqueueBatch',
+ // We do not need RTK to track this request in the store
+ track: false,
})
);
+
+ // The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect.
+ // TODO(psyche): Fix the OpenAPI schema.
+ const { batch_id } = (await enqueueRequest.unwrap()).batch;
+ assert(batch_id, 'Enqueue result is missing batch_id');
+
+ const resultPromise = new Promise((resolve, reject) => {
+ const invocationCompleteHandler = async (event: S['InvocationCompleteEvent']) => {
+ // Ignore events that are not for this graph
+ if (event.origin !== origin) {
+ return;
+ }
+ // Ignore events that are not from the output node
+ if (event.invocation_source_id !== outputNodeId) {
+ return;
+ }
+
+ // If we get here, the event is for the correct graph and output node.
+
+ // Clear the timeout and socket listeners
+ _clearTimeout();
+ clearListeners();
+
+ // The result must be an image output
+ const { result } = event;
+ if (result.type !== 'image_output') {
+ reject(new Error(`Graph output node did not return an image output, got: ${result}`));
+ return;
+ }
+
+ // Get the result image DTO
+ const getImageDTOResult = await withResultAsync(() => getImageDTO(result.image.image_name));
+ if (getImageDTOResult.isErr()) {
+ reject(getImageDTOResult.error);
+ return;
+ }
+
+ // Ok!
+ resolve(getImageDTOResult.value);
+ };
+
+ const queueItemStatusChangedHandler = (event: S['QueueItemStatusChangedEvent']) => {
+ // Ignore events that are not for this graph
+ if (event.origin !== origin) {
+ return;
+ }
+
+ // Ignore events where the status is pending or in progress - no need to do anything for these
+ if (event.status === 'pending' || event.status === 'in_progress') {
+ return;
+ }
+
+ // event.status is 'failed', 'canceled' or 'completed' - something has gone awry
+ _clearTimeout();
+ clearListeners();
+
+ if (event.status === 'completed') {
+ // If we get a queue item completed event, that means we never got a completion event for the output node!
+ reject(new Error('Queue item completed without output node completion event'));
+ } else if (event.status === 'failed') {
+ // We expect the event to have error details, but technically it's possible that it doesn't
+ const { error_type, error_message, error_traceback } = event;
+ if (error_type && error_message && error_traceback) {
+ reject(new QueueError(error_type, error_message, error_traceback));
+ } else {
+ reject(new Error('Queue item failed, but no error details were provided'));
+ }
+ } else {
+ // event.status is 'canceled'
+ reject(new Error('Graph canceled'));
+ }
+ };
+
+ this.manager.socket.on('invocation_complete', invocationCompleteHandler);
+ this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedHandler);
+
+ const clearListeners = () => {
+ this.manager.socket.off('invocation_complete', invocationCompleteHandler);
+ this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedHandler);
+ };
+
+ const cancelGraph = () => {
+ this.store.dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false }));
+ };
+
+ if (timeout) {
+ timeoutId = window.setTimeout(() => {
+ this.log.trace('Graph canceled by timeout');
+ clearListeners();
+ cancelGraph();
+ reject(new Error('Graph timed out'));
+ }, timeout);
+ }
+
+ if (signal) {
+ signal.addEventListener('abort', () => {
+ this.log.trace('Graph canceled by signal');
+ _clearTimeout();
+ clearListeners();
+ cancelGraph();
+ reject(new Error('Graph canceled'));
+ });
+ }
+ });
+
+ return resultPromise;
};
/**
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/filters.ts b/invokeai/frontend/web/src/features/controlLayers/store/filters.ts
index bd138666b27..5cfec9ee9f2 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/filters.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/filters.ts
@@ -1,7 +1,8 @@
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
-import type { AnyInvocation, ControlNetModelConfig, Invocation, T2IAdapterModelConfig } from 'services/api/types';
+import { Graph } from 'features/nodes/util/graph/generation/Graph';
+import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { z } from 'zod';
@@ -132,7 +133,10 @@ export const isFilterType = (v: unknown): v is FilterType => zFilterType.safePar
type ImageFilterData = {
type: T;
buildDefaults(): Extract;
- buildNode(imageDTO: ImageWithDims, config: Extract): AnyInvocation;
+ buildGraph(
+ imageDTO: ImageWithDims,
+ config: Extract
+ ): { graph: Graph; outputNodeId: string };
validateConfig?(config: Extract): boolean;
};
@@ -144,13 +148,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('canny_edge_detection'),
- type: 'canny_edge_detection',
- image: { image_name },
- low_threshold,
- high_threshold,
- }),
+ buildGraph: ({ image_name }, { low_threshold, high_threshold }) => {
+ const graph = new Graph(getPrefixedId('canny_edge_detection_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('canny_edge_detection'),
+ type: 'canny_edge_detection',
+ image: { image_name },
+ low_threshold,
+ high_threshold,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
color_map: {
type: 'color_map',
@@ -158,12 +169,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('color_map'),
- type: 'color_map',
- image: { image_name },
- tile_size,
- }),
+ buildGraph: ({ image_name }, { tile_size }) => {
+ const graph = new Graph(getPrefixedId('color_map_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('color_map'),
+ type: 'color_map',
+ image: { image_name },
+ tile_size,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
content_shuffle: {
type: 'content_shuffle',
@@ -171,12 +189,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('content_shuffle'),
- type: 'content_shuffle',
- image: { image_name },
- scale_factor,
- }),
+ buildGraph: ({ image_name }, { scale_factor }) => {
+ const graph = new Graph(getPrefixedId('content_shuffle_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('content_shuffle'),
+ type: 'content_shuffle',
+ image: { image_name },
+ scale_factor,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
depth_anything_depth_estimation: {
type: 'depth_anything_depth_estimation',
@@ -184,12 +209,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('depth_anything_depth_estimation'),
- type: 'depth_anything_depth_estimation',
- image: { image_name },
- model_size,
- }),
+ buildGraph: ({ image_name }, { model_size }) => {
+ const graph = new Graph(getPrefixedId('depth_anything_depth_estimation_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('depth_anything_depth_estimation'),
+ type: 'depth_anything_depth_estimation',
+ image: { image_name },
+ model_size,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
hed_edge_detection: {
type: 'hed_edge_detection',
@@ -197,23 +229,37 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('hed_edge_detection'),
- type: 'hed_edge_detection',
- image: { image_name },
- scribble,
- }),
+ buildGraph: ({ image_name }, { scribble }) => {
+ const graph = new Graph(getPrefixedId('hed_edge_detection_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('hed_edge_detection'),
+ type: 'hed_edge_detection',
+ image: { image_name },
+ scribble,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
lineart_anime_edge_detection: {
type: 'lineart_anime_edge_detection',
buildDefaults: () => ({
type: 'lineart_anime_edge_detection',
}),
- buildNode: ({ image_name }): Invocation<'lineart_anime_edge_detection'> => ({
- id: getPrefixedId('lineart_anime_edge_detection'),
- type: 'lineart_anime_edge_detection',
- image: { image_name },
- }),
+ buildGraph: ({ image_name }) => {
+ const graph = new Graph(getPrefixedId('lineart_anime_edge_detection_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('lineart_anime_edge_detection'),
+ type: 'lineart_anime_edge_detection',
+ image: { image_name },
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
lineart_edge_detection: {
type: 'lineart_edge_detection',
@@ -221,12 +267,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('lineart_edge_detection'),
- type: 'lineart_edge_detection',
- image: { image_name },
- coarse,
- }),
+ buildGraph: ({ image_name }, { coarse }) => {
+ const graph = new Graph(getPrefixedId('lineart_edge_detection_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('lineart_edge_detection'),
+ type: 'lineart_edge_detection',
+ image: { image_name },
+ coarse,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
mediapipe_face_detection: {
type: 'mediapipe_face_detection',
@@ -235,13 +288,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('mediapipe_face_detection'),
- type: 'mediapipe_face_detection',
- image: { image_name },
- max_faces,
- min_confidence,
- }),
+ buildGraph: ({ image_name }, { max_faces, min_confidence }) => {
+ const graph = new Graph(getPrefixedId('mediapipe_face_detection_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('mediapipe_face_detection'),
+ type: 'mediapipe_face_detection',
+ image: { image_name },
+ max_faces,
+ min_confidence,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
mlsd_detection: {
type: 'mlsd_detection',
@@ -250,24 +310,38 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('mlsd_detection'),
- type: 'mlsd_detection',
- image: { image_name },
- score_threshold,
- distance_threshold,
- }),
+ buildGraph: ({ image_name }, { score_threshold, distance_threshold }) => {
+ const graph = new Graph(getPrefixedId('mlsd_detection_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('mlsd_detection'),
+ type: 'mlsd_detection',
+ image: { image_name },
+ score_threshold,
+ distance_threshold,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
normal_map: {
type: 'normal_map',
buildDefaults: () => ({
type: 'normal_map',
}),
- buildNode: ({ image_name }): Invocation<'normal_map'> => ({
- id: getPrefixedId('normal_map'),
- type: 'normal_map',
- image: { image_name },
- }),
+ buildGraph: ({ image_name }) => {
+ const graph = new Graph(getPrefixedId('normal_map_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('normal_map'),
+ type: 'normal_map',
+ image: { image_name },
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
pidi_edge_detection: {
type: 'pidi_edge_detection',
@@ -276,13 +350,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('pidi_edge_detection'),
- type: 'pidi_edge_detection',
- image: { image_name },
- quantize_edges,
- scribble,
- }),
+ buildGraph: ({ image_name }, { quantize_edges, scribble }) => {
+ const graph = new Graph(getPrefixedId('pidi_edge_detection_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('pidi_edge_detection'),
+ type: 'pidi_edge_detection',
+ image: { image_name },
+ quantize_edges,
+ scribble,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
dw_openpose_detection: {
type: 'dw_openpose_detection',
@@ -292,14 +373,21 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({
- id: getPrefixedId('dw_openpose_detection'),
- type: 'dw_openpose_detection',
- image: { image_name },
- draw_body,
- draw_face,
- draw_hands,
- }),
+ buildGraph: ({ image_name }, { draw_body, draw_face, draw_hands }) => {
+ const graph = new Graph(getPrefixedId('dw_openpose_detection_filter'));
+ const node = graph.addNode({
+ id: getPrefixedId('dw_openpose_detection'),
+ type: 'dw_openpose_detection',
+ image: { image_name },
+ draw_body,
+ draw_face,
+ draw_hands,
+ });
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
+ },
},
spandrel_filter: {
type: 'spandrel_filter',
@@ -309,29 +397,30 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => {
+ buildGraph: ({ image_name }, { model, scale, autoScale }) => {
assert(model !== null);
- if (autoScale) {
- const node: Invocation<'spandrel_image_to_image_autoscale'> = {
- id: getPrefixedId('spandrel_image_to_image_autoscale'),
- type: 'spandrel_image_to_image_autoscale',
- image_to_image_model: model,
- image: { image_name },
- scale,
- };
- return node;
- } else {
- const node: Invocation<'spandrel_image_to_image'> = {
- id: getPrefixedId('spandrel_image_to_image'),
- type: 'spandrel_image_to_image',
- image_to_image_model: model,
- image: { image_name },
- };
- return node;
- }
+ const graph = new Graph(getPrefixedId('spandrel_filter'));
+ const node = graph.addNode(
+ autoScale
+ ? {
+ id: getPrefixedId('spandrel_image_to_image_autoscale'),
+ type: 'spandrel_image_to_image_autoscale',
+ image_to_image_model: model,
+ image: { image_name },
+ scale,
+ }
+ : {
+ id: getPrefixedId('spandrel_image_to_image'),
+ type: 'spandrel_image_to_image',
+ image_to_image_model: model,
+ image: { image_name },
+ }
+ );
+
+ return {
+ graph,
+ outputNodeId: node.id,
+ };
},
validateConfig: (config): boolean => {
if (!config.model) {
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
index 8befe563bb0..007e3567645 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
@@ -7,7 +7,7 @@ import {
zParameterNegativePrompt,
zParameterPositivePrompt,
} from 'features/parameters/types/parameterSchemas';
-import { getImageDTO } from 'services/api/endpoints/images';
+import { getImageDTOSafe } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { z } from 'zod';
@@ -31,7 +31,7 @@ const zImageWithDims = z
})
.refine(async (v) => {
const { image_name } = v;
- const imageDTO = await getImageDTO(image_name, true);
+ const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
return imageDTO !== null;
});
export type ImageWithDims = z.infer;
diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts
index f47ec375199..49b50225cc8 100644
--- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts
+++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts
@@ -67,7 +67,7 @@ import {
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import { get, isArray, isString } from 'lodash-es';
-import { getImageDTO } from 'services/api/endpoints/images';
+import { getImageDTOSafe } from 'services/api/endpoints/images';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
@@ -603,7 +603,7 @@ const parseIPAdapterToIPAdapterLayer: MetadataParseFunc {
it('should create a new graph with the correct id', () => {
const g = new Graph('test-id');
expect(g._graph.id).toBe('test-id');
+ expect(g.id).toBe('test-id');
});
- it('should create a new graph with a uuid id if none is provided', () => {
+ it('should create an id if none is provided', () => {
const g = new Graph();
expect(g._graph.id).not.toBeUndefined();
- expect(validate(g._graph.id)).toBeTruthy();
+ expect(g.id).not.toBeUndefined();
});
});
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts
index 950b25d71f6..9da719509b8 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts
@@ -32,10 +32,12 @@ export type GraphType = { id: string; nodes: Record; edge
export class Graph {
_graph: GraphType;
_metadataNodeId = getPrefixedId('core_metadata');
+ id: string;
constructor(id?: string) {
+ this.id = id ?? Graph.getId('graph');
this._graph = {
- id: id ?? uuidv4(),
+ id: this.id,
nodes: {},
edges: [],
};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
index deaddab38f6..347ca4fba4b 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
@@ -1,3 +1,5 @@
+import { logger } from 'app/logging/logger';
+import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type {
CanvasControlLayerState,
@@ -6,9 +8,12 @@ import type {
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
+import { serializeError } from 'serialize-error';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
+const log = logger('system');
+
type AddControlNetsResult = {
addedControlNets: number;
};
@@ -33,9 +38,17 @@ export const addControlNets = async (
for (const layer of validControlLayers) {
result.addedControlNets++;
- const adapter = manager.adapters.controlLayers.get(layer.id);
- assert(adapter, 'Adapter not found');
- const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
+ const getImageDTOResult = await withResultAsync(() => {
+ const adapter = manager.adapters.controlLayers.get(layer.id);
+ assert(adapter, 'Adapter not found');
+ return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
+ });
+ if (getImageDTOResult.isErr()) {
+ log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer');
+ continue;
+ }
+
+ const imageDTO = getImageDTOResult.value;
addControlNetToGraph(g, layer, imageDTO, collector);
}
@@ -66,9 +79,17 @@ export const addT2IAdapters = async (
for (const layer of validControlLayers) {
result.addedT2IAdapters++;
- const adapter = manager.adapters.controlLayers.get(layer.id);
- assert(adapter, 'Adapter not found');
- const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [], bg: 'black' } });
+ const getImageDTOResult = await withResultAsync(() => {
+ const adapter = manager.adapters.controlLayers.get(layer.id);
+ assert(adapter, 'Adapter not found');
+ return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
+ });
+ if (getImageDTOResult.isErr()) {
+ log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer');
+ continue;
+ }
+
+ const imageDTO = getImageDTOResult.value;
addT2IAdapterToGraph(g, layer, imageDTO, collector);
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts
index d85e862d508..dcce2046daa 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts
@@ -1,4 +1,6 @@
+import { logger } from 'app/logging/logger';
import { deepClone } from 'common/util/deepClone';
+import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
@@ -8,9 +10,12 @@ import type {
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
+import { serializeError } from 'serialize-error';
import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
+const log = logger('system');
+
type AddedRegionResult = {
addedPositivePrompt: boolean;
addedNegativePrompt: boolean;
@@ -64,9 +69,18 @@ export const addRegions = async (
addedAutoNegativePositivePrompt: false,
addedIPAdapters: 0,
};
- const adapter = manager.adapters.regionMasks.get(region.id);
- assert(adapter, 'Adapter not found');
- const imageDTO = await adapter.renderer.rasterize({ rect: bbox });
+
+ const getImageDTOResult = await withResultAsync(() => {
+ const adapter = manager.adapters.regionMasks.get(region.id);
+ assert(adapter, 'Adapter not found');
+ return adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } });
+ });
+ if (getImageDTOResult.isErr()) {
+ log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing region mask');
+ continue;
+ }
+
+ const imageDTO = getImageDTOResult.value;
// The main mask-to-tensor node
const maskToTensor = g.addNode({
diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts
index 90c076f283c..0d92302e031 100644
--- a/invokeai/frontend/web/src/services/api/endpoints/images.ts
+++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts
@@ -1,3 +1,4 @@
+import type { StartQueryActionCreatorOptions } from '@reduxjs/toolkit/dist/query/core/buildInitiate';
import { getStore } from 'app/store/nanostores/store';
import type { SerializableObject } from 'common/types';
import type { BoardId } from 'features/gallery/store/types';
@@ -568,25 +569,40 @@ export const {
/**
* Imperative RTKQ helper to fetch an ImageDTO.
* @param image_name The name of the image to fetch
- * @param forceRefetch Whether to force a refetch of the image
- * @returns
+ * @param options The options for the query. By default, the query will not subscribe to the store.
+ * @returns The ImageDTO if found, otherwise null
*/
-export const getImageDTO = async (image_name: string, forceRefetch?: boolean): Promise => {
- const options = {
+export const getImageDTOSafe = async (
+ image_name: string,
+ options?: StartQueryActionCreatorOptions
+): Promise => {
+ const _options = {
subscribe: false,
- forceRefetch,
+ ...options,
};
- const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, options));
+ const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, _options));
try {
- const imageDTO = await req.unwrap();
- req.unsubscribe();
- return imageDTO;
+ return await req.unwrap();
} catch {
- req.unsubscribe();
return null;
}
};
+/**
+ * Imperative RTKQ helper to fetch an ImageDTO.
+ * @param image_name The name of the image to fetch
+ * @param options The options for the query. By default, the query will not subscribe to the store.
+ * @raises Error if the image is not found or there is an error fetching the image
+ */
+export const getImageDTO = (image_name: string, options?: StartQueryActionCreatorOptions): Promise => {
+ const _options = {
+ subscribe: false,
+ ...options,
+ };
+ const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, _options));
+ return req.unwrap();
+};
+
export type UploadOptions = {
blob: Blob;
fileName: string;
@@ -596,7 +612,7 @@ export type UploadOptions = {
board_id?: BoardId;
metadata?: SerializableObject;
};
-export const uploadImage = async (arg: UploadOptions): Promise => {
+export const uploadImage = (arg: UploadOptions): Promise => {
const { blob, fileName, image_category, is_intermediate, crop_visible = false, board_id, metadata } = arg;
const { dispatch } = getStore();
@@ -612,5 +628,5 @@ export const uploadImage = async (arg: UploadOptions): Promise => {
})
);
req.reset();
- return await req.unwrap();
+ return req.unwrap();
};
diff --git a/invokeai/frontend/web/src/services/events/errors.ts b/invokeai/frontend/web/src/services/events/errors.ts
new file mode 100644
index 00000000000..24100939e90
--- /dev/null
+++ b/invokeai/frontend/web/src/services/events/errors.ts
@@ -0,0 +1,23 @@
+/**
+ * A custom error class for queue event errors. These errors have a type, message and traceback.
+ */
+
+export class QueueError extends Error {
+ type: string;
+ traceback: string;
+
+ constructor(type: string, message: string, traceback: string) {
+ super(message);
+ this.name = 'QueueError';
+ this.type = type;
+ this.traceback = traceback;
+
+ if (Error.captureStackTrace) {
+ Error.captureStackTrace(this, QueueError);
+ }
+ }
+
+ toString() {
+ return `${this.name} [${this.type}]: ${this.message}\nTraceback:\n${this.traceback}`;
+ }
+}
diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx
index 7c7e511e48d..177d4599235 100644
--- a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx
+++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx
@@ -7,7 +7,7 @@ import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } fro
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { boardsApi } from 'services/api/endpoints/boards';
-import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
+import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { $lastProgressEvent } from 'services/events/stores';
@@ -87,10 +87,8 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
const getResultImageDTO = (data: S['InvocationCompleteEvent']) => {
const { result } = data;
- if (result.type === 'image_output') {
- return getImageDTO(result.image.image_name);
- } else if (result.type === 'canvas_v2_mask_and_crop_output') {
- return getImageDTO(result.image.image_name);
+ if (result.type === 'image_output' || result.type === 'canvas_v2_mask_and_crop_output') {
+ return getImageDTOSafe(result.image.image_name);
}
return null;
};