|
| 1 | +// For licensing see accompanying LICENSE.md file. |
| 2 | +// Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved. |
| 3 | + |
| 4 | +import Accelerate |
| 5 | +import CoreML |
| 6 | + |
| 7 | +/// A scheduler used to compute a de-noised image |
| 8 | +/// |
| 9 | +/// This implementation matches: |
| 10 | +/// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py) |
| 11 | +/// |
| 12 | +/// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095). |
| 13 | +/// Limitations: |
| 14 | +/// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver). |
| 15 | +/// - Second order only. |
| 16 | +/// - Assumes the model predicts epsilon. |
| 17 | +/// - No dynamic thresholding. |
| 18 | +/// - `midpoint` solver algorithm. |
| 19 | +@available(iOS 16.2, macOS 13.1, *) |
| 20 | +public final class DPMSolverMultistepScheduler: Scheduler { |
| 21 | + public let trainStepCount: Int |
| 22 | + public let inferenceStepCount: Int |
| 23 | + public let betas: [Float] |
| 24 | + public let alphas: [Float] |
| 25 | + public let alphasCumProd: [Float] |
| 26 | + public let timeSteps: [Int] |
| 27 | + |
| 28 | + public let alpha_t: [Float] |
| 29 | + public let sigma_t: [Float] |
| 30 | + public let lambda_t: [Float] |
| 31 | + |
| 32 | + public let solverOrder = 2 |
| 33 | + private(set) var lowerOrderStepped = 0 |
| 34 | + |
| 35 | + /// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps. |
| 36 | + /// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps. |
| 37 | + public let useLowerOrderFinal = true |
| 38 | + |
| 39 | + // Stores solverOrder (2) items |
| 40 | + private(set) var modelOutputs: [MLShapedArray<Float32>] = [] |
| 41 | + |
| 42 | + /// Create a scheduler that uses a second order DPM-Solver++ algorithm. |
| 43 | + /// |
| 44 | + /// - Parameters: |
| 45 | + /// - stepCount: Number of inference steps to schedule |
| 46 | + /// - trainStepCount: Number of training diffusion steps |
| 47 | + /// - betaSchedule: Method to schedule betas from betaStart to betaEnd |
| 48 | + /// - betaStart: The starting value of beta for inference |
| 49 | + /// - betaEnd: The end value for beta for inference |
| 50 | + /// - Returns: A scheduler ready for its first step |
| 51 | + public init( |
| 52 | + stepCount: Int = 50, |
| 53 | + trainStepCount: Int = 1000, |
| 54 | + betaSchedule: BetaSchedule = .scaledLinear, |
| 55 | + betaStart: Float = 0.00085, |
| 56 | + betaEnd: Float = 0.012 |
| 57 | + ) { |
| 58 | + self.trainStepCount = trainStepCount |
| 59 | + self.inferenceStepCount = stepCount |
| 60 | + |
| 61 | + switch betaSchedule { |
| 62 | + case .linear: |
| 63 | + self.betas = linspace(betaStart, betaEnd, trainStepCount) |
| 64 | + case .scaledLinear: |
| 65 | + self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 }) |
| 66 | + } |
| 67 | + |
| 68 | + self.alphas = betas.map({ 1.0 - $0 }) |
| 69 | + var alphasCumProd = self.alphas |
| 70 | + for i in 1..<alphasCumProd.count { |
| 71 | + alphasCumProd[i] *= alphasCumProd[i - 1] |
| 72 | + } |
| 73 | + self.alphasCumProd = alphasCumProd |
| 74 | + |
| 75 | + // Currently we only support VP-type noise shedule |
| 76 | + self.alpha_t = vForce.sqrt(self.alphasCumProd) |
| 77 | + self.sigma_t = vForce.sqrt(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd)) |
| 78 | + self.lambda_t = zip(self.alpha_t, self.sigma_t).map { α, σ in log(α) - log(σ) } |
| 79 | + |
| 80 | + self.timeSteps = linspace(0, Float(self.trainStepCount-1), stepCount).reversed().map { Int(round($0)) } |
| 81 | + } |
| 82 | + |
| 83 | + /// Convert the model output to the corresponding type the algorithm needs. |
| 84 | + /// This implementation is for second-order DPM-Solver++ assuming epsilon prediction. |
| 85 | + func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> { |
| 86 | + assert(modelOutput.scalars.count == sample.scalars.count) |
| 87 | + let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep]) |
| 88 | + |
| 89 | + // This could be optimized with a Metal kernel if we find we need to |
| 90 | + let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in |
| 91 | + (s - m * sigma_t) / alpha_t |
| 92 | + } |
| 93 | + return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape) |
| 94 | + } |
| 95 | + |
| 96 | + /// One step for the first-order DPM-Solver (equivalent to DDIM). |
| 97 | + /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. |
| 98 | + /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py |
| 99 | + func firstOrderUpdate( |
| 100 | + modelOutput: MLShapedArray<Float32>, |
| 101 | + timestep: Int, |
| 102 | + prevTimestep: Int, |
| 103 | + sample: MLShapedArray<Float32> |
| 104 | + ) -> MLShapedArray<Float32> { |
| 105 | + let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep])) |
| 106 | + let p_alpha_t = Double(alpha_t[prevTimestep]) |
| 107 | + let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep])) |
| 108 | + let h = p_lambda_t - lambda_s |
| 109 | + // x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output |
| 110 | + let x_t = weightedSum( |
| 111 | + [p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)], |
| 112 | + [sample, modelOutput] |
| 113 | + ) |
| 114 | + return x_t |
| 115 | + } |
| 116 | + |
| 117 | + /// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method. |
| 118 | + /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py |
| 119 | + func secondOrderUpdate( |
| 120 | + modelOutputs: [MLShapedArray<Float32>], |
| 121 | + timesteps: [Int], |
| 122 | + prevTimestep t: Int, |
| 123 | + sample: MLShapedArray<Float32> |
| 124 | + ) -> MLShapedArray<Float32> { |
| 125 | + let (s0, s1) = (timesteps[back: 1], timesteps[back: 2]) |
| 126 | + let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2]) |
| 127 | + let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1])) |
| 128 | + let p_alpha_t = Double(alpha_t[t]) |
| 129 | + let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0])) |
| 130 | + let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1) |
| 131 | + let r0 = h_0 / h |
| 132 | + let D0 = m0 |
| 133 | + |
| 134 | + // D1 = (1.0 / r0) * (m0 - m1) |
| 135 | + let D1 = weightedSum( |
| 136 | + [1/r0, -1/r0], |
| 137 | + [m0, m1] |
| 138 | + ) |
| 139 | + |
| 140 | + // See https://arxiv.org/abs/2211.01095 for detailed derivations |
| 141 | + // x_t = ( |
| 142 | + // (sigma_t / sigma_s0) * sample |
| 143 | + // - (alpha_t * (torch.exp(-h) - 1.0)) * D0 |
| 144 | + // - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 |
| 145 | + // ) |
| 146 | + let x_t = weightedSum( |
| 147 | + [p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)], |
| 148 | + [sample, D0, D1] |
| 149 | + ) |
| 150 | + return x_t |
| 151 | + } |
| 152 | + |
| 153 | + public func step(output: MLShapedArray<Float32>, timeStep t: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> { |
| 154 | + let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1 |
| 155 | + let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1] |
| 156 | + |
| 157 | + let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15 |
| 158 | + let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15 |
| 159 | + let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond |
| 160 | + |
| 161 | + let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample) |
| 162 | + if modelOutputs.count == solverOrder { modelOutputs.removeFirst() } |
| 163 | + modelOutputs.append(modelOutput) |
| 164 | + |
| 165 | + let prevSample: MLShapedArray<Float32> |
| 166 | + if lowerOrder { |
| 167 | + prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample) |
| 168 | + } else { |
| 169 | + prevSample = secondOrderUpdate( |
| 170 | + modelOutputs: modelOutputs, |
| 171 | + timesteps: [timeSteps[stepIndex - 1], t], |
| 172 | + prevTimestep: prevTimestep, |
| 173 | + sample: sample |
| 174 | + ) |
| 175 | + } |
| 176 | + if lowerOrderStepped < solverOrder { |
| 177 | + lowerOrderStepped += 1 |
| 178 | + } |
| 179 | + |
| 180 | + return prevSample |
| 181 | + } |
| 182 | +} |
0 commit comments