Skip to content

Add SD3 Pipeline #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,64 @@ An example `<selected-recipe-string-key>` would be `"recipe_4.50_bit_mixedpalett

</details>


## <a name="using-stable-diffusion-3"></a> Using Stable Diffusion 3

<details>
<summary> Details (Click to expand) </summary>

### Model Conversion

Stable Diffusion 3 uses some new and some old models to run. For the text encoders, the conversion can be done using a similar command as before

```bash
python -m python_coreml_stable_diffusion.torch2coreml --convert-text-encoder --xl-version --model-version stabilityai/stable-diffusion-xl-base-1.0 --bundle-resources-for-swift-cli --attention-implementation ORIGINAL -o <output-dir>
```

For the new models (MMDiT and a new VAE with 16 channels), the conversion can be done through the [DiffusionKit](https://www.github.com/argmaxinc/DiffusionKit) repo with the following commands:

```bash
git clone https://github.com/argmaxinc/DiffusionKit.git
cd DiffusionKit
pip install -e .
```

Once installed, you can convert the MMDiT model using:

```bash
python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path <path-to-sd3-mmdit.safetensors> --model-version {2b} -o <output-mlpackages-directory> --latent-size {64, 128}
```

And similar for the new VAE model:

```bash
python -m tests.torch2coreml.test_vae --sd3-ckpt-path <path-to-sd3-mmdit.safetensors> -o <output-mlpackages-directory> --latent-size {64, 128}
```

### Swift Inference

Swift inference for Stable Diffusion 3 is similar to the previous versions. The only difference is that the `--sd3` flag should be used to indicate that the model is a Stable Diffusion 3 model.

```bash
swift run StableDiffusionSample <prompt> --resource-path <output-mlpackages-directory/Resources> --output-path <output-dir> --compute-units cpuAndGPU --sd3
```

### Python MLX Inference

Python inference is supported via the [MLX](https://github.com/ml-explore) backend in [DiffusionKit](https://www.github.com/argmaxinc/DiffusionKit). The following command can be used to generate images using Stable Diffusion 3:

```bash
diffusionkit-cli --prompt "a photo of a cat" --output-path </path/to/output/image.png> --seed 0 -w16 -a16
```

Some notable optional arguments:

- For image-to-image, use --image-path (path to input image) and --denoise (value between 0. and 1.)
- T5 text embeddings, use --t5
- For different resolutions, use --height and --width

</details>

## <a name="using-stable-diffusion-xl"></a> Using Stable Diffusion XL

<details>
Expand Down
80 changes: 80 additions & 0 deletions swift/StableDiffusion/pipeline/DecoderSD3.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2024 Apple Inc. All Rights Reserved.

import Foundation
import CoreML

/// A decoder model which produces RGB images from latent samples
@available(iOS 16.2, macOS 13.1, *)
public struct DecoderSD3: ResourceManaging {

/// VAE decoder model
var model: ManagedMLModel

/// Create decoder from Core ML model
///
/// - Parameters:
/// - url: Location of compiled VAE decoder Core ML model
/// - configuration: configuration to be used when the model is loaded
/// - Returns: A decoder that will lazily load its required resources when needed or requested
public init(modelAt url: URL, configuration: MLModelConfiguration) {
self.model = ManagedMLModel(modelAt: url, configuration: configuration)
}

/// Ensure the model has been loaded into memory
public func loadResources() throws {
try model.loadResources()
}

/// Unload the underlying model to free up memory
public func unloadResources() {
model.unloadResources()
}

/// Batch decode latent samples into images
///
/// - Parameters:
/// - latents: Batch of latent samples to decode
/// - scaleFactor: scalar divisor on latents before decoding
/// - Returns: decoded images
public func decode(
_ latents: [MLShapedArray<Float32>],
scaleFactor: Float32,
shiftFactor: Float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, shiftFactor is the only difference between Decoder and DecoderSD3. Let's add the shift to Decoder and default it to 0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still stands, let's reuse Decoder instead of introducing DecoderSD3

) throws -> [CGImage] {

// Form batch inputs for model
let inputs: [MLFeatureProvider] = try latents.map { sample in
// Reference pipeline scales the latent samples before decoding
let sampleScaled = MLShapedArray<Float32>(
scalars: sample.scalars.map { $0 / scaleFactor + shiftFactor },
shape: sample.shape)

let dict = [inputName: MLMultiArray(sampleScaled)]
return try MLDictionaryFeatureProvider(dictionary: dict)
}
let batch = MLArrayBatchProvider(array: inputs)

// Batch predict with model
let results = try model.perform { model in
try model.predictions(fromBatch: batch)
}

// Transform the outputs to CGImages
let images: [CGImage] = try (0..<results.count).map { i in
let result = results.features(at: i)
let outputName = result.featureNames.first!
let output = result.featureValue(for: outputName)!.multiArrayValue!
return try CGImage.fromShapedArray(MLShapedArray<Float32>(converting: output))
}

return images
}

var inputName: String {
try! model.perform { model in
model.modelDescription.inputDescriptionsByName.first!.key
}
}

}
123 changes: 123 additions & 0 deletions swift/StableDiffusion/pipeline/DiscreteFlowScheduler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2024 Apple Inc. All Rights Reserved.

import CoreML

/// A scheduler used to compute a de-noised image
@available(iOS 16.2, macOS 13.1, *)
public final class DiscreteFlowScheduler: Scheduler {
public let trainStepCount: Int
public let inferenceStepCount: Int
public var timeSteps = [Int]()
public var betas = [Float]()
public var alphas = [Float]()
public var alphasCumProd = [Float]()

public private(set) var modelOutputs: [MLShapedArray<Float32>] = []

var trainSteps: Float
var shift: Float
var counter: Int
var sigmas = [Float]()

/// Create a scheduler that uses a second order DPM-Solver++ algorithm.
///
/// - Parameters:
/// - stepCount: Number of inference steps to schedule
/// - trainStepCount: Number of training diffusion steps
/// - timeStepShift: Amount to shift the timestep schedule
/// - Returns: A scheduler ready for its first step
public init(
stepCount: Int = 50,
trainStepCount: Int = 1000,
timeStepShift: Float = 3.0
) {
self.trainStepCount = trainStepCount
self.inferenceStepCount = stepCount
self.trainSteps = Float(trainStepCount)
self.shift = timeStepShift
self.counter = 0

let sigmaDistribution = linspace(1, trainSteps, Int(trainSteps)).map { sigmaFromTimestep($0) }
let timeStepDistribution = linspace(sigmaDistribution.first!, sigmaDistribution.last!, stepCount).reversed()
self.timeSteps = timeStepDistribution.map { Int($0 * trainSteps) }
self.sigmas = timeStepDistribution.map { sigmaFromTimestep($0 * trainSteps) }
}

func sigmaFromTimestep(_ timestep: Float) -> Float {
if shift == 1.0 {
return timestep / trainSteps
} else {
// shift * timestep / (1 + (shift - 1) * timestep)
let t = timestep / trainSteps
return shift * t / (1 + (shift - 1) * t)
}
}

func timestepsFromSigmas() -> [Float] {
return sigmas.map { $0 * trainSteps }
}

/// Convert the model output to the corresponding type the algorithm needs.
func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
assert(modelOutput.scalarCount == sample.scalarCount)
let stepIndex = timeSteps.firstIndex(of: timestep) ?? counter
let sigma = sigmas[stepIndex]

return MLShapedArray<Float>(unsafeUninitializedShape: modelOutput.shape) { result, _ in
modelOutput.withUnsafeShapedBufferPointer { noiseScalars, _, _ in
sample.withUnsafeShapedBufferPointer { latentScalars, _, _ in
for i in 0..<result.count {
let denoised = latentScalars[i] - noiseScalars[i] * sigma
result.initializeElement(
at: i,
to: denoised
)
}
}
}
}
}

public func calculateTimestepsFromSigmas(strength: Float?) -> [Float] {
guard let strength else { return timestepsFromSigmas() }
let startStep = max(inferenceStepCount - Int(Float(inferenceStepCount) * strength), 0)
let actualTimesteps = Array(timestepsFromSigmas()[startStep...])
return actualTimesteps
}

public func step(output: MLShapedArray<Float32>, timeStep t: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
let stepIndex = timeSteps.firstIndex(of: t) ?? counter // TODO: allow float timesteps in scheduler step protocol
let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample)
modelOutputs.append(modelOutput)

let sigma = sigmas[stepIndex]
var dt = sigma
var prevSigma: Float = 0
if stepIndex < sigmas.count - 1 {
prevSigma = sigmas[stepIndex + 1]
dt = prevSigma - sigma
}

let prevSample: MLShapedArray<Float32> = MLShapedArray<Float>(unsafeUninitializedShape: modelOutput.shape) { result, _ in
modelOutput.withUnsafeShapedBufferPointer { noiseScalars, _, _ in
sample.withUnsafeShapedBufferPointer { latentScalars, _, _ in
for i in 0..<result.count {
let denoised = noiseScalars[i]
let x = latentScalars[i]

let d = (x - denoised) / sigma
let prev_x = x + d * dt
result.initializeElement(
at: i,
to: prev_x
)
}
}
}
}

counter += 1
return prevSample
}
}
Loading