Skip to content

Commit

Permalink
feat(ui): handle control adapter processed images
Browse files Browse the repository at this point in the history
- Add helper functions to build metadata for control adapters, including the processed images
- Update parses to parse the new metadata
  • Loading branch information
psychedelicious authored and brandonrising committed Mar 14, 2024
1 parent 76296cc commit f11e173
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 71 deletions.
37 changes: 30 additions & 7 deletions invokeai/frontend/web/src/features/metadata/util/parsers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,14 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
const control_model = await getProperty(metadataItem, 'control_model');
const key = await getModelKey(control_model, 'controlnet');
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
const image = zControlField.shape.image.nullish().catch(null).parse(await getProperty(metadataItem, 'image'));
const image = zControlField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const processedImage = zControlField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'processed_image'));
const control_weight = zControlField.shape.control_weight
.nullish()
.catch(null)
Expand Down Expand Up @@ -259,7 +266,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
controlMode: control_mode ?? initialControlNet.controlMode,
resizeMode: resize_mode ?? initialControlNet.resizeMode,
controlImage: image?.image_name ?? null,
processedControlImage: image?.image_name ?? null,
processedControlImage: processedImage?.image_name ?? null,
processorType,
processorNode,
shouldAutoConfig: true,
Expand All @@ -283,8 +290,18 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);

const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(await getProperty(metadataItem, 'image'));
const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(await getProperty(metadataItem, 'weight'));
const image = zT2IAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const processedImage = zT2IAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'processed_image'));
const weight = zT2IAdapterField.shape.weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
Expand All @@ -309,7 +326,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
controlImage: image?.image_name ?? null,
processedControlImage: image?.image_name ?? null,
processedControlImage: processedImage?.image_name ?? null,
processorType,
processorNode,
shouldAutoConfig: true,
Expand All @@ -333,8 +350,14 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);

const image = zIPAdapterField.shape.image.nullish().catch(null).parse(await getProperty(metadataItem, 'image'));
const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(await getProperty(metadataItem, 'weight'));
const image = zIPAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const weight = zIPAdapterField.shape.weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { assert } from 'tsafe';

import { CONTROL_NET_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
Expand Down Expand Up @@ -70,34 +74,12 @@ export const addControlNetToLinearGraph = async (
resize_mode: resizeMode,
control_model: model,
control_weight: weight,
image: buildControlImage(controlImage, processedControlImage, processorType),
};

if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip CAs without an unprocessed image - should never happen, we already filtered the list of valid CAs
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;

graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;

controlNetMetadata.push({
control_model: model,
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: controlNetNode.image,
});
controlNetMetadata.push(buildControlNetMetadata(controlNet));

graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
Expand All @@ -110,3 +92,62 @@ export const addControlNetToLinearGraph = async (
upsertMetadata(graph, { controlnets: controlNetMetadata });
}
};

const buildControlImage = (
controlImage: string | null,
processedControlImage: string | null,
processorType: ControlAdapterProcessorType
): ImageField => {
let image: ImageField | null = null;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
image = {
image_name: controlImage,
};
}
assert(image, 'ControlNet image is required');
return image;
};

const buildControlNetMetadata = (controlNet: ControlNetConfig): S['ControlNetMetadataField'] => {
const {
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;

assert(model, 'ControlNet model is required');

const processed_image =
processedControlImage && processorType !== 'none'
? {
image_name: processedControlImage,
}
: null;

assert(controlImage, 'ControlNet image is required');

return {
control_model: model,
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: {
image_name: controlImage,
},
processed_image,
};
};
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { assert } from 'tsafe';

import { IP_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
Expand Down Expand Up @@ -44,7 +48,10 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) {
return;
}
const { id, weight, model, beginStepPct, endStepPct } = ipAdapter;
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;

assert(controlImage, 'IP Adapter image is required');

const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
Expand All @@ -53,25 +60,14 @@ export const addIPAdapterToLinearGraph = async (
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: {
image_name: controlImage,
},
};

if (ipAdapter.controlImage) {
ipAdapterNode.image = {
image_name: ipAdapter.controlImage,
};
} else {
return;
}

graph.nodes[ipAdapterNode.id] = ipAdapterNode;

ipAdapterMetdata.push({
weight: weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: ipAdapterNode.image,
});
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapter));

graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
Expand All @@ -85,3 +81,27 @@ export const addIPAdapterToLinearGraph = async (
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
}
};

const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;

assert(model, 'IP Adapter model is required');

let image: ImageField | null = null;

if (controlImage) {
image = {
image_name: controlImage,
};
}

assert(image, 'IP Adapter image is required');

return {
ip_adapter_model: model,
weight,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image,
};
};
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
T2IAdapterInvocation,
} from 'services/api/types';
import { assert } from 'tsafe';

import { T2I_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
Expand Down Expand Up @@ -68,33 +72,12 @@ export const addT2IAdaptersToLinearGraph = async (
resize_mode: resizeMode,
t2i_adapter_model: model,
weight: weight,
image: buildControlImage(controlImage, processedControlImage, processorType),
};

if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
t2iAdapterNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
t2iAdapterNode.image = {
image_name: controlImage,
};
} else {
// Skip CAs without an unprocessed image - should never happen, we already filtered the list of valid CAs
return;
}

graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;

t2iAdapterMetadata.push({
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
t2i_adapter_model: t2iAdapter.model,
weight: weight,
image: t2iAdapterNode.image,
});
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapter));

graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
Expand All @@ -108,3 +91,52 @@ export const addT2IAdaptersToLinearGraph = async (
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
}
};

const buildControlImage = (
controlImage: string | null,
processedControlImage: string | null,
processorType: ControlAdapterProcessorType
): ImageField => {
let image: ImageField | null = null;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
image = {
image_name: controlImage,
};
}
assert(image, 'T2I Adapter image is required');
return image;
};

const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfig): S['T2IAdapterMetadataField'] => {
const { controlImage, processedControlImage, beginStepPct, endStepPct, resizeMode, model, processorType, weight } =
t2iAdapter;

assert(model, 'T2I Adapter model is required');

const processed_image =
processedControlImage && processorType !== 'none'
? {
image_name: processedControlImage,
}
: null;

assert(controlImage, 'T2I Adapter image is required');

return {
t2i_adapter_model: model,
weight,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: {
image_name: controlImage,
},
processed_image,
};
};

0 comments on commit f11e173

Please sign in to comment.