Skip to content

Commit a21468e

Browse files
committed
Make samplers promisable and ID-based to support arbitrary schemes outside of gRPC
1 parent 8e225da commit a21468e

File tree

9 files changed

+69
-92
lines changed

9 files changed

+69
-92
lines changed

packages/stablestudio-plugin-stability/src/index.ts

+18-51
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ const getStableDiffusionDefaultInputFromPrompt = (prompt: string) => ({
2626
],
2727

2828
model: "stable-diffusion-xl-beta-v2-2-2",
29+
sampler: { id: "0", name: "DDIM" },
2930
style: "enhance",
3031

31-
sampler: { value: 0, name: "DDIM" },
3232
width: 512,
3333
height: 512,
3434

@@ -147,6 +147,8 @@ export const createPlugin = StableStudio.createPlugin<{
147147
);
148148
}
149149

150+
console.log(input.sampler.id ? parseInt(input.sampler.id) : 0);
151+
150152
const imageParams = Generation.ImageParameters.create({
151153
width: BigInt(width),
152154
height: BigInt(height),
@@ -158,7 +160,7 @@ export const createPlugin = StableStudio.createPlugin<{
158160
transform: Generation.TransformType.create({
159161
type: {
160162
oneofKind: "diffusion",
161-
diffusion: input.sampler.value ?? 0,
163+
diffusion: input.sampler.id ? parseInt(input.sampler.id) : 0,
162164
},
163165
}),
164166

@@ -411,55 +413,6 @@ export const createPlugin = StableStudio.createPlugin<{
411413
);
412414
},
413415

414-
getStableDiffusionSamplers: () => {
415-
return [
416-
{
417-
value: 0,
418-
name: "DDIM",
419-
},
420-
{
421-
value: 1,
422-
name: "DDPM",
423-
},
424-
{
425-
value: 2,
426-
name: "K_EULER",
427-
},
428-
{
429-
value: 3,
430-
name: "K_EULER_ANCESTRAL",
431-
},
432-
{
433-
value: 4,
434-
name: "K_HEUN",
435-
},
436-
{
437-
value: 5,
438-
name: "K_DPM_2",
439-
},
440-
{
441-
value: 6,
442-
name: "K_DPM_2_ANCESTRAL",
443-
},
444-
{
445-
value: 7,
446-
name: "K_LMS",
447-
},
448-
{
449-
value: 8,
450-
name: "K_DPMPP_2S_ANCESTRAL",
451-
},
452-
{
453-
value: 9,
454-
name: "K_DPMPP_2M",
455-
},
456-
{
457-
value: 10,
458-
name: "K_DPMPP_SDE",
459-
},
460-
];
461-
},
462-
463416
getStableDiffusionModels: async () => {
464417
const request = await engines.listEngines({});
465418
const allEngines = await request.response.engine;
@@ -484,6 +437,20 @@ export const createPlugin = StableStudio.createPlugin<{
484437
localStorage.getItem("stability-apiKey") ?? undefined
485438
),
486439

440+
getStableDiffusionSamplers: () => [
441+
{ id: "0", name: "DDIM" },
442+
{ id: "1", name: "DDPM" },
443+
{ id: "2", name: "K_EULER" },
444+
{ id: "3", name: "K_EULER_ANCESTRAL" },
445+
{ id: "4", name: "K_HEUN" },
446+
{ id: "5", name: "K_DPM_2" },
447+
{ id: "6", name: "K_DPM_2_ANCESTRAL" },
448+
{ id: "7", name: "K_LMS" },
449+
{ id: "8", name: "K_DPMPP_2S_ANCESTRAL" },
450+
{ id: "9", name: "K_DPMPP_2M" },
451+
{ id: "10", name: "K_DPMPP_SDE" },
452+
],
453+
487454
getStableDiffusionStyles: () => [
488455
{
489456
id: "enhance",

packages/stablestudio-plugin/src/Plugin.ts

+11-6
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,16 @@ export type Plugin<P extends PluginTypeHelper = PluginTypeHelperDefault> = {
181181
StableDiffusionModel[] | undefined
182182
>;
183183

184+
/** If more than one `StableDiffusionSampler`, you can return them via this function and they will be presented as a dropdown in the UI */
185+
getStableDiffusionSamplers?: () => MaybePromise<
186+
StableDiffusionSampler[] | undefined
187+
>;
188+
184189
/** If more than one `StableDiffusionStyle`, you can return them via this function and they will be presented as a dropdown in the UI */
185190
getStableDiffusionStyles?: () => MaybePromise<
186191
StableDiffusionStyle[] | undefined
187192
>;
188193

189-
getStableDiffusionSamplers?: () =>
190-
StableDiffusionSampler[] | undefined;
191194
/** Determines the default count passed to `createStableDiffusionImages` */
192195
getStableDiffusionDefaultCount?: () => number | undefined;
193196

@@ -243,10 +246,6 @@ export type StableDiffusionPrompt = {
243246
weight?: number;
244247
};
245248

246-
export declare type StableDiffusionSampler = {
247-
name?: string;
248-
value?: number;
249-
}
250249
/** This controls how a `StableDiffusionModel` is displayed in the UI */
251250
export type StableDiffusionModel = {
252251
id: ID;
@@ -263,6 +262,12 @@ export type StableDiffusionStyle = {
263262
image?: URLString;
264263
};
265264

265+
/** This controls how a `StableDiffusionSampler` is displayed in the UI */
266+
export declare type StableDiffusionSampler = {
267+
id: ID;
268+
name?: string;
269+
};
270+
266271
export type StableDiffusionInputImage = {
267272
blob?: Blob;
268273

packages/stablestudio-ui/src/Generation/Image/Model/Dropdown.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ export function Dropdown({ id, className }: Styleable & { id: ID }) {
5555
options={options}
5656
anchor="bottom"
5757
>
58-
{models === undefined && (
59-
<div className="flex flex-col items-center justify-center py-32 px-16">
58+
{!models && (
59+
<div className="flex flex-col items-center justify-center px-16 py-32">
6060
<div className="text-muted-white pb-3">Loading models...</div>
6161
</div>
6262
)}

packages/stablestudio-ui/src/Generation/Image/Sampler/Dropdown.tsx

+9-10
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,40 @@ import { Theme } from "~/Theme";
33

44
export function Dropdown({ id, className }: Styleable & { id: ID }) {
55
const { setInput, input } = Generation.Image.Input.use(id);
6-
const { samplers } = Generation.Image.Sampler.use();
6+
const { data: samplers } = Generation.Image.Samplers.use();
77

88
const options = useMemo(
99
() => [
10-
...(samplers ?? []).map(({ value, name }) => ({
10+
...(samplers ?? []).map(({ id, name }) => ({
1111
name: name,
12-
value: value,
12+
value: id,
1313
})),
1414
],
1515
[samplers]
1616
);
1717

1818
const onClick = useCallback(
19-
(value: number) => {
19+
(id: ID) => {
2020
setInput((input) => {
21-
console.log("sampler", value);
22-
input.sampler = { value: value, name: options.at(value)?.name };
21+
input.sampler = samplers?.find((sampler) => sampler.id === id);
2322
});
2423
},
25-
[setInput, options]
24+
[setInput, samplers]
2625
);
2726

2827
if (!input) return null;
2928
return (
3029
<Theme.Popout
3130
title="Sampler"
3231
label="Sampler"
33-
placeholder={"Select a Sampler"}
34-
value={input.sampler?.value}
32+
placeholder={"Select a sampler"}
33+
value={input.sampler?.id}
3534
className={className}
3635
onClick={onClick}
3736
options={options}
3837
anchor="bottom"
3938
>
40-
{samplers === undefined && (
39+
{!samplers && (
4140
<div className="flex flex-col items-center justify-center px-16 py-32">
4241
<div className="text-muted-white pb-3">Loading samplers...</div>
4342
</div>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import * as StableStudio from "@stability/stablestudio-plugin";
2+
import * as ReactQuery from "@tanstack/react-query";
3+
4+
import { Plugin } from "~/Plugin";
5+
6+
export type Samplers = StableStudio.StableDiffusionSampler[];
7+
export namespace Samplers {
8+
export const use = () => {
9+
const getStableDiffusionSamplers = Plugin.use(
10+
({ getStableDiffusionSamplers }) => getStableDiffusionSamplers
11+
);
12+
13+
return ReactQuery.useQuery({
14+
enabled: !!getStableDiffusionSamplers,
15+
16+
queryKey: ["Generation.Image.Samplers.use"],
17+
queryFn: async () => (await getStableDiffusionSamplers?.()) ?? [],
18+
});
19+
};
20+
21+
export const useAreEnabled = () =>
22+
Plugin.use(({ getStableDiffusionModels }) => !!getStableDiffusionModels);
23+
}
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,15 @@
11
import * as StableStudio from "@stability/stablestudio-plugin";
22

3-
import { Plugin } from "~/Plugin";
4-
53
import { Dropdown } from "./Dropdown";
64

5+
export * from "./Samplers";
6+
77
export type Sampler = StableStudio.StableDiffusionSampler[];
88

99
export declare namespace Sampler {
1010
export { Dropdown };
1111
}
1212

1313
export namespace Sampler {
14-
export const use = () => {
15-
const getStableDiffusionSamplers = Plugin.use(
16-
({ getStableDiffusionSamplers }) => getStableDiffusionSamplers
17-
);
18-
19-
const samplers = useMemo(
20-
() => getStableDiffusionSamplers?.(),
21-
[getStableDiffusionSamplers]
22-
);
23-
return { samplers };
24-
};
25-
26-
export const useAreEnabled = () =>
27-
Plugin.use(
28-
({ getStableDiffusionSamplers }) => !!getStableDiffusionSamplers
29-
);
30-
3114
Sampler.Dropdown = Dropdown;
3215
}

packages/stablestudio-ui/src/Generation/Image/Sidebar/Advanced.tsx

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export function Advanced({
88
}: App.Sidebar.Section.Props & { id: ID }) {
99
const { setInput, input } = Generation.Image.Input.use(id);
1010
const areModelsEnabled = Generation.Image.Models.useAreEnabled();
11-
const areSamplersEnabled = Generation.Image.Sampler.useAreEnabled();
11+
const areSamplersEnabled = Generation.Image.Samplers.useAreEnabled();
1212

1313
const onPromptStrengthChange = useCallback(
1414
(cfgScale: number) => {

packages/stablestudio-ui/src/Generation/Image/index.tsx

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import { Modal } from "./Modal";
1414
import { Model, Models } from "./Model";
1515
import { Output, Outputs } from "./Output";
1616
import { Prompt, Prompts } from "./Prompt";
17-
import { Sampler } from "./Sampler";
17+
import { Sampler, Samplers } from "./Sampler";
1818
import { Search } from "./Search";
1919
import { Session } from "./Session";
2020
import { Sidebar } from "./Sidebar";
@@ -210,6 +210,7 @@ export declare namespace Image {
210210
Prompt,
211211
Prompts,
212212
Sampler,
213+
Samplers,
213214
Search,
214215
Session,
215216
Sidebar,
@@ -242,6 +243,7 @@ export namespace Image {
242243
Image.Prompt = Prompt;
243244
Image.Prompts = Prompts;
244245
Image.Sampler = Sampler;
246+
Image.Samplers = Samplers;
245247
Image.Search = Search;
246248
Image.Session = Session;
247249
Image.Sidebar = Sidebar;

packages/stablestudio-ui/src/Generation/index.tsx

-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
// import { App } from "~/App";
21
import { Shortcut } from "~/Shortcut";
32
import { Theme } from "~/Theme";
43

54
import { Image, Images } from "./Image";
65

76
export function Generation() {
87
const createDream = Generation.Image.Session.useCreateDream();
9-
// App.Breadcrumbs.useSet(["defaultProject", "generate"]);
108

119
Shortcut.use(
1210
useMemo(

0 commit comments

Comments
 (0)