Skip to content

Commit 9e4c5ce

Browse files
feat(ui): add flux redux image influence to canvas
1 parent 237f4e3 commit 9e4c5ce

File tree

10 files changed

+204
-5
lines changed

10 files changed

+204
-5
lines changed

invokeai/frontend/web/public/locales/en.json

+8
Original file line numberDiff line numberDiff line change
@@ -2020,6 +2020,14 @@
20202020
"composition": "Composition Only",
20212021
"compositionDesc": "Replicates layout & structure while ignoring the reference's style."
20222022
},
2023+
"fluxReduxImageInfluence": {
2024+
"imageInfluence": "Image Influence",
2025+
"lowest": "Lowest",
2026+
"low": "Low",
2027+
"medium": "Medium",
2028+
"high": "High",
2029+
"highest": "Highest"
2030+
},
20232031
"fill": {
20242032
"fillColor": "Fill Color",
20252033
"fillStyle": "Fill Style",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
2+
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
3+
import type { FLUXReduxImageInfluence as FLUXReduxImageInfluenceType } from 'features/controlLayers/store/types';
4+
import { isFLUXReduxImageInfluence } from 'features/controlLayers/store/types';
5+
import { memo, useCallback, useMemo } from 'react';
6+
import { useTranslation } from 'react-i18next';
7+
import { assert } from 'tsafe';
8+
9+
type Props = {
10+
imageInfluence: FLUXReduxImageInfluenceType;
11+
onChange: (imageInfluence: FLUXReduxImageInfluenceType) => void;
12+
};
13+
14+
export const FLUXReduxImageInfluence = memo(({ imageInfluence, onChange }: Props) => {
15+
const { t } = useTranslation();
16+
17+
const options = useMemo(
18+
() =>
19+
[
20+
{
21+
label: t('controlLayers.fluxReduxImageInfluence.lowest'),
22+
value: 'lowest',
23+
},
24+
{
25+
label: t('controlLayers.fluxReduxImageInfluence.low'),
26+
value: 'low',
27+
},
28+
{
29+
label: t('controlLayers.fluxReduxImageInfluence.medium'),
30+
value: 'medium',
31+
},
32+
{
33+
label: t('controlLayers.fluxReduxImageInfluence.high'),
34+
value: 'high',
35+
},
36+
{
37+
label: t('controlLayers.fluxReduxImageInfluence.highest'),
38+
value: 'highest',
39+
},
40+
] satisfies { label: string; value: FLUXReduxImageInfluenceType }[],
41+
[t]
42+
);
43+
const _onChange = useCallback<ComboboxOnChange>(
44+
(v) => {
45+
assert(isFLUXReduxImageInfluence(v?.value));
46+
onChange(v.value);
47+
},
48+
[onChange]
49+
);
50+
const value = useMemo(() => options.find((o) => o.value === imageInfluence), [options, imageInfluence]);
51+
52+
return (
53+
<FormControl>
54+
<FormLabel m={0}>{t('controlLayers.fluxReduxImageInfluence.imageInfluence')}</FormLabel>
55+
<Combobox value={value} options={options} onChange={_onChange} />
56+
</FormControl>
57+
);
58+
});
59+
60+
FLUXReduxImageInfluence.displayName = 'FLUXReduxImageInfluence';

invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export const IPAdapterImagePreview = memo(
6161
)}
6262
{imageDTO && (
6363
<>
64-
<DndImage imageDTO={imageDTO} />
64+
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle='solid' />
6565
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
6666
<DndImageIcon
6767
onClick={handleResetControlImage}

invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterMethod.tsx

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
5050
return (
5151
<FormControl>
5252
<InformationalPopover feature="ipAdapterMethod">
53-
<FormLabel>{t('controlLayers.ipAdapterMethod.ipAdapterMethod')}</FormLabel>
53+
<FormLabel m={0}>{t('controlLayers.ipAdapterMethod.ipAdapterMethod')}</FormLabel>
5454
</InformationalPopover>
5555
<Combobox value={value} options={options} onChange={_onChange} />
5656
</FormControl>

invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterSettings.tsx

+23-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginE
55
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
66
import { Weight } from 'features/controlLayers/components/common/Weight';
77
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
8+
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
89
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
910
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
1011
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -13,14 +14,20 @@ import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
1314
import {
1415
referenceImageIPAdapterBeginEndStepPctChanged,
1516
referenceImageIPAdapterCLIPVisionModelChanged,
17+
referenceImageIPAdapterFLUXReduxImageInfluenceChanged,
1618
referenceImageIPAdapterImageChanged,
1719
referenceImageIPAdapterMethodChanged,
1820
referenceImageIPAdapterModelChanged,
1921
referenceImageIPAdapterWeightChanged,
2022
} from 'features/controlLayers/store/canvasSlice';
2123
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
2224
import { selectCanvasSlice, selectEntity, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
23-
import type { CanvasEntityIdentifier, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
25+
import type {
26+
CanvasEntityIdentifier,
27+
CLIPVisionModelV2,
28+
FLUXReduxImageInfluence as FLUXReduxImageInfluenceType,
29+
IPMethodV2,
30+
} from 'features/controlLayers/store/types';
2431
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
2532
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
2633
import { memo, useCallback, useMemo } from 'react';
@@ -65,6 +72,13 @@ const IPAdapterSettingsContent = memo(() => {
6572
[dispatch, entityIdentifier]
6673
);
6774

75+
const onChangeFLUXReduxImageInfluence = useCallback(
76+
(imageInfluence: FLUXReduxImageInfluenceType) => {
77+
dispatch(referenceImageIPAdapterFLUXReduxImageInfluenceChanged({ entityIdentifier, imageInfluence }));
78+
},
79+
[dispatch, entityIdentifier]
80+
);
81+
6882
const onChangeModel = useCallback(
6983
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
7084
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
@@ -124,6 +138,14 @@ const IPAdapterSettingsContent = memo(() => {
124138
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
125139
</Flex>
126140
)}
141+
{ipAdapter.type === 'flux_redux' && (
142+
<Flex flexDir="column" gap={2} w="full" alignItems="flex-start">
143+
<FLUXReduxImageInfluence
144+
imageInfluence={ipAdapter.imageInfluence ?? 'lowest'}
145+
onChange={onChangeFLUXReduxImageInfluence}
146+
/>
147+
</Flex>
148+
)}
127149
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
128150
<IPAdapterImagePreview
129151
image={ipAdapter.image}

invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettings.tsx

+23-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
44
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
55
import { Weight } from 'features/controlLayers/components/common/Weight';
66
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
7+
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
78
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
89
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
910
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
@@ -15,13 +16,19 @@ import {
1516
rgIPAdapterBeginEndStepPctChanged,
1617
rgIPAdapterCLIPVisionModelChanged,
1718
rgIPAdapterDeleted,
19+
rgIPAdapterFLUXReduxImageInfluenceChanged,
1820
rgIPAdapterImageChanged,
1921
rgIPAdapterMethodChanged,
2022
rgIPAdapterModelChanged,
2123
rgIPAdapterWeightChanged,
2224
} from 'features/controlLayers/store/canvasSlice';
2325
import { selectCanvasSlice, selectRegionalGuidanceReferenceImage } from 'features/controlLayers/store/selectors';
24-
import type { CanvasEntityIdentifier, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
26+
import type {
27+
CanvasEntityIdentifier,
28+
CLIPVisionModelV2,
29+
FLUXReduxImageInfluence as FLUXReduxImageInfluenceType,
30+
IPMethodV2,
31+
} from 'features/controlLayers/store/types';
2532
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
2633
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
2734
import { memo, useCallback, useMemo } from 'react';
@@ -73,6 +80,13 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
7380
[dispatch, entityIdentifier, referenceImageId]
7481
);
7582

83+
const onChangeFLUXReduxImageInfluence = useCallback(
84+
(imageInfluence: FLUXReduxImageInfluenceType) => {
85+
dispatch(rgIPAdapterFLUXReduxImageInfluenceChanged({ entityIdentifier, referenceImageId, imageInfluence }));
86+
},
87+
[dispatch, entityIdentifier, referenceImageId]
88+
);
89+
7690
const onChangeModel = useCallback(
7791
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
7892
dispatch(rgIPAdapterModelChanged({ entityIdentifier, referenceImageId, modelConfig }));
@@ -151,6 +165,14 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
151165
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
152166
</Flex>
153167
)}
168+
{ipAdapter.type === 'flux_redux' && (
169+
<Flex flexDir="column" gap={2} w="full">
170+
<FLUXReduxImageInfluence
171+
imageInfluence={ipAdapter.imageInfluence ?? 'lowest'}
172+
onChange={onChangeFLUXReduxImageInfluence}
173+
/>
174+
</Flex>
175+
)}
154176
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
155177
<IPAdapterImagePreview
156178
image={ipAdapter.image}

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

+37
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import type {
2121
ControlLoRAConfig,
2222
EntityMovedByPayload,
2323
FillStyle,
24+
FLUXReduxImageInfluence,
2425
RegionalGuidanceReferenceImageState,
2526
RgbColor,
2627
} from 'features/controlLayers/store/types';
@@ -626,6 +627,20 @@ export const canvasSlice = createSlice({
626627
}
627628
entity.ipAdapter.method = method;
628629
},
630+
referenceImageIPAdapterFLUXReduxImageInfluenceChanged: (
631+
state,
632+
action: PayloadAction<EntityIdentifierPayload<{ imageInfluence: FLUXReduxImageInfluence }, 'reference_image'>>
633+
) => {
634+
const { entityIdentifier, imageInfluence } = action.payload;
635+
const entity = selectEntity(state, entityIdentifier);
636+
if (!entity) {
637+
return;
638+
}
639+
if (entity.ipAdapter.type !== 'flux_redux') {
640+
return;
641+
}
642+
entity.ipAdapter.imageInfluence = imageInfluence;
643+
},
629644
referenceImageIPAdapterModelChanged: (
630645
state,
631646
action: PayloadAction<
@@ -926,6 +941,26 @@ export const canvasSlice = createSlice({
926941

927942
referenceImage.ipAdapter.method = method;
928943
},
944+
rgIPAdapterFLUXReduxImageInfluenceChanged: (
945+
state,
946+
action: PayloadAction<
947+
EntityIdentifierPayload<
948+
{ referenceImageId: string; imageInfluence: FLUXReduxImageInfluence },
949+
'regional_guidance'
950+
>
951+
>
952+
) => {
953+
const { entityIdentifier, referenceImageId, imageInfluence } = action.payload;
954+
const referenceImage = selectRegionalGuidanceReferenceImage(state, entityIdentifier, referenceImageId);
955+
if (!referenceImage) {
956+
return;
957+
}
958+
if (referenceImage.ipAdapter.type !== 'flux_redux') {
959+
return;
960+
}
961+
962+
referenceImage.ipAdapter.imageInfluence = imageInfluence;
963+
},
929964
rgIPAdapterModelChanged: (
930965
state,
931966
action: PayloadAction<
@@ -1731,6 +1766,7 @@ export const {
17311766
referenceImageIPAdapterCLIPVisionModelChanged,
17321767
referenceImageIPAdapterWeightChanged,
17331768
referenceImageIPAdapterBeginEndStepPctChanged,
1769+
referenceImageIPAdapterFLUXReduxImageInfluenceChanged,
17341770
// Regions
17351771
rgAdded,
17361772
// rgRecalled,
@@ -1746,6 +1782,7 @@ export const {
17461782
rgIPAdapterMethodChanged,
17471783
rgIPAdapterModelChanged,
17481784
rgIPAdapterCLIPVisionModelChanged,
1785+
rgIPAdapterFLUXReduxImageInfluenceChanged,
17491786
// Inpaint mask
17501787
inpaintMaskAdded,
17511788
inpaintMaskConvertedToRegionalGuidance,

invokeai/frontend/web/src/features/controlLayers/store/types.ts

+5
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,15 @@ const zIPAdapterConfig = z.object({
233233
});
234234
export type IPAdapterConfig = z.infer<typeof zIPAdapterConfig>;
235235

236+
const zFLUXReduxImageInfluence = z.enum(['lowest', 'low', 'medium', 'high', 'highest']);
237+
export const isFLUXReduxImageInfluence = (v: unknown): v is FLUXReduxImageInfluence =>
238+
zFLUXReduxImageInfluence.safeParse(v).success;
239+
export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
236240
const zFLUXReduxConfig = z.object({
237241
type: z.literal('flux_redux'),
238242
image: zImageWithDims.nullable(),
239243
model: zServerValidatedModelIdentifierField.nullable(),
244+
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
240245
});
241246
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
242247

invokeai/frontend/web/src/features/controlLayers/store/util.ts

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ export const initialFLUXRedux: FLUXReduxConfig = {
7575
type: 'flux_redux',
7676
image: null,
7777
model: null,
78+
imageInfluence: 'highest',
7879
};
7980
export const initialT2IAdapter: T2IAdapterConfig = {
8081
type: 't2i_adapter',

invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXRedux.ts

+45-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
import type { CanvasReferenceImageState, FLUXReduxConfig } from 'features/controlLayers/store/types';
1+
import type {
2+
CanvasReferenceImageState,
3+
FLUXReduxConfig,
4+
FLUXReduxImageInfluence,
5+
} from 'features/controlLayers/store/types';
26
import { isFLUXReduxConfig } from 'features/controlLayers/store/types';
37
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
48
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -36,6 +40,45 @@ export const addFLUXReduxes = ({ entities, g, collector, model }: AddFLUXReduxAr
3640
return result;
3741
};
3842

43+
/**
44+
* To fine-tune the image influence, edit this object.
45+
* - downsampling_factor: 1 to 9, where 1 is the most image influence and 9 is the least. 1 is FLUX redux in its original form.
46+
* - downsampling_function: the function used to downsample the image. Defaults to 'area'. Dunno about how it affects the image.
47+
* - weight: 0 to 1. the conditioning is multiplied by the square of this value. 1 means no change.
48+
*
49+
* See invokeai/app/invocations/flux_redux.py for more details.
50+
*/
51+
const IMAGE_INFLUENCE_TO_SETTINGS: Record<
52+
FLUXReduxImageInfluence,
53+
Pick<Invocation<'flux_redux'>, 'downsampling_factor' | 'downsampling_function' | 'weight'>
54+
> = {
55+
lowest: {
56+
downsampling_factor: 5,
57+
// downsampling_function: 'area',
58+
weight: 1,
59+
},
60+
low: {
61+
downsampling_factor: 4,
62+
// downsampling_function: 'area',
63+
weight: 1,
64+
},
65+
medium: {
66+
downsampling_factor: 3,
67+
// downsampling_function: 'area',
68+
weight: 1,
69+
},
70+
high: {
71+
downsampling_factor: 2,
72+
// downsampling_function: 'area',
73+
weight: 1,
74+
},
75+
highest: {
76+
downsampling_factor: 1,
77+
// downsampling_function: 'area',
78+
weight: 1,
79+
},
80+
};
81+
3982
const addFLUXRedux = (id: string, ipAdapter: FLUXReduxConfig, g: Graph, collector: Invocation<'collect'>) => {
4083
const { model: fluxReduxModel, image } = ipAdapter;
4184
assert(image, 'FLUX Redux image is required');
@@ -48,6 +91,7 @@ const addFLUXRedux = (id: string, ipAdapter: FLUXReduxConfig, g: Graph, collecto
4891
image: {
4992
image_name: image.image_name,
5093
},
94+
...IMAGE_INFLUENCE_TO_SETTINGS[ipAdapter.imageInfluence ?? 'highest'],
5195
});
5296

5397
g.addEdge(node, 'redux_cond', collector, 'item');

0 commit comments

Comments
 (0)