-
Notifications
You must be signed in to change notification settings - Fork 1k
Add SD3 Pipeline #329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SD3 Pipeline #329
Changes from 1 commit
d1f0604
f278214
61ab0f1
a8a2958
fec1ab5
e726229
57ef523
804f0c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2024 Apple Inc. All Rights Reserved. | ||
|
||
import Foundation | ||
import CoreML | ||
|
||
/// A decoder model which produces RGB images from latent samples | ||
@available(iOS 16.2, macOS 13.1, *) | ||
public struct DecoderSD3: ResourceManaging { | ||
|
||
/// VAE decoder model | ||
var model: ManagedMLModel | ||
|
||
/// Create decoder from Core ML model | ||
/// | ||
/// - Parameters: | ||
/// - url: Location of compiled VAE decoder Core ML model | ||
/// - configuration: configuration to be used when the model is loaded | ||
/// - Returns: A decoder that will lazily load its required resources when needed or requested | ||
public init(modelAt url: URL, configuration: MLModelConfiguration) { | ||
self.model = ManagedMLModel(modelAt: url, configuration: configuration) | ||
} | ||
|
||
/// Ensure the model has been loaded into memory | ||
public func loadResources() throws { | ||
try model.loadResources() | ||
} | ||
|
||
/// Unload the underlying model to free up memory | ||
public func unloadResources() { | ||
model.unloadResources() | ||
} | ||
|
||
/// Batch decode latent samples into images | ||
/// | ||
/// - Parameters: | ||
/// - latents: Batch of latent samples to decode | ||
/// - scaleFactor: scalar divisor on latents before decoding | ||
/// - Returns: decoded images | ||
public func decode( | ||
_ latents: [MLShapedArray<Float32>], | ||
scaleFactor: Float32, | ||
shiftFactor: Float32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I can tell, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment still stands, let's reuse |
||
) throws -> [CGImage] { | ||
|
||
// Form batch inputs for model | ||
let inputs: [MLFeatureProvider] = try latents.map { sample in | ||
// Reference pipeline scales the latent samples before decoding | ||
let sampleScaled = MLShapedArray<Float32>( | ||
scalars: sample.scalars.map { $0 / scaleFactor + shiftFactor }, | ||
shape: sample.shape) | ||
|
||
let dict = [inputName: MLMultiArray(sampleScaled)] | ||
return try MLDictionaryFeatureProvider(dictionary: dict) | ||
} | ||
let batch = MLArrayBatchProvider(array: inputs) | ||
|
||
// Batch predict with model | ||
let results = try model.perform { model in | ||
try model.predictions(fromBatch: batch) | ||
} | ||
|
||
// Transform the outputs to CGImages | ||
let images: [CGImage] = try (0..<results.count).map { i in | ||
let result = results.features(at: i) | ||
let outputName = result.featureNames.first! | ||
let output = result.featureValue(for: outputName)!.multiArrayValue! | ||
return try CGImage.fromShapedArray(MLShapedArray<Float32>(converting: output)) | ||
} | ||
|
||
return images | ||
} | ||
|
||
var inputName: String { | ||
try! model.perform { model in | ||
model.modelDescription.inputDescriptionsByName.first!.key | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2024 Apple Inc. All Rights Reserved. | ||
|
||
import CoreML | ||
|
||
/// A scheduler used to compute a de-noised image | ||
@available(iOS 16.2, macOS 13.1, *) | ||
public final class DiscreteFlowScheduler: Scheduler { | ||
public let trainStepCount: Int | ||
public let inferenceStepCount: Int | ||
public var timeSteps = [Int]() | ||
public var betas = [Float]() | ||
public var alphas = [Float]() | ||
public var alphasCumProd = [Float]() | ||
|
||
public private(set) var modelOutputs: [MLShapedArray<Float32>] = [] | ||
|
||
var trainSteps: Float | ||
var shift: Float | ||
var counter: Int | ||
var sigmas = [Float]() | ||
|
||
/// Create a scheduler that uses a second order DPM-Solver++ algorithm. | ||
/// | ||
/// - Parameters: | ||
/// - stepCount: Number of inference steps to schedule | ||
/// - trainStepCount: Number of training diffusion steps | ||
/// - timeStepShift: Amount to shift the timestep schedule | ||
/// - Returns: A scheduler ready for its first step | ||
public init( | ||
stepCount: Int = 50, | ||
trainStepCount: Int = 1000, | ||
timeStepShift: Float = 3.0 | ||
) { | ||
self.trainStepCount = trainStepCount | ||
self.inferenceStepCount = stepCount | ||
self.trainSteps = Float(trainStepCount) | ||
self.shift = timeStepShift | ||
self.counter = 0 | ||
|
||
let sigmaDistribution = linspace(1, trainSteps, Int(trainSteps)).map { sigmaFromTimestep($0) } | ||
let timeStepDistribution = linspace(sigmaDistribution.first!, sigmaDistribution.last!, stepCount).reversed() | ||
self.timeSteps = timeStepDistribution.map { Int($0 * trainSteps) } | ||
self.sigmas = timeStepDistribution.map { sigmaFromTimestep($0 * trainSteps) } | ||
} | ||
|
||
func sigmaFromTimestep(_ timestep: Float) -> Float { | ||
if shift == 1.0 { | ||
return timestep / trainSteps | ||
} else { | ||
// shift * timestep / (1 + (shift - 1) * timestep) | ||
let t = timestep / trainSteps | ||
return shift * t / (1 + (shift - 1) * t) | ||
} | ||
} | ||
|
||
func timestepsFromSigmas() -> [Float] { | ||
return sigmas.map { $0 * trainSteps } | ||
} | ||
|
||
/// Convert the model output to the corresponding type the algorithm needs. | ||
func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> { | ||
assert(modelOutput.scalarCount == sample.scalarCount) | ||
let stepIndex = timeSteps.firstIndex(of: timestep) ?? counter | ||
let sigma = sigmas[stepIndex] | ||
|
||
return MLShapedArray<Float>(unsafeUninitializedShape: modelOutput.shape) { result, _ in | ||
modelOutput.withUnsafeShapedBufferPointer { noiseScalars, _, _ in | ||
sample.withUnsafeShapedBufferPointer { latentScalars, _, _ in | ||
for i in 0..<result.count { | ||
let denoised = latentScalars[i] - noiseScalars[i] * sigma | ||
result.initializeElement( | ||
at: i, | ||
to: denoised | ||
) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
public func calculateTimestepsFromSigmas(strength: Float?) -> [Float] { | ||
guard let strength else { return timestepsFromSigmas() } | ||
let startStep = max(inferenceStepCount - Int(Float(inferenceStepCount) * strength), 0) | ||
let actualTimesteps = Array(timestepsFromSigmas()[startStep...]) | ||
return actualTimesteps | ||
} | ||
|
||
public func step(output: MLShapedArray<Float32>, timeStep t: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> { | ||
let stepIndex = timeSteps.firstIndex(of: t) ?? counter // TODO: allow float timesteps in scheduler step protocol | ||
let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample) | ||
modelOutputs.append(modelOutput) | ||
|
||
let sigma = sigmas[stepIndex] | ||
var dt = sigma | ||
var prevSigma: Float = 0 | ||
if stepIndex < sigmas.count - 1 { | ||
prevSigma = sigmas[stepIndex + 1] | ||
dt = prevSigma - sigma | ||
} | ||
|
||
let prevSample: MLShapedArray<Float32> = MLShapedArray<Float>(unsafeUninitializedShape: modelOutput.shape) { result, _ in | ||
modelOutput.withUnsafeShapedBufferPointer { noiseScalars, _, _ in | ||
sample.withUnsafeShapedBufferPointer { latentScalars, _, _ in | ||
for i in 0..<result.count { | ||
let denoised = noiseScalars[i] | ||
let x = latentScalars[i] | ||
|
||
let d = (x - denoised) / sigma | ||
let prev_x = x + d * dt | ||
result.initializeElement( | ||
at: i, | ||
to: prev_x | ||
) | ||
} | ||
} | ||
} | ||
} | ||
|
||
counter += 1 | ||
return prevSample | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.