@@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
108
108
///
109
109
/// - Parameters:
110
110
/// - prompt: Text prompt to guide sampling
111
+ /// - negativePrompt: Negative text prompt to guide sampling
111
112
/// - stepCount: Number of inference steps to perform
112
113
/// - imageCount: Number of samples/images to generate for the input prompt
113
114
/// - seed: Random seed which
@@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
117
118
/// The images will be nil if safety checks were performed and found the result to be un-safe
118
119
public func generateImages(
119
120
prompt: String ,
121
+ negativePrompt: String = " " ,
120
122
imageCount: Int = 1 ,
121
123
stepCount: Int = 50 ,
122
124
seed: UInt32 = 0 ,
@@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
125
127
progressHandler: ( Progress ) -> Bool = { _ in true }
126
128
) throws -> [ CGImage ? ] {
127
129
128
- // Encode the input prompt as well as a blank unconditioned input
130
+ // Encode the input prompt and negative prompt
129
131
let promptEmbedding = try textEncoder. encode ( prompt)
130
- let blankEmbedding = try textEncoder. encode ( " " )
132
+ let negativePromptEmbedding = try textEncoder. encode ( negativePrompt )
131
133
132
134
if reduceMemory {
133
135
textEncoder. unloadResources ( )
134
136
}
135
137
136
138
// Convert to Unet hidden state representation
139
+ // Concatenate the prompt and negative prompt embeddings
137
140
let concatEmbedding = MLShapedArray < Float32 > (
138
- concatenating: [ blankEmbedding , promptEmbedding] ,
141
+ concatenating: [ negativePromptEmbedding , promptEmbedding] ,
139
142
alongAxis: 0
140
143
)
141
144
0 commit comments