@@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging {
33
33
/// Optional model for checking safety of generated image
34
34
var safetyChecker : SafetyChecker ? = nil
35
35
36
- /// Controls the influence of the text prompt on sampling process (0=random images)
37
- var guidanceScale : Float = 7.5
38
-
39
36
/// Reports whether this pipeline can perform safety checks
40
37
public var canSafetyCheck : Bool {
41
38
safetyChecker != nil
@@ -56,20 +53,17 @@ public struct StableDiffusionPipeline: ResourceManaging {
56
53
/// - unet: Model for noise prediction on latent samples
57
54
/// - decoder: Model for decoding latent sample to image
58
55
/// - safetyChecker: Optional model for checking safety of generated images
59
- /// - guidanceScale: Influence of the text prompt on generation process
60
56
/// - reduceMemory: Option to enable reduced memory mode
61
57
/// - Returns: Pipeline ready for image generation
62
58
public init ( textEncoder: TextEncoder ,
63
59
unet: Unet ,
64
60
decoder: Decoder ,
65
61
safetyChecker: SafetyChecker ? = nil ,
66
- guidanceScale: Float = 7.5 ,
67
62
reduceMemory: Bool = false ) {
68
63
self . textEncoder = textEncoder
69
64
self . unet = unet
70
65
self . decoder = decoder
71
66
self . safetyChecker = safetyChecker
72
- self . guidanceScale = guidanceScale
73
67
self . reduceMemory = reduceMemory
74
68
}
75
69
@@ -112,6 +106,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
112
106
/// - stepCount: Number of inference steps to perform
113
107
/// - imageCount: Number of samples/images to generate for the input prompt
114
108
/// - seed: Random seed which
109
+ /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
115
110
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
116
111
/// - progressHandler: Callback to perform after each step, stops on receiving false response
117
112
/// - Returns: An array of `imageCount` optional images.
@@ -122,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
122
117
imageCount: Int = 1 ,
123
118
stepCount: Int = 50 ,
124
119
seed: UInt32 = 0 ,
120
+ guidanceScale: Float = 7.5 ,
125
121
disableSafety: Bool = false ,
126
122
scheduler: StableDiffusionScheduler = . pndmScheduler,
127
123
progressHandler: ( Progress ) -> Bool = { _ in true }
@@ -173,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
173
169
hiddenStates: hiddenStates
174
170
)
175
171
176
- noise = performGuidance ( noise)
172
+ noise = performGuidance ( noise, guidanceScale )
177
173
178
174
// Have the scheduler compute the previous (t-1) latent
179
175
// sample given the predicted noise and current sample
@@ -236,11 +232,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
236
232
return states
237
233
}
238
234
239
- func performGuidance( _ noise: [ MLShapedArray < Float32 > ] ) -> [ MLShapedArray < Float32 > ] {
240
- noise. map { performGuidance ( $0) }
235
+ func performGuidance( _ noise: [ MLShapedArray < Float32 > ] , _ guidanceScale : Float ) -> [ MLShapedArray < Float32 > ] {
236
+ noise. map { performGuidance ( $0, guidanceScale ) }
241
237
}
242
238
243
- func performGuidance( _ noise: MLShapedArray < Float32 > ) -> MLShapedArray < Float32 > {
239
+ func performGuidance( _ noise: MLShapedArray < Float32 > , _ guidanceScale : Float ) -> MLShapedArray < Float32 > {
244
240
245
241
let blankNoiseScalars = noise [ 0 ] . scalars
246
242
let textNoiseScalars = noise [ 1 ] . scalars
0 commit comments