Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): regional prompting followups #6247

Merged
merged 11 commits into from
Apr 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
}
const { dispatch } = getStore();
// TODO: Handle non-SDXL
// const isSDXL = state.generation.model?.base === 'sdxl';
const isSDXL = state.generation.model?.base === 'sdxl';
const layers = state.regionalPrompts.present.layers
.filter(isRPLayer) // We only want the prompt region layers
.filter((l) => l.isVisible) // Only visible layers are rendered on the canvas
Expand Down Expand Up @@ -125,12 +125,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull

if (layer.positivePrompt) {
// The main positive conditioning node
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] = {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
};
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;

// Connect the mask to the conditioning
Expand Down Expand Up @@ -158,12 +164,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull

if (layer.negativePrompt) {
// The main negative conditioning node
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] = {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
style: layer.negativePrompt,
};
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
style: layer.negativePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
};
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;

// Connect the mask to the conditioning
Expand Down Expand Up @@ -212,12 +224,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull

// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
// positive prompt
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] = {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt,
};
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
// Connect the inverted mask to the conditioning
graph.edges.push({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';

Expand Down Expand Up @@ -255,6 +256,8 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non

await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);

await addRegionalPromptsToGraph(state, graph, DENOISE_LATENTS);

// High resolution fix.
if (state.hrf.hrfEnabled) {
addHrfToGraph(state, graph);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Flex } from '@invoke-ai/ui-library';
import type { Meta, StoryObj } from '@storybook/react';
import { RegionalPromptsEditor } from 'features/regionalPrompts/components/RegionalPromptsEditor';

Expand All @@ -11,7 +12,7 @@ export default meta;
type Story = StoryObj<typeof RegionalPromptsEditor>;

const Component = () => {
return <RegionalPromptsEditor />;
return <Flex w={1500} h={1500}><RegionalPromptsEditor /></Flex>
};

export const Default: Story = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
$tool,
isRPLayer,
rpLayerBboxChanged,
rpLayerSelected,
rpLayerTranslated,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
Expand Down Expand Up @@ -49,12 +50,19 @@ const useStageRenderer = (container: HTMLDivElement | null, wrapper: HTMLDivElem
);

const onBboxChanged = useCallback(
(layerId: string, bbox: IRect) => {
(layerId: string, bbox: IRect | null) => {
dispatch(rpLayerBboxChanged({ layerId, bbox }));
},
[dispatch]
);

const onBboxMouseDown = useCallback(
(layerId: string) => {
dispatch(rpLayerSelected(layerId));
},
[dispatch]
);

useLayoutEffect(() => {
log.trace('Initializing stage');
if (!container) {
Expand Down Expand Up @@ -138,8 +146,8 @@ const useStageRenderer = (container: HTMLDivElement | null, wrapper: HTMLDivElem
if (!stage) {
return;
}
renderBbox(stage, tool, state.selectedLayerId, onBboxChanged);
}, [dispatch, stage, tool, state.selectedLayerId, onBboxChanged]);
renderBbox(stage, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown);
}, [dispatch, stage, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown]);
};

const $container = atom<HTMLDivElement | null>(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import type { IRect, Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash-es';
import { atom } from 'nanostores';
import type { RgbColor } from 'react-colorful';
import type { UndoableOptions } from 'redux-undo';
Expand Down Expand Up @@ -53,6 +54,8 @@ export type RegionalPromptLayer = LayerBase & {
x: number;
y: number;
bbox: IRect | null;
bboxNeedsUpdate: boolean;
hasEraserStrokes: boolean;
kind: 'regionalPromptLayer';
objects: LayerObject[];
positivePrompt: string;
Expand Down Expand Up @@ -90,8 +93,10 @@ export const regionalPromptsSlice = createSlice({
reducers: {
//#region Meta Layer
layerAdded: {
reducer: (state, action: PayloadAction<Layer['kind'], string, { uuid: string; color: RgbColor }>) => {
reducer: (state, action: PayloadAction<Layer['kind'], string, { uuid: string }>) => {
if (action.payload === 'regionalPromptLayer') {
const lastColor = state.layers[state.layers.length - 1]?.color;
const color = LayerColors.next(lastColor);
const layer: RegionalPromptLayer = {
id: getRPLayerId(action.meta.uuid),
isVisible: true,
Expand All @@ -100,17 +105,19 @@ export const regionalPromptsSlice = createSlice({
positivePrompt: '',
negativePrompt: '',
objects: [],
color: action.meta.color,
color,
x: 0,
y: 0,
autoNegative: 'off',
bboxNeedsUpdate: false,
hasEraserStrokes: false,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
return;
}
},
prepare: (payload: Layer['kind']) => ({ payload, meta: { uuid: uuidv4(), color: LayerColors.next() } }),
prepare: (payload: Layer['kind']) => ({ payload, meta: { uuid: uuidv4() } }),
},
layerDeleted: (state, action: PayloadAction<string>) => {
state.layers = state.layers.filter((l) => l.id !== action.payload);
Expand Down Expand Up @@ -154,6 +161,8 @@ export const regionalPromptsSlice = createSlice({
layer.objects = [];
layer.bbox = null;
layer.isVisible = true;
layer.hasEraserStrokes = false;
layer.bboxNeedsUpdate = false;
}
},
rpLayerTranslated: (state, action: PayloadAction<{ layerId: string; x: number; y: number }>) => {
Expand All @@ -169,6 +178,7 @@ export const regionalPromptsSlice = createSlice({
const layer = state.layers.find((l) => l.id === layerId);
if (isRPLayer(layer)) {
layer.bbox = bbox;
layer.bboxNeedsUpdate = false;
}
},
allLayersDeleted: (state) => {
Expand Down Expand Up @@ -218,6 +228,10 @@ export const regionalPromptsSlice = createSlice({
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
strokeWidth: state.brushSize,
});
layer.bboxNeedsUpdate = true;
if (!layer.hasEraserStrokes && tool === 'eraser') {
layer.hasEraserStrokes = true;
}
}
},
prepare: (payload: { layerId: string; points: [number, number, number, number]; tool: DrawingTool }) => ({
Expand All @@ -236,6 +250,7 @@ export const regionalPromptsSlice = createSlice({
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
layer.bboxNeedsUpdate = true;
}
},
rpLayerAutoNegativeChanged: (
Expand Down Expand Up @@ -268,18 +283,25 @@ export const regionalPromptsSlice = createSlice({
*/
class LayerColors {
static COLORS: RgbColor[] = [
{ r: 200, g: 0, b: 0 },
{ r: 0, g: 200, b: 0 },
{ r: 0, g: 0, b: 200 },
{ r: 200, g: 200, b: 0 },
{ r: 0, g: 200, b: 200 },
{ r: 200, g: 0, b: 200 },
{ r: 123, g: 159, b: 237 }, // rgb(123, 159, 237)
{ r: 106, g: 222, b: 106 }, // rgb(106, 222, 106)
{ r: 250, g: 225, b: 80 }, // rgb(250, 225, 80)
{ r: 233, g: 137, b: 81 }, // rgb(233, 137, 81)
{ r: 229, g: 96, b: 96 }, // rgb(229, 96, 96)
{ r: 226, g: 122, b: 210 }, // rgb(226, 122, 210)
{ r: 167, g: 116, b: 234 }, // rgb(167, 116, 234)
];
static i = this.COLORS.length - 1;
/**
* Get the next color in the sequence.
* Get the next color in the sequence. If a known color is provided, the next color will be the one after it.
*/
static next(): RgbColor {
static next(currentColor?: RgbColor): RgbColor {
if (currentColor) {
const i = this.COLORS.findIndex((c) => isEqual(c, currentColor));
if (i !== -1) {
this.i = i;
}
}
this.i = (this.i + 1) % this.COLORS.length;
const color = this.COLORS[this.i];
assert(color);
Expand Down Expand Up @@ -343,7 +365,6 @@ const getRPLayerId = (layerId: string) => `rp_layer_${layerId}`;
const getRPLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
export const getRPLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`;
export const getPRLayerBboxId = (layerId: string) => `${layerId}.bbox`;
export const getRPLayerTransparencyRectId = (layerId: string) => `${layerId}.transparency_rect`;

export const regionalPromptsPersistConfig: PersistConfig<RegionalPromptsState> = {
name: regionalPromptsSlice.name,
Expand All @@ -363,7 +384,8 @@ const undoableGroupByMatcher = isAnyOf(
isEnabledChanged,
rpLayerPositivePromptChanged,
rpLayerNegativePromptChanged,
rpLayerTranslated
rpLayerTranslated,
rpLayerColorChanged
);

const LINE_1 = 'LINE_1';
Expand Down
Loading
Loading