Skip to content

Commit 075e040

Browse files
authored
Update Simple Upscale Button to work with spandrel models (#6649)
## Summary Update Simple Upscale Button to work with spandrel models, add UpscaleWarning when models aren't available, clean up ESRGAN logic <!--A description of the changes in this PR. Include the kind of change (fix, feature, docs, etc), the "why" and the "how". Screenshots or videos are useful for frontend changes.--> ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents 5635f65 + bf6066d commit 075e040

File tree

13 files changed

+83
-180
lines changed

13 files changed

+83
-180
lines changed

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts

+14-12
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import { heightChanged, widthChanged } from 'features/controlLayers/store/contro
1010
import { loraRemoved } from 'features/lora/store/loraSlice';
1111
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
1212
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
13-
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
13+
import { simpleUpscaleModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
1414
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
1515
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
1616
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
@@ -186,21 +186,23 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log)
186186
};
187187

188188
const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch, _log) => {
189-
const currentUpscaleModel = state.upscale.upscaleModel;
189+
const { upscaleModel: currentUpscaleModel, simpleUpscaleModel: currentSimpleUpscaleModel } = state.upscale;
190190
const upscaleModels = models.filter(isSpandrelImageToImageModelConfig);
191+
const firstModel = upscaleModels[0] || null;
191192

192-
if (currentUpscaleModel) {
193-
const isCurrentUpscaleModelAvailable = upscaleModels.some((m) => m.key === currentUpscaleModel.key);
194-
if (isCurrentUpscaleModelAvailable) {
195-
return;
196-
}
197-
}
193+
const isCurrentUpscaleModelAvailable = currentUpscaleModel
194+
? upscaleModels.some((m) => m.key === currentUpscaleModel.key)
195+
: false;
198196

199-
const firstModel = upscaleModels[0];
200-
if (firstModel) {
197+
if (!isCurrentUpscaleModelAvailable) {
201198
dispatch(upscaleModelChanged(firstModel));
202-
return;
203199
}
204200

205-
dispatch(upscaleModelChanged(null));
201+
const isCurrentSimpleUpscaleModelAvailable = currentSimpleUpscaleModel
202+
? upscaleModels.some((m) => m.key === currentSimpleUpscaleModel.key)
203+
: false;
204+
205+
if (!isCurrentSimpleUpscaleModelAvailable) {
206+
dispatch(simpleUpscaleModelChanged(firstModel));
207+
}
206208
};

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts

+2-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
1818
const log = logger('session');
1919

2020
const { imageDTO } = action.payload;
21-
const { image_name } = imageDTO;
2221
const state = getState();
2322

2423
const { isAllowedToUpscale, detailTKey } = createIsAllowedToUpscaleSelector(imageDTO)(state);
@@ -40,8 +39,8 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
4039
const enqueueBatchArg: BatchConfig = {
4140
prepend: true,
4241
batch: {
43-
graph: buildAdHocUpscaleGraph({
44-
image_name,
42+
graph: await buildAdHocUpscaleGraph({
43+
image: imageDTO,
4544
state,
4645
}),
4746
runs: 1,

invokeai/frontend/web/src/app/store/store.ts

-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/no
2525
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
2626
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
2727
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
28-
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
2928
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
3029
import { queueSlice } from 'features/queue/store/queueSlice';
3130
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
@@ -53,7 +52,6 @@ const allReducers = {
5352
[gallerySlice.name]: gallerySlice.reducer,
5453
[generationSlice.name]: generationSlice.reducer,
5554
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
56-
[postprocessingSlice.name]: postprocessingSlice.reducer,
5755
[systemSlice.name]: systemSlice.reducer,
5856
[configSlice.name]: configSlice.reducer,
5957
[uiSlice.name]: uiSlice.reducer,
@@ -104,7 +102,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
104102
[galleryPersistConfig.name]: galleryPersistConfig,
105103
[generationPersistConfig.name]: generationPersistConfig,
106104
[nodesPersistConfig.name]: nodesPersistConfig,
107-
[postprocessingPersistConfig.name]: postprocessingPersistConfig,
108105
[systemPersistConfig.name]: systemPersistConfig,
109106
[workflowPersistConfig.name]: workflowPersistConfig,
110107
[uiPersistConfig.name]: uiPersistConfig,

invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts

+26-18
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,46 @@
11
import type { RootState } from 'app/store/store';
2-
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
3-
import type { Graph, Invocation, NonNullableGraph } from 'services/api/types';
2+
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
3+
import {
4+
type ImageDTO,
5+
type Invocation,
6+
isSpandrelImageToImageModelConfig,
7+
type NonNullableGraph,
8+
} from 'services/api/types';
9+
import { assert } from 'tsafe';
410

5-
import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata';
6-
import { ESRGAN } from './constants';
11+
import { addCoreMetadataNode, getModelMetadataField, upsertMetadata } from './canvas/metadata';
12+
import { SPANDREL } from './constants';
713

814
type Arg = {
9-
image_name: string;
15+
image: ImageDTO;
1016
state: RootState;
1117
};
1218

13-
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
14-
const { esrganModelName } = state.postprocessing;
19+
export const buildAdHocUpscaleGraph = async ({ image, state }: Arg): Promise<NonNullableGraph> => {
20+
const { simpleUpscaleModel } = state.upscale;
1521

16-
const realesrganNode: Invocation<'esrgan'> = {
17-
id: ESRGAN,
18-
type: 'esrgan',
19-
image: { image_name },
20-
model_name: esrganModelName,
21-
is_intermediate: false,
22-
board: getBoardField(state),
22+
assert(simpleUpscaleModel, 'No upscale model found in state');
23+
24+
const upscaleNode: Invocation<'spandrel_image_to_image'> = {
25+
id: SPANDREL,
26+
type: 'spandrel_image_to_image',
27+
image_to_image_model: simpleUpscaleModel,
28+
tile_size: 500,
29+
image,
2330
};
2431

2532
const graph: NonNullableGraph = {
26-
id: `adhoc-esrgan-graph`,
33+
id: `adhoc-upscale-graph`,
2734
nodes: {
28-
[ESRGAN]: realesrganNode,
35+
[SPANDREL]: upscaleNode,
2936
},
3037
edges: [],
3138
};
39+
const modelConfig = await fetchModelConfigWithTypeGuard(simpleUpscaleModel.key, isSpandrelImageToImageModelConfig);
3240

33-
addCoreMetadataNode(graph, {}, ESRGAN);
41+
addCoreMetadataNode(graph, {}, SPANDREL);
3442
upsertMetadata(graph, {
35-
esrgan_model: esrganModelName,
43+
upscale_model: getModelMetadataField(modelConfig),
3644
});
3745

3846
return graph;

invokeai/frontend/web/src/features/nodes/util/graph/constants.ts

-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ export const CONTROL_NET_COLLECT = 'control_net_collect';
3636
export const IP_ADAPTER_COLLECT = 'ip_adapter_collect';
3737
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
3838
export const METADATA = 'core_metadata';
39-
export const ESRGAN = 'esrgan';
4039
export const SPANDREL = 'spandrel';
4140
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
4241
export const SDXL_DENOISE_LATENTS = 'sdxl_denoise_latents';

invokeai/frontend/web/src/features/parameters/components/Upscale/ParamRealESRGANModel.tsx

-72
This file was deleted.

invokeai/frontend/web/src/features/parameters/components/Upscale/ParamSpandrelModel.tsx

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
22
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
33
import { useModelCombobox } from 'common/hooks/useModelCombobox';
4-
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
4+
import { simpleUpscaleModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
55
import { memo, useCallback, useMemo } from 'react';
66
import { useTranslation } from 'react-i18next';
77
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
88
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
99

10-
const ParamSpandrelModel = () => {
10+
interface Props {
11+
isMultidiffusion: boolean;
12+
}
13+
14+
const ParamSpandrelModel = ({ isMultidiffusion }: Props) => {
1115
const { t } = useTranslation();
1216
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
1317

14-
const model = useAppSelector((s) => s.upscale.upscaleModel);
18+
const model = useAppSelector((s) => (isMultidiffusion ? s.upscale.upscaleModel : s.upscale.simpleUpscaleModel));
1519
const dispatch = useAppDispatch();
1620

1721
const tooltipLabel = useMemo(() => {
@@ -23,9 +27,13 @@ const ParamSpandrelModel = () => {
2327

2428
const _onChange = useCallback(
2529
(v: SpandrelImageToImageModelConfig | null) => {
26-
dispatch(upscaleModelChanged(v));
30+
if (isMultidiffusion) {
31+
dispatch(upscaleModelChanged(v));
32+
} else {
33+
dispatch(simpleUpscaleModelChanged(v));
34+
}
2735
},
28-
[dispatch]
36+
[isMultidiffusion, dispatch]
2937
);
3038

3139
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({

invokeai/frontend/web/src/features/parameters/components/Upscale/ParamUpscaleSettings.tsx

+5-3
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listen
1212
import { useAppDispatch } from 'app/store/storeHooks';
1313
import { useIsAllowedToUpscale } from 'features/parameters/hooks/useIsAllowedToUpscale';
1414
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
15+
import { UpscaleWarning } from 'features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleWarning';
1516
import { memo, useCallback } from 'react';
1617
import { useTranslation } from 'react-i18next';
1718
import { PiFrameCornersBold } from 'react-icons/pi';
1819
import type { ImageDTO } from 'services/api/types';
1920

20-
import ParamESRGANModel from './ParamRealESRGANModel';
21+
import ParamSpandrelModel from './ParamSpandrelModel';
2122

2223
type Props = { imageDTO?: ImageDTO };
2324

@@ -48,9 +49,10 @@ const ParamUpscalePopover = (props: Props) => {
4849
/>
4950
</PopoverTrigger>
5051
<PopoverContent>
51-
<PopoverBody minW={96}>
52+
<PopoverBody w={96}>
5253
<Flex flexDirection="column" gap={4}>
53-
<ParamESRGANModel />
54+
<ParamSpandrelModel isMultidiffusion={false} />
55+
<UpscaleWarning usesTile={false} />
5456
<Button
5557
tooltip={detail}
5658
size="sm"

invokeai/frontend/web/src/features/parameters/hooks/useIsAllowedToUpscale.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
22
import { useAppSelector } from 'app/store/storeHooks';
3-
import { selectPostprocessingSlice } from 'features/parameters/store/postprocessingSlice';
3+
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
44
import { selectConfigSlice } from 'features/system/store/configSlice';
55
import { useMemo } from 'react';
66
import { useTranslation } from 'react-i18next';
@@ -55,13 +55,16 @@ const getDetailTKey = (isAllowedToUpscale?: ReturnType<typeof getIsAllowedToUpsc
5555
};
5656

5757
export const createIsAllowedToUpscaleSelector = (imageDTO?: ImageDTO) =>
58-
createMemoizedSelector(selectPostprocessingSlice, selectConfigSlice, (postprocessing, config) => {
59-
const { esrganModelName } = postprocessing;
58+
createMemoizedSelector(selectUpscalelice, selectConfigSlice, (upscale, config) => {
59+
const { simpleUpscaleModel } = upscale;
6060
const { maxUpscalePixels } = config;
61+
if (!simpleUpscaleModel) {
62+
return { isAllowedToUpscale: false, detailTKey: undefined };
63+
}
6164

6265
const upscaledPixels = getUpscaledPixels(imageDTO, maxUpscalePixels);
6366
const isAllowedToUpscale = getIsAllowedToUpscale(upscaledPixels, maxUpscalePixels);
64-
const scaleFactor = esrganModelName.includes('x2') ? 2 : 4;
67+
const scaleFactor = simpleUpscaleModel.name.includes('x2') ? 2 : 4;
6568
const detailTKey = getDetailTKey(isAllowedToUpscale, scaleFactor);
6669
return {
6770
isAllowedToUpscale: scaleFactor === 2 ? isAllowedToUpscale.x2 : isAllowedToUpscale.x4,

invokeai/frontend/web/src/features/parameters/store/postprocessingSlice.ts

-53
This file was deleted.

0 commit comments

Comments
 (0)