Skip to content

Commit e07c4d0

Browse files
authored
Move guidanceScale as generation parameter (apple#46)
* Move guidanceScale as generation parameter * Added guidanceScale in CLI * Reverted identation change
1 parent 877ccb9 commit e07c4d0

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

+6-10
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging {
3333
/// Optional model for checking safety of generated image
3434
var safetyChecker: SafetyChecker? = nil
3535

36-
/// Controls the influence of the text prompt on sampling process (0=random images)
37-
var guidanceScale: Float = 7.5
38-
3936
/// Reports whether this pipeline can perform safety checks
4037
public var canSafetyCheck: Bool {
4138
safetyChecker != nil
@@ -56,20 +53,17 @@ public struct StableDiffusionPipeline: ResourceManaging {
5653
/// - unet: Model for noise prediction on latent samples
5754
/// - decoder: Model for decoding latent sample to image
5855
/// - safetyChecker: Optional model for checking safety of generated images
59-
/// - guidanceScale: Influence of the text prompt on generation process
6056
/// - reduceMemory: Option to enable reduced memory mode
6157
/// - Returns: Pipeline ready for image generation
6258
public init(textEncoder: TextEncoder,
6359
unet: Unet,
6460
decoder: Decoder,
6561
safetyChecker: SafetyChecker? = nil,
66-
guidanceScale: Float = 7.5,
6762
reduceMemory: Bool = false) {
6863
self.textEncoder = textEncoder
6964
self.unet = unet
7065
self.decoder = decoder
7166
self.safetyChecker = safetyChecker
72-
self.guidanceScale = guidanceScale
7367
self.reduceMemory = reduceMemory
7468
}
7569

@@ -112,6 +106,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
112106
/// - stepCount: Number of inference steps to perform
113107
/// - imageCount: Number of samples/images to generate for the input prompt
114108
/// - seed: Random seed which
109+
/// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
115110
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
116111
/// - progressHandler: Callback to perform after each step, stops on receiving false response
117112
/// - Returns: An array of `imageCount` optional images.
@@ -122,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
122117
imageCount: Int = 1,
123118
stepCount: Int = 50,
124119
seed: UInt32 = 0,
120+
guidanceScale: Float = 7.5,
125121
disableSafety: Bool = false,
126122
scheduler: StableDiffusionScheduler = .pndmScheduler,
127123
progressHandler: (Progress) -> Bool = { _ in true }
@@ -173,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
173169
hiddenStates: hiddenStates
174170
)
175171

176-
noise = performGuidance(noise)
172+
noise = performGuidance(noise, guidanceScale)
177173

178174
// Have the scheduler compute the previous (t-1) latent
179175
// sample given the predicted noise and current sample
@@ -236,11 +232,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
236232
return states
237233
}
238234

239-
func performGuidance(_ noise: [MLShapedArray<Float32>]) -> [MLShapedArray<Float32>] {
240-
noise.map { performGuidance($0) }
235+
func performGuidance(_ noise: [MLShapedArray<Float32>], _ guidanceScale: Float) -> [MLShapedArray<Float32>] {
236+
noise.map { performGuidance($0, guidanceScale) }
241237
}
242238

243-
func performGuidance(_ noise: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
239+
func performGuidance(_ noise: MLShapedArray<Float32>, _ guidanceScale: Float) -> MLShapedArray<Float32> {
244240

245241
let blankNoiseScalars = noise[0].scalars
246242
let textNoiseScalars = noise[1].scalars

swift/StableDiffusionCLI/main.swift

+4
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ struct StableDiffusionSample: ParsableCommand {
5353
@Option(help: "Random seed")
5454
var seed: UInt32 = 93
5555

56+
@Option(help: "Controls the influence of the text prompt on sampling process (0=random images)")
57+
var guidanceScale: Float = 7.5
58+
5659
@Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
5760
var computeUnits: ComputeUnits = .all
5861

@@ -92,6 +95,7 @@ struct StableDiffusionSample: ParsableCommand {
9295
imageCount: imageCount,
9396
stepCount: stepCount,
9497
seed: seed,
98+
guidanceScale: guidanceScale,
9599
scheduler: scheduler.stableDiffusionScheduler
96100
) { progress in
97101
sampleTimer.stop()

0 commit comments

Comments
 (0)