Skip to content

Commit c90b705

Browse files
Adds Negative Prompts (#61)
* Synced to main branch and minimizes line changes * Adds negative prompt argument to CLI Co-authored-by: Wanaldino Antimonio <[email protected]>
1 parent 4c00b32 commit c90b705

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

+6-3
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
108108
///
109109
/// - Parameters:
110110
/// - prompt: Text prompt to guide sampling
111+
/// - negativePrompt: Negative text prompt to guide sampling
111112
/// - stepCount: Number of inference steps to perform
112113
/// - imageCount: Number of samples/images to generate for the input prompt
113114
/// - seed: Random seed which
@@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
117118
/// The images will be nil if safety checks were performed and found the result to be un-safe
118119
public func generateImages(
119120
prompt: String,
121+
negativePrompt: String = "",
120122
imageCount: Int = 1,
121123
stepCount: Int = 50,
122124
seed: UInt32 = 0,
@@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
125127
progressHandler: (Progress) -> Bool = { _ in true }
126128
) throws -> [CGImage?] {
127129

128-
// Encode the input prompt as well as a blank unconditioned input
130+
// Encode the input prompt and negative prompt
129131
let promptEmbedding = try textEncoder.encode(prompt)
130-
let blankEmbedding = try textEncoder.encode("")
132+
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
131133

132134
if reduceMemory {
133135
textEncoder.unloadResources()
134136
}
135137

136138
// Convert to Unet hidden state representation
139+
// Concatenate the prompt and negative prompt embeddings
137140
let concatEmbedding = MLShapedArray<Float32>(
138-
concatenating: [blankEmbedding, promptEmbedding],
141+
concatenating: [negativePromptEmbedding, promptEmbedding],
139142
alongAxis: 0
140143
)
141144

swift/StableDiffusionCLI/main.swift

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ struct StableDiffusionSample: ParsableCommand {
1919
@Argument(help: "Input string prompt")
2020
var prompt: String
2121

22+
@Option(help: "Input string negative prompt")
23+
var negativePrompt: String
24+
2225
@Option(
2326
help: ArgumentHelp(
2427
"Path to stable diffusion resources.",
@@ -85,6 +88,7 @@ struct StableDiffusionSample: ParsableCommand {
8588

8689
let images = try pipeline.generateImages(
8790
prompt: prompt,
91+
negativePrompt: negativePrompt,
8892
imageCount: imageCount,
8993
stepCount: stepCount,
9094
seed: seed,

0 commit comments

Comments
 (0)