Skip to content

Commit

Permalink
feat(ui): use updated progress event in frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious authored and hipsterusername committed Sep 22, 2024
1 parent 7ab7fa8 commit a9f93c1
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';
import type { O } from 'ts-toolbelt';

type ProgressEventWithImage = O.NonNullable<S['InvocationProgressEvent'], 'image'>;
const isProgressEventWithImage = (val: S['InvocationProgressEvent']): val is ProgressEventWithImage =>
Boolean(val.image);

export class CanvasProgressImageModule extends CanvasModuleBase {
readonly type = 'progress_image';
Expand All @@ -26,7 +31,7 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
imageElement: HTMLImageElement | null = null;

subscriptions = new Set<() => void>();
$lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
$lastProgressEvent = atom<ProgressEventWithImage | null>(null);
hasActiveGeneration: boolean = false;
mutex: Mutex = new Mutex();

Expand Down Expand Up @@ -62,10 +67,13 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
}

setSocketEventListeners = (): (() => void) => {
const progressListener = (data: S['InvocationDenoiseProgressEvent']) => {
const progressListener = (data: S['InvocationProgressEvent']) => {
if (data.destination !== 'canvas') {
return;
}
if (!isProgressEventWithImage(data)) {
return;
}
if (!this.hasActiveGeneration) {
return;
}
Expand All @@ -76,13 +84,13 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
this.$lastProgressEvent.set(null);
};

this.manager.socket.on('invocation_denoise_progress', progressListener);
this.manager.socket.on('invocation_progress', progressListener);
this.manager.socket.on('connect', clearProgress);
this.manager.socket.on('connect_error', clearProgress);
this.manager.socket.on('disconnect', clearProgress);

return () => {
this.manager.socket.off('invocation_denoise_progress', progressListener);
this.manager.socket.off('invocation_progress', progressListener);
this.manager.socket.off('connect', clearProgress);
this.manager.socket.off('connect_error', clearProgress);
this.manager.socket.off('disconnect', clearProgress);
Expand Down Expand Up @@ -111,9 +119,8 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
this.isLoading = true;

const { x, y, width, height } = this.manager.stateApi.getBbox().rect;
const { dataURL } = event.progress_image;
try {
this.imageElement = await loadImage(dataURL);
this.imageElement = await loadImage(event.image.dataURL);
if (this.konva.image) {
this.konva.image.setAttrs({
image: this.imageElement,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,10 @@ const CurrentImageNode = (props: NodeProps) => {
const imageDTO = useAppSelector(selectLastSelectedImage);
const lastProgressEvent = useStore($lastProgressEvent);

if (lastProgressEvent?.progress_image) {
if (lastProgressEvent?.image) {
return (
<Wrapper nodeProps={props}>
<Image
src={lastProgressEvent?.progress_image.dataURL}
w="full"
h="full"
objectFit="contain"
borderRadius="base"
/>
<Image src={lastProgressEvent?.image.dataURL} w="full" h="full" objectFit="contain" borderRadius="base" />
</Wrapper>
);
}
Expand Down
24 changes: 15 additions & 9 deletions invokeai/frontend/web/src/services/events/setEventListeners.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { forEach, isNil, round } from 'lodash-es';
import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
Expand Down Expand Up @@ -81,22 +81,28 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
}
});

socket.on('invocation_denoise_progress', (data) => {
const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage } = data;
socket.on('invocation_progress', (data) => {
const { invocation_source_id, invocation, image, origin, percentage, message } = data;

log.trace(
{ data } as SerializableObject,
`Denoise ${Math.round(percentage * 100)}% (${invocation.type}, ${invocation_source_id})`
);
let _message = 'Invocation progress';
if (message) {
_message += `: ${message}`;
}
if (!isNil(percentage)) {
_message += ` ${round(percentage * 100, 2)}%`;
}
_message += ` (${invocation.type}, ${invocation_source_id})`;

log.trace({ data } as SerializableObject, _message);

$lastProgressEvent.set(data);

if (origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
nes.progress = percentage;
nes.progressImage = image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
}
Expand Down
4 changes: 2 additions & 2 deletions invokeai/frontend/web/src/services/events/stores.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { ManagerOptions, SocketOptions } from 'socket.io-client';
export const $socket = atom<AppSocket | null>(null);
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
export const $isConnected = atom<boolean>(false);
export const $lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
export const $lastProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
export const $progressImage = computed($lastProgressEvent, (val) => val?.image ?? null);
export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas');
2 changes: 1 addition & 1 deletion invokeai/frontend/web/src/services/events/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ type ClientEmitSubscribeBulkDownload = { bulk_download_id: string };
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;

export type ServerToClientEvents = {
invocation_denoise_progress: (payload: S['InvocationDenoiseProgressEvent']) => void;
invocation_progress: (payload: S['InvocationProgressEvent']) => void;
invocation_complete: (payload: S['InvocationCompleteEvent']) => void;
invocation_error: (payload: S['InvocationErrorEvent']) => void;
invocation_started: (payload: S['InvocationStartedEvent']) => void;
Expand Down

0 comments on commit a9f93c1

Please sign in to comment.