diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift index fea6d9a7..63c0efb4 100644 --- a/Libraries/MLXVLM/MediaProcessing.swift +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -16,7 +16,6 @@ public struct ProcessedFrames { let totalDuration: CMTime } -// TODO: verify working color space, rendering color space private let context = CIContext() /// Collection of methods for processing media (images, video, etc.). @@ -27,7 +26,7 @@ private let context = CIContext() /// var image: CIImage /// image = MediaProcessing.inSRGBToneCurveSpace(image) /// -/// // apply user instructions +/// // Apply user instructions /// image = MediaProcessing.apply(image, processing: processing) /// /// image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize) @@ -76,58 +75,58 @@ public enum MediaProcessing { return Float(1 / inputAspectRatio * desiredAspectRatio) } - /// Resample the image using bicubic interpolation. - public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { - let filter = CIFilter.bicubicScaleTransform() - let extent = image.extent.size - - filter.inputImage = image - - // set the aspect ratio to match the aspect ratio of the target - filter.aspectRatio = aspectRatioForResample(image, size: size) - - // that image is now the aspect ratio of the target and the size - // of the shorter dimension - let scale: CGFloat - if extent.width < extent.height { - scale = size.width / extent.width - } else { - scale = size.height / extent.height - } - filter.scale = Float(scale) - - let rescaled = filter.outputImage! - - // the image has a DoD larger than the requested size so crop - // it to the desired size - return rescaled.cropped(to: CGRect(origin: .zero, size: size)) - } - /// Resample the image using Lanczos interpolation. static public func resampleLanczos(_ image: CIImage, to size: CGSize) -> CIImage { - let filter = CIFilter.lanczosScaleTransform() - let extent = image.extent.size + // Create a bicubic scale filter + + let yScale = size.height / image.extent.height + let xScale = size.width / image.extent.width + let filter = CIFilter.lanczosScaleTransform() filter.inputImage = image + filter.scale = Float(yScale) + filter.aspectRatio = Float(xScale / yScale) + let scaledImage = filter.outputImage! + + // Create a rect with the exact dimensions we want + let exactRect = CGRect( + x: 0, + y: 0, + width: size.width, + height: size.height + ) - // set the aspect ratio to match the aspect ratio of the target - filter.aspectRatio = aspectRatioForResample(image, size: size) + // Crop to ensure exact dimensions + return scaledImage.cropped(to: exactRect) + } - // that image is now the aspect ratio of the target and the size - // of the shorter dimension - let scale: CGFloat - if extent.width < extent.height { - scale = size.width / extent.width - } else { - scale = size.height / extent.height - } - filter.scale = Float(scale) + /// Resample the image using bicubic interpolation. + /// - Parameters: + /// - image: The image to resample + /// - size: The target size + /// - Returns: The resampled image + public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { + // Create a bicubic scale filter + + let yScale = size.height / image.extent.height + let xScale = size.width / image.extent.width - let rescaled = filter.outputImage! + let filter = CIFilter.bicubicScaleTransform() + filter.inputImage = image + filter.scale = Float(yScale) + filter.aspectRatio = Float(xScale / yScale) + let scaledImage = filter.outputImage! + + // Create a rect with the exact dimensions we want + let exactRect = CGRect( + x: 0, + y: 0, + width: size.width, + height: size.height + ) - // the image has a DoD larger than the requested size so crop - // it to the desired size - return rescaled.cropped(to: CGRect(origin: .zero, size: size)) + // Crop to ensure exact dimensions + return scaledImage.cropped(to: exactRect) } /// Normalize the image using the given mean and standard deviation parameters. @@ -137,7 +136,7 @@ public enum MediaProcessing { let filter = CIFilter.colorMatrix() filter.inputImage = image - // this should match + // This should match // https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html // // output[channel] = (input[channel] - mean[channel]) / std[channel] @@ -156,6 +155,10 @@ public enum MediaProcessing { } /// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]` + /// - Parameters: + /// - image: The image to convert + /// - colorSpace: Optional color space for rendering + /// - Returns: The MLXArray representation of the image public static func asMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) -> MLXArray { let size = image.extent.size let w = Int(size.width.rounded()) @@ -178,10 +181,10 @@ public enum MediaProcessing { var array = MLXArray(data, [h, w, 4], type: Float32.self) - // drop 4th channel + // Drop 4th channel array = array[0..., 0..., ..<3] - // convert to 1, C, H, W + // Convert to 1, C, H, W array = array.reshaped(1, h, w, 3).transposed(0, 3, 1, 2) return array diff --git a/Libraries/MLXVLM/Models/Idefics3.swift b/Libraries/MLXVLM/Models/Idefics3.swift index df22f182..0a63e87c 100644 --- a/Libraries/MLXVLM/Models/Idefics3.swift +++ b/Libraries/MLXVLM/Models/Idefics3.swift @@ -851,7 +851,7 @@ public class Idefics3Processor: UserInputProcessor { height: fixedImageSize ) image = MediaProcessing.apply(image, processing: input.processing) - image = MediaProcessing.resampleBicubic(image, to: targetSize) + image = try MediaProcessing.resampleBicubic(image, to: targetSize) image = MediaProcessing.normalize( image, mean: config.imageMeanTuple, diff --git a/Libraries/MLXVLM/Models/Paligemma.swift b/Libraries/MLXVLM/Models/Paligemma.swift index de893318..c8a46e8f 100644 --- a/Libraries/MLXVLM/Models/Paligemma.swift +++ b/Libraries/MLXVLM/Models/Paligemma.swift @@ -441,7 +441,7 @@ private enum Vision { /// PaliGemma VLM `UserInputProcessor`. /// /// This is meant to be used with ``PaliGemma`` and is typically created by ``VLMModelFactory``. -public class PaligGemmaProcessor: UserInputProcessor { +public class PaliGemmaProcessor: UserInputProcessor { private let config: PaliGemmaProcessorConfiguration private let tokenizer: any Tokenizer @@ -451,7 +451,7 @@ public class PaligGemmaProcessor: UserInputProcessor { self.tokenizer = tokenizer } - private func prepare(image: CIImage, processing: UserInput.Processing?) -> MLXArray { + private func prepare(image: CIImage, processing: UserInput.Processing?) throws -> MLXArray { // based on image_processing_siglip from transformers var image = image @@ -463,7 +463,7 @@ public class PaligGemmaProcessor: UserInputProcessor { // apply user instructions image = MediaProcessing.apply(image, processing: processing) - image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize) + image = try MediaProcessing.resampleBicubic(image, to: config.size.cgSize) image = MediaProcessing.normalize( image, mean: config.imageMeanTuple, std: config.imageStdTuple) @@ -705,7 +705,7 @@ public struct PaliGemmaConfiguration: Codable, Sendable { } } -/// Configuration for ``PaligGemmaProcessor`` +/// Configuration for ``PaliGemmaProcessor`` public struct PaliGemmaProcessorConfiguration: Codable, Sendable { public struct Size: Codable, Sendable { diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift new file mode 100644 index 00000000..4ddce506 --- /dev/null +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -0,0 +1,1092 @@ +// Port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/qwen2_5_vl + +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +// MARK: - Language + +private enum Language { + + /// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors + static private func applyMultimodalRotaryPositionEmbedding( + q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray, + positionIds: MLXArray, mropeSection: [Int] + ) -> (MLXArray, MLXArray) { + var cos = cos[positionIds] + var sin = sin[positionIds] + + cos = + concatenated( + // [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))] + split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + sin = + concatenated( + split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + // Apply rotary embedding + let qEmbed = (q * cos) + (QwenVL.rotateHalf(q) * sin) + let kEmbed = (k * cos) + (QwenVL.rotateHalf(k) * sin) + return (qEmbed, kEmbed) + } + + fileprivate class Attention: Module { + + let heads: Int + let kvHeads: Int + let headDim: Int + let scale: Float + let mropeSection: [Int] + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + let dim = args.hiddenSize + self.heads = args.attentionHeads + self.kvHeads = args.kvHeads + self.headDim = dim / heads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() { + // mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist() + self.mropeSection = sequence(state: (0, array.makeIterator())) { state in + if let v = state.1.next() { + // note the *2 + state.0 += v * 2 + return state.0 + } else { + return nil + } + }.dropLast() + } else { + fatalError("rope_scaling['mrope_section'] must be an array of integers") + } + + self._rotaryEmbedding.wrappedValue = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + + let offset = cache?.offset ?? 0 + let mask = mask?[0..., 0 ..< keys.dim(-2)] + + queries = rotaryEmbedding(queries, offset: offset) + keys = rotaryEmbedding(keys, offset: offset) + + if let cache { + (keys, values) = cache.update(keys: keys, values: values) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } + } + + fileprivate class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "up_proj") var up: Linear + @ModuleInfo(key: "down_proj") var down: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } + } + + fileprivate class Qwen25VLDecoderLayer: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + var r = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return out + } + } + + fileprivate class Qwen25Model: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [Qwen25VLDecoderLayer] + fileprivate let norm: RMSNorm + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + Qwen25VLDecoderLayer(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> MLXArray { + var h: MLXArray + if let inputEmbedding { + h = inputEmbedding + } else if let inputs { + h = embedTokens(inputs) + } else { + fatalError("one of inputs or inputEmbedding must be non-nil") + } + + let mask = createAttentionMask(h: h, cache: cache) + + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + + return norm(h) + } + } + + fileprivate class LanguageModel: Module, KVCacheDimensionProvider { + @ModuleInfo var model: Qwen25Model + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + var kvHeads: [Int] + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + self.model = Qwen25Model(args) + + if !args.tieWordEmbeddings { + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + } + + public func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> LMOutput { + var out = model(inputs, cache: cache, inputEmbedding: inputEmbedding) + if let lmHead { + out = lmHead(out) + } else { + out = model.embedTokens.asLinear(out) + } + return LMOutput(logits: out) + } + } +} + +// MARK: - Vision + +private enum Vision { + + static fileprivate func applyMultimodalRotaryPositionEmbedding( + _ tensor: MLXArray, freqs: MLXArray + ) -> MLXArray { + var cos = cos(freqs) + var sin = sin(freqs) + + cos = expandedDimensions(cos, axis: 1) + cos = tiled(cos, repetitions: [1, 1, 2]) + cos = expandedDimensions(cos, axis: 0) + + sin = expandedDimensions(sin, axis: 1) + sin = tiled(sin, repetitions: [1, 1, 2]) + sin = expandedDimensions(sin, axis: 0) + + let output = (tensor * cos) + (QwenVL.rotateHalf(tensor) * sin) + return output.asType(tensor.dtype) + } + + fileprivate class PatchMerger: Module, UnaryLayer { + let hiddenSize: Int + @ModuleInfo(key: "ln_q") var layerNormQ: RMSNorm + @ModuleInfo var mlp: (Linear, GELU, Linear) + + init(dimensions: Int, contextDimensions: Int, spatialMergeSize: Int) { + self.hiddenSize = contextDimensions * (spatialMergeSize * spatialMergeSize) + self._layerNormQ.wrappedValue = RMSNorm(dimensions: contextDimensions, eps: 1e-6) + self.mlp = ( + Linear(hiddenSize, hiddenSize), + GELU(), + Linear(hiddenSize, dimensions) + ) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + var x = layerNormQ(x).reshaped(-1, hiddenSize) + x = mlp.0(x) + x = mlp.1(x) + x = mlp.2(x) + return x + } + } + + fileprivate class Attention: Module { + + let numHeads: Int + let scale: Float + + @ModuleInfo(key: "qkv") var qkv: Linear + @ModuleInfo(key: "proj") var proj: Linear + + public init(dims: Int, numHeads: Int) { + self.numHeads = numHeads + let headDim = dims / numHeads + self.scale = pow(Float(headDim), -0.5) + + self._qkv.wrappedValue = Linear(dims, 3 * dims, bias: true) + self._proj.wrappedValue = Linear(dims, dims) + } + + public func callAsFunction( + _ x: MLXArray, attentionMask: MLXArray, rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + let sequenceLength = x.dim(0) + + let qkv = qkv(x) + let s = split(qkv, parts: 3, axis: -1) + var (q, k, v) = (s[0], s[1], s[2]) + + q = q.reshaped(sequenceLength, numHeads, -1) + k = k.reshaped(sequenceLength, numHeads, -1) + v = v.reshaped(sequenceLength, numHeads, -1) + + q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) + k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) + + q = q.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + k = k.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + v = v.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + + let output = MLXFast.scaledDotProductAttention( + queries: q, keys: k, values: v, scale: scale, mask: attentionMask + ) + .transposed(0, 2, 1, 3) + .reshaped(sequenceLength, -1) + + return proj(output) + } + } + + fileprivate class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "up_proj") var up: Linear + @ModuleInfo(key: "down_proj") var down: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } + } + + fileprivate class Qwen25VLVisionBlock: Module { + + @ModuleInfo var norm1: RMSNorm + @ModuleInfo var norm2: RMSNorm + @ModuleInfo(key: "attn") var attention: Attention + @ModuleInfo var mlp: MLP + + public init(_ config: Qwen25VLConfiguration.VisionConfiguration) { + self.norm1 = RMSNorm(dimensions: config.hiddenSize, eps: 1e-6) + self.norm2 = RMSNorm(dimensions: config.hiddenSize, eps: 1e-6) + + self._attention.wrappedValue = Attention( + dims: config.hiddenSize, numHeads: config.numHeads) + + self.mlp = MLP( + dimensions: config.hiddenSize, hiddenDimensions: config.intermediateSize) + } + + func callAsFunction( + _ hiddenStates: MLXArray, attentionMask: MLXArray, rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + var hiddenStates = + hiddenStates + + attention( + norm1(hiddenStates), + attentionMask: attentionMask, + rotaryPositionEmbedding: rotaryPositionEmbedding + ) + hiddenStates = hiddenStates + mlp(norm2(hiddenStates)) + return hiddenStates + } + } + + fileprivate class VisionModel: Module { + + @ModuleInfo(key: "patch_embed") var patchEmbed: QwenVL.PatchEmbed + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: QwenVL.VisionRotaryEmbedding + @ModuleInfo(key: "blocks") var blocks: [Qwen25VLVisionBlock] + @ModuleInfo(key: "merger") var patchMerger: PatchMerger + + let spatialMergeSize: Int + let windowSize: Int + let patchSize: Int + let spatialMergeUnit: Int + let fullattBlockIndexes: [Int] + + public init(_ config: Qwen25VLConfiguration.VisionConfiguration) { + self.spatialMergeSize = config.spatialMergeSize + self.windowSize = config.windowSize + self.patchSize = config.patchSize + self.spatialMergeUnit = config.spatialMergeSize * config.spatialMergeSize + self.fullattBlockIndexes = config.fullattBlockIndexes + + self._patchEmbed.wrappedValue = QwenVL.PatchEmbed( + patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize, + inChannels: config.inChannels, + hiddenSize: config.hiddenSize) + + let headDimensions = config.hiddenSize / config.numHeads + self._rotaryPositionEmbedding.wrappedValue = QwenVL.VisionRotaryEmbedding( + dimensions: headDimensions / 2, theta: 10_000) + + self._blocks.wrappedValue = (0 ..< config.depth).map { _ in + Qwen25VLVisionBlock(config) + } + self._patchMerger.wrappedValue = PatchMerger( + dimensions: config.outHiddenSize, contextDimensions: config.hiddenSize, + spatialMergeSize: config.spatialMergeSize) + } + + func rotaryPositionEmbedding(_ frames: [THW]) -> MLXArray { + var positionIds = [MLXArray]() + + for row in frames { + let (t, h, w) = row.values + + var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1) + hposIds = repeated(hposIds, count: w, axis: 1) + hposIds = + hposIds + .reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize + ) + .transposed(0, 2, 1, 3) + .flattened() + + var wposIds = expandedDimensions(MLXArray(0 ..< w), axis: 0) + wposIds = repeated(wposIds, count: h, axis: 0) + wposIds = + wposIds + .reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize + ) + .transposed(0, 2, 1, 3) + .flattened() + + let stackedPosIds = stacked([hposIds, wposIds], axis: -1) + positionIds.append(tiled(stackedPosIds, repetitions: [t, 1])) + } + + let indices = concatenated(positionIds, axis: 0) + let maxFrameSize = frames.lazy.map { max($0.h, $0.w) }.max() ?? 0 + let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxFrameSize)[ + indices] + + return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) + } + + func getWindowIndex(_ frames: [THW]) -> (MLXArray, MLXArray) { + var windowIndex = [MLXArray]() + var cuWindowSeqlens = [0] + var windowIndexId = 0 + let vitMergerWindowSize = windowSize / spatialMergeSize / patchSize + + for frame in frames { + let (gridT, gridH, gridW) = frame.values + let llmGridH = gridH / spatialMergeSize + let llmGridW = gridW / spatialMergeSize + + let index = MLXArray(0 ..< (gridT * llmGridH * llmGridW)).reshaped( + gridT, llmGridH, llmGridW) + + let padH = vitMergerWindowSize - llmGridH % vitMergerWindowSize + let padW = vitMergerWindowSize - llmGridW % vitMergerWindowSize + let numWindowsH = (llmGridH + padH) / vitMergerWindowSize + let numWindowsW = (llmGridW + padW) / vitMergerWindowSize + + // Pad the index + let indexPadded = padded( + index, + widths: [[0, 0], [0, padH], [0, padW]], + mode: .constant, + value: MLXArray(-100) + ) + + // Reshape and transpose + let indexReshaped = indexPadded.reshaped( + gridT, + numWindowsH, + vitMergerWindowSize, + numWindowsW, + vitMergerWindowSize + ) + + let indexTransposed = indexReshaped.transposed(0, 1, 3, 2, 4).reshaped( + gridT, + numWindowsH * numWindowsW, + vitMergerWindowSize, + vitMergerWindowSize + ) + + // Calculate sequence lengths + let seqlens = sum(indexTransposed .!= -100, axes: [2, 3]).reshaped(-1) + + // Get valid indices + let indexFlattened = indexTransposed.flattened() + let validIndices = indexFlattened.asArray(Int.self).enumerated() + .filter { $0.element != -100 } + .map { $0.offset } + + let validValues = indexFlattened[MLXArray(validIndices)] + + // Add to window index + windowIndex.append(validValues + windowIndexId) + + // Update cumulative sequence lengths + let cuSeqlensTmp = + cumsum(seqlens, axis: 0) * spatialMergeUnit + cuWindowSeqlens.last! + cuWindowSeqlens.append(contentsOf: cuSeqlensTmp.asArray(Int.self)) + + windowIndexId += gridT * llmGridH * llmGridW + } + + // Concatenate all window indices + let combinedWindowIndex = concatenated(windowIndex, axis: 0) + let cuWindowSeqlensArray = MLXArray(cuWindowSeqlens) + + // Get unique values in cuWindowSeqlens + var seen = Set() + var uniqueIndices = [Int]() + + for (i, value) in cuWindowSeqlens.enumerated() { + if !seen.contains(value) { + seen.insert(value) + uniqueIndices.append(i) + } + } + + let uniqueCuWindowSeqlens = cuWindowSeqlensArray[MLXArray(uniqueIndices)] + + return (combinedWindowIndex, uniqueCuWindowSeqlens) + } + + func attentionMask(sequenceLength: Int, cuSeqlens: MLXArray) -> MLXArray { + // Create attention mask + let attentionMask = full( + [1, sequenceLength, sequenceLength], + values: Int8(-127)) + + // Update mask for each sequence + let cuSeqlens = cuSeqlens.asArray(Int.self) + for i in 1 ..< cuSeqlens.count { + let start = cuSeqlens[i - 1] + let end = cuSeqlens[i] + attentionMask[0..., start ..< end, start ..< end] = MLXArray(Int8(0)) + } + + return attentionMask + } + + public func callAsFunction(_ hiddenStates: MLXArray, frames: [THW]) -> MLXArray { + var hiddenStates = patchEmbed(hiddenStates) + let rotaryPosEmb = rotaryPositionEmbedding(frames) + + // Get window indices and sequence lengths + let (windowIndex, cuWindowSeqlens) = getWindowIndex(frames) + + // prepare attention masks + let seqLen = hiddenStates.dim(0) + var cuSeqlens = [0] + for frame in frames { + let seqLen = frame.h * frame.w + cuSeqlens.append( + contentsOf: Array(repeating: seqLen, count: frame.t).map { + cuSeqlens.last! + $0 + }) + } + let cuSeqlensArray = MLXArray(cuSeqlens) + + let fullAttentionMask = attentionMask(sequenceLength: seqLen, cuSeqlens: cuSeqlensArray) + let windowAttentionMask = attentionMask( + sequenceLength: seqLen, cuSeqlens: cuWindowSeqlens) + + // Reshape and reindex hidden states + hiddenStates = hiddenStates.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) + hiddenStates = hiddenStates[windowIndex, 0..., 0...] + hiddenStates = hiddenStates.reshaped(seqLen, -1) + + // Reshape and reindex rotary position embeddings + var rotaryPosEmbReshaped = rotaryPosEmb.reshaped( + seqLen / spatialMergeUnit, spatialMergeUnit, -1) + rotaryPosEmbReshaped = rotaryPosEmbReshaped[windowIndex, 0..., 0...] + rotaryPosEmbReshaped = rotaryPosEmbReshaped.reshaped(seqLen, -1) + + // Process through blocks + for (i, block) in blocks.enumerated() { + // Use full attention for specific blocks, window attention for others + let attentionMask = + fullattBlockIndexes.contains(i) ? fullAttentionMask : windowAttentionMask + + hiddenStates = block( + hiddenStates, + attentionMask: attentionMask, + rotaryPositionEmbedding: rotaryPosEmbReshaped + ) + } + + // Apply patch merger + hiddenStates = patchMerger(hiddenStates) + + // Reorder back to original sequence + let reverseIndices = argSort(windowIndex, axis: 0) + hiddenStates = hiddenStates[reverseIndices, 0...] + + return hiddenStates + } + + private func isMLXWeight(_ array: MLXArray) -> Bool { + if array.ndim != 4, array.ndim != 5 { + return false + } + + if array.dim(-1) == 3 { + return true + } + + let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) + return outChannels >= kH && outChannels >= kW && kH == kW + } + + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var sanitizedWeights = [String: MLXArray]() + + for (k, v) in weights { + if k.contains("position_id") { + // Remove unused position_ids + continue + } else if k.contains("patch_embed.proj.weight") { + // PyTorch conv2d weight tensors have shape: + // [B, out_channels, in_channels, kH, KW] + // MLX conv2d expects the weight be of shape: + // [B, out_channels, kH, KW, in_channels] + if isMLXWeight(v) { + sanitizedWeights[k] = v + } else { + sanitizedWeights[k] = v.transposed(0, 2, 3, 4, 1) + } + } else { + sanitizedWeights[k] = v + } + } + + return sanitizedWeights + } + } +} + +// MARK: - Processor + +/// Qwen2.5VL VLM `UserInputProcessor`. +/// +/// This is meant to be used with ``Qwen25VL`` and is typically created by ``VLMModelFactory``. +public class Qwen25VLProcessor: UserInputProcessor { + private let config: Qwen25VLProcessorConfiguration + private let tokenizer: any Tokenizer + + public init(_ config: Qwen25VLProcessorConfiguration, tokenizer: any Tokenizer) { + self.config = config + self.tokenizer = tokenizer + } + + func preprocess(image: CIImage, resizedSize: CGSize) -> CIImage { + image + .toSRGB() + .resampled(to: resizedSize, method: .bicubic) + .normalized(mean: config.imageMeanTuple, std: config.imageStdTuple) + } + + public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( + MLXArray, THW + ) { + // First apply the user requested resizing, etc. if any + let images = images.map { MediaProcessing.apply($0, processing: processing) } + + // image_processing_qwen2_vl._preprocess + let size = images[0].extent.size + let (resizedHeight, resizedWidth) = try QwenVL.targetSize( + height: Int(size.height), width: Int(size.width), + factor: config.patchSize * config.mergeSize, + minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) + let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + + // Process images + let processedImages = + try images + .map { + MediaProcessing.inSRGBToneCurveSpace($0) + } + .map { + return try MediaProcessing.resampleBicubic($0, to: resizedSize) + } + .map { + MediaProcessing.normalize( + $0, mean: config.imageMeanTuple, std: config.imageStdTuple) + } + .map { + MediaProcessing.asMLXArray($0) + } + + return try QwenVL.patchify( + images: processedImages, mergeSize: config.mergeSize, patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize) + } + + public func prepare(input: UserInput) async throws -> LMInput { + let messages = input.prompt.asMessages() + + var promptTokens = try tokenizer.applyChatTemplate(messages: messages) + + // Text-only input + if input.images.isEmpty, input.videos.isEmpty { + return LMInput(tokens: MLXArray(promptTokens)) + } + + // Process images if any + var processedImage: LMInput.ProcessedImage? + if !input.images.isEmpty { + let imagePixelsAndFrames = try input.images.map { + try preprocess(images: [$0.asCIImage()], processing: input.processing) + } + let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) + processedImage = LMInput.ProcessedImage( + pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) + + if let imageFrames = processedImage?.frames { + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) + } + } + + // Process videos if any + var processedVideo: LMInput.ProcessedVideo? + if !input.videos.isEmpty { + var videosAsImageSequences = [[MLXArray]]() + var resizedSize: CGSize = .zero + for video in input.videos { + let imageSequence = try await MediaProcessing.asProcessedSequence( + video.asAVAsset(), samplesPerSecond: 2 + ) { frame in + // first apply the user requested resizing, etc. if any + let resizedImage = MediaProcessing.apply( + frame.frame, processing: input.processing) + if resizedSize == .zero { + let size = resizedImage.extent.size + let (resizedHeight, resizedWidth) = try QwenVL.targetSize( + height: Int(size.height), width: Int(size.width), + factor: config.patchSize * config.mergeSize, + minPixels: config.minPixels, maxPixels: config.maxPixels) + resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + } + let processedImage = preprocess(image: resizedImage, resizedSize: resizedSize) + return VideoFrame(frame: processedImage, timeStamp: frame.timeStamp) + } + videosAsImageSequences.append(imageSequence.frames) + } + let videoPixelsAndFrames = try videosAsImageSequences.map { + try QwenVL.patchify( + images: $0, mergeSize: config.mergeSize, patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize) + } + let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) + processedVideo = LMInput.ProcessedVideo( + pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) + if let videoFrames = processedVideo?.frames { + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) + } + } + + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray).asType(.int8) + return LMInput( + text: .init(tokens: promptArray, mask: mask), + image: processedImage, + video: processedVideo) + } +} + +// MARK: - Model + +/// Qwen2.5VL VLM +/// +/// This is typically created by ``VLMModelFactory``. +public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { + + @ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel + @ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel + + public let config: Qwen25VLConfiguration + + public var vocabularySize: Int { config.baseConfiguration.vocabularySize } + public var kvHeads: [Int] { languageModel.kvHeads } + + public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers { + languageModel.model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } + + public init(_ config: Qwen25VLConfiguration) { + self.config = config + self._visionModel.wrappedValue = Vision.VisionModel(config.visionConfiguration) + self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) + } + + private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, frames: [THW]?) + -> MLXArray + { + guard let pixelValues, let frames else { + return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis]) + } + + // Get the input embeddings from the language model + let inputEmbeds = languageModel.model.embedTokens(inputIds) + + // Get the ouptut hidden states from the vision model + var hiddenStates = self.visionModel(pixelValues, frames: frames) + + if hiddenStates.ndim == 2 { + hiddenStates = hiddenStates[.newAxis, 0..., 0...] + } + + // Insert special image tokens in the input_ids + return QwenVL.mergeInputIdsWithImageFeatures( + inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates, + imageTokenId: config.baseConfiguration.imageTokenId, + videoTokenId: config.baseConfiguration.videoTokenId) + } + + public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws + -> PrepareResult + { + let dtype = visionModel.patchEmbed.proj.weight.dtype + + // Process both images and videos together + var allPixels: MLXArray? + var allFrames: [THW] = [] + + if let imagePixels = input.image?.pixels, let imageFrames = input.image?.frames { + allPixels = imagePixels.asType(dtype) + allFrames.append(contentsOf: imageFrames) + } + + if let videoPixels = input.video?.pixels, let videoFrames = input.video?.frames { + if allPixels == nil { + allPixels = videoPixels.asType(dtype) + } else { + allPixels = concatenated([allPixels!, videoPixels.asType(dtype)]) + } + allFrames.append(contentsOf: videoFrames) + } + + let inputEmbeddings = self.inputEmbeddings( + inputIds: input.text.tokens, pixelValues: allPixels, + frames: allFrames.isEmpty ? nil : allFrames) + + let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) + + return .logits(result) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray { + languageModel(inputs, cache: cache).logits + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + visionModel.sanitize( + weights: + Dictionary( + uniqueKeysWithValues: weights.map { key, value in + var key = key + if !key.contains("vision_tower") { + key = key.replacingOccurrences(of: "visual", with: "vision_tower") + } + if !key.contains("language_model") { + key = key.replacingOccurrences( + of: "model", with: "language_model.model") + key = key.replacingOccurrences( + of: "lm_head", with: "language_model.lm_head") + } + + return (key, value) + }) + ) + } +} + +// MARK: - Configuration + +/// Configuration for ``Qwen25VL`` +public struct Qwen25VLConfiguration: Codable, Sendable { + + public struct TextConfiguration: Codable, Sendable { + public let modelType: String + public let hiddenSize: Int + public let hiddenLayers: Int + public let intermediateSize: Int + public let attentionHeads: Int + private let _rmsNormEps: Float? + public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } + public let vocabularySize: Int + public let kvHeads: Int + private let _maxPositionEmbeddings: Int? + public var maxPositionEmbeddings: Int { _maxPositionEmbeddings ?? 128000 } + private let _ropeTheta: Float? + public var ropeTheta: Float { _ropeTheta ?? 1_000_000 } + private let _ropeTraditional: Bool? + public var ropeTraditional: Bool { _ropeTraditional ?? false } + public let ropeScaling: [String: StringOrNumber]? + private let _tieWordEmbeddings: Bool? + public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? true } + private let _slidingWindow: Int? + public var slidingWindow: Int { _slidingWindow ?? 32768 } + private let _useSlidingWindow: Bool? + public var useSlidingWindow: Bool { _useSlidingWindow ?? false } + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case _rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case _maxPositionEmbeddings = "max_position_embeddings" + case _ropeTheta = "rope_theta" + case _ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + case _tieWordEmbeddings = "tie_word_embeddings" + case _slidingWindow = "sliding_window" + case _useSlidingWindow = "use_sliding_window" + } + } + + public struct VisionConfiguration: Codable, Sendable { + public let depth: Int + public let hiddenSize: Int + public let intermediateSize: Int + public let outHiddenSize: Int + public let numHeads: Int + public let patchSize: Int + private let _inChans: Int? + public var inChannels: Int { _inChans ?? 3 } + private let _layerNormEps: Float? + public var layerNormEps: Float { _layerNormEps ?? 1e-6 } + public let spatialPatchSize: Int + public let spatialMergeSize: Int + public let temporalPatchSize: Int + public let windowSize: Int + public let fullattBlockIndexes: [Int] + public let tokensPerSecond: Int + private let _skipVision: Bool? + public var skipVision: Bool { _skipVision ?? false } + private let _hiddenAct: String? + public var hiddenAct: String { _hiddenAct ?? "silu" } + + enum CodingKeys: String, CodingKey { + case depth + case hiddenSize = "hidden_size" + case intermediateSize = "intermediate_size" + case outHiddenSize = "out_hidden_size" + case numHeads = "num_heads" + case patchSize = "patch_size" + case _inChans = "in_chans" + case _layerNormEps = "layer_norm_eps" // Added this line + case spatialPatchSize = "spatial_patch_size" + case spatialMergeSize = "spatial_merge_size" + case temporalPatchSize = "temporal_patch_size" + case windowSize = "window_size" + case fullattBlockIndexes = "fullatt_block_indexes" + case tokensPerSecond = "tokens_per_second" + case _skipVision = "skip_vision" + case _hiddenAct = "hidden_act" + } + } + + public struct BaseConfiguration: Codable, Sendable { + public let modelType: String + public let vocabularySize: Int + public let imageTokenId: Int + public let videoTokenId: Int + public let visionStartTokenId: Int + public let visionEndTokenId: Int + public let visionTokenId: Int + public let hiddenSize: Int + public let numAttentionHeads: Int + public let numHiddenLayers: Int + public let intermediateSize: Int + public let numKeyValueHeads: Int + public let slidingWindow: Int + public let useSlidingWindow: Bool + public let maxWindowLayers: Int + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case vocabularySize = "vocab_size" + case imageTokenId = "image_token_id" + case videoTokenId = "video_token_id" + case visionStartTokenId = "vision_start_token_id" + case visionEndTokenId = "vision_end_token_id" + case visionTokenId = "vision_token_id" + case hiddenSize = "hidden_size" + case numAttentionHeads = "num_attention_heads" + case numHiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case numKeyValueHeads = "num_key_value_heads" + case slidingWindow = "sliding_window" + case useSlidingWindow = "use_sliding_window" + case maxWindowLayers = "max_window_layers" + } + } + + public let textConfiguration: TextConfiguration + public let visionConfiguration: VisionConfiguration + public let baseConfiguration: BaseConfiguration + + enum CodingKeys: String, CodingKey { + case visionConfiguration = "vision_config" + } + + public init(from decoder: any Swift.Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + // this is a sub-dictionary + self.visionConfiguration = try container.decode( + VisionConfiguration.self, forKey: .visionConfiguration) + + // these are overlaid in the top level + self.textConfiguration = try TextConfiguration(from: decoder) + self.baseConfiguration = try BaseConfiguration(from: decoder) + } +} + +/// Configuration for ``Qwen25VLProcessor`` +public struct Qwen25VLProcessorConfiguration: Codable, Sendable { + public struct Size: Codable, Sendable { + public let maxPixels: Int + public let minPixels: Int + + enum CodingKeys: String, CodingKey { + case maxPixels = "max_pixels" + case minPixels = "min_pixels" + } + } + + public let imageMean: [CGFloat] + public let imageStd: [CGFloat] + public let minPixels: Int + public let maxPixels: Int + public let mergeSize: Int + public let patchSize: Int + public let temporalPatchSize: Int + public let imageProcessorType: String + + public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { + (imageMean[0], imageMean[1], imageMean[2]) + } + public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { + (imageStd[0], imageStd[1], imageStd[2]) + } + + public var size: Size { + Size(maxPixels: maxPixels, minPixels: minPixels) + } + + enum CodingKeys: String, CodingKey { + case imageMean = "image_mean" + case imageStd = "image_std" + case minPixels = "min_pixels" + case maxPixels = "max_pixels" + case mergeSize = "merge_size" + case patchSize = "patch_size" + case temporalPatchSize = "temporal_patch_size" + case imageProcessorType = "image_processor_type" + } +} diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index c129f2df..8b20d576 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -11,16 +11,6 @@ import MLXLMCommon import MLXNN import Tokenizers -// MARK: - Common - -/// Rotates half the hidden dims of the input -private func rotateHalf(_ x: MLXArray) -> MLXArray { - let index = x.dim(-1) / 2 - let x1 = x[.ellipsis, 0 ..< index] - let x2 = x[.ellipsis, index...] - return concatenated([-x2, x1], axis: -1) -} - // MARK: - Language private enum Language { @@ -47,8 +37,8 @@ private enum Language { )[0..., .newAxis, 0..., 0...] // Apply rotary embedding - let qEmbed = (q * cos) + (rotateHalf(q) * sin) - let kEmbed = (k * cos) + (rotateHalf(k) * sin) + let qEmbed = (q * cos) + (QwenVL.rotateHalf(q) * sin) + let kEmbed = (k * cos) + (QwenVL.rotateHalf(k) * sin) return (qEmbed, kEmbed) } @@ -267,64 +257,10 @@ private enum Vision { sin = tiled(sin, repetitions: [1, 1, 2]) sin = expandedDimensions(sin, axis: 0) - let output = (tensor * cos) + (rotateHalf(tensor) * sin) + let output = (tensor * cos) + (QwenVL.rotateHalf(tensor) * sin) return output.asType(tensor.dtype) } - fileprivate class VisionRotaryEmbedding { - let dimensions: Int - let theta: Float - let inverseFreq: MLXArray - - init(dimensions: Int, theta: Float) { - self.dimensions = dimensions - self.theta = theta - let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions - self.inverseFreq = 1.0 / pow(theta, p) - } - - func callAsFunction(sequenceLength: Int) -> MLXArray { - let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) - let freqs = outer(seq, inverseFreq) - return freqs - } - } - - fileprivate class PatchEmbed: Module, UnaryLayer { - @ModuleInfo var proj: Conv3d - - let patchSize: Int - let temporalPatchSize: Int - let inChannels: Int - let embedDimensions: Int - - init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, embedDimensions: Int) { - self.patchSize = patchSize - self.temporalPatchSize = temporalPatchSize - self.inChannels = inChannels - self.embedDimensions = embedDimensions - - let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) - self._proj.wrappedValue = Conv3d( - inputChannels: inChannels, - outputChannels: embedDimensions, - kernelSize: kernelSize, - stride: kernelSize, - bias: false - ) - } - - func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { - var hiddenStates = hiddenStates.reshaped( - -1, inChannels, temporalPatchSize, patchSize, patchSize - ).movedAxis(source: 1, destination: 4) - - hiddenStates = proj(hiddenStates) - hiddenStates = hiddenStates.reshaped(-1, embedDimensions) - return hiddenStates - } - } - fileprivate class PatchMerger: Module, UnaryLayer { let hiddenSize: Int @ModuleInfo(key: "ln_q") var layerNormQ: LayerNorm @@ -451,8 +387,8 @@ private enum Vision { fileprivate class VisionModel: Module { - @ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed - @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding + @ModuleInfo(key: "patch_embed") var patchEmbed: QwenVL.PatchEmbed + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: QwenVL.VisionRotaryEmbedding @ModuleInfo(key: "blocks") var blocks: [Qwen2VLVisionBlock] @ModuleInfo(key: "merger") var patchMerger: PatchMerger @@ -461,14 +397,14 @@ private enum Vision { public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { self.spatialMergeSize = config.spatialMergeSize - self._patchEmbed.wrappedValue = PatchEmbed( + self._patchEmbed.wrappedValue = QwenVL.PatchEmbed( patchSize: config.patchSize, temporalPatchSize: config.temporalPatchSize, inChannels: config.inChannels, embedDimensions: config.embedDimensions) let headDimensions = config.embedDimensions / config.numHeads - self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding( + self._rotaryPositionEmbedding.wrappedValue = QwenVL.VisionRotaryEmbedding( dimensions: headDimensions / 2, theta: 10_000) self._blocks.wrappedValue = (0 ..< config.depth).map { _ in @@ -592,38 +528,6 @@ public class Qwen2VLProcessor: UserInputProcessor { self.tokenizer = tokenizer } - // image_processing_qwen2_vl.smart_resize - private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) - throws -> (Int, Int) - { - if height < factor { - throw VLMError.imageProcessingFailure( - "height: \(height) must be larger than factor: \(factor)") - } - if width < factor { - throw VLMError.imageProcessingFailure( - "width: \(width) must be larger than factor: \(factor)") - } - if max(height, width) / min(height, width) > 200 { - throw VLMError.imageProcessingFailure( - "absolute aspect ratio must be smaller than 200: \(width)x\(height)") - } - - var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) - var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) - - if hBar * wBar > maxPixels { - let beta = sqrt(Float(height * width) / Float(maxPixels)) - hBar = Int(floor(Float(height) / beta / Float(factor))) * factor - wBar = Int(floor(Float(width) / beta / Float(factor))) * factor - } else if hBar * wBar < minPixels { - let beta = sqrt(Float(minPixels) / Float(height * width)) - hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor - wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor - } - return (hBar, wBar) - } - func preprocess(image: CIImage, resizedSize: CGSize) -> CIImage { image .toSRGB() @@ -640,7 +544,7 @@ public class Qwen2VLProcessor: UserInputProcessor { // image_processing_qwen2_vl._preprocess let size = images[0].extent.size - let (resizedHeight, resizedWidth) = try targetSize( + let (resizedHeight, resizedWidth) = try QwenVL.targetSize( height: Int(size.height), width: Int(size.width), factor: config.patchSize * config.mergeSize, minPixels: config.minPixels, maxPixels: config.maxPixels) @@ -650,49 +554,9 @@ public class Qwen2VLProcessor: UserInputProcessor { preprocess(image: image, resizedSize: resizedSize).asMLXArray() } - return try patchify(images: processedImages) - } - - public func patchify(images: [MLXArray]) throws -> ( - MLXArray, THW - ) { - guard let firstImage = images.first else { - throw VLMError.imageProcessingFailure("No images in video sequence") - } - let resizedHeight = firstImage.dim(-2) - let resizedWidth = firstImage.dim(-1) - var patches = concatenated(images) - let mod = patches.dim(0) % config.temporalPatchSize - if mod != 0 { - let lastPatch = patches[-1, .ellipsis] - let lastPatchRepeated = tiled( - lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) - patches = concatenated([patches, lastPatchRepeated]) - } - let channel = patches.dim(1) - let gridT = patches.dim(0) / self.config.temporalPatchSize - let gridH = resizedHeight / self.config.patchSize - let gridW = resizedWidth / self.config.patchSize - - patches = patches.reshaped( - gridT, - config.temporalPatchSize, - channel, - gridH / config.mergeSize, - config.mergeSize, - config.patchSize, - gridW / config.mergeSize, - config.mergeSize, - config.patchSize - ) - patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) - - let flattenedPatches = patches.reshaped( - gridT * gridH * gridW, - channel * config.temporalPatchSize * config.patchSize * config.patchSize - ) - - return (flattenedPatches, .init(gridT, gridH, gridW)) + return try QwenVL.patchify( + images: processedImages, mergeSize: config.mergeSize, patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize) } public func prepare(input: UserInput) async throws -> LMInput { @@ -714,8 +578,9 @@ public class Qwen2VLProcessor: UserInputProcessor { processedImage = LMInput.ProcessedImage( pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) if let imageFrames = processedImage?.frames { - promptTokens = try replacePaddingTokens( - in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>") + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) } } @@ -733,7 +598,7 @@ public class Qwen2VLProcessor: UserInputProcessor { frame.frame, processing: input.processing) if resizedSize == .zero { let size = resizedImage.extent.size - let (resizedHeight, resizedWidth) = try targetSize( + let (resizedHeight, resizedWidth) = try QwenVL.targetSize( height: Int(size.height), width: Int(size.width), factor: config.patchSize * config.mergeSize, minPixels: config.minPixels, maxPixels: config.maxPixels) @@ -744,13 +609,18 @@ public class Qwen2VLProcessor: UserInputProcessor { } videosAsImageSequences.append(imageSequence.frames) } - let videoPixelsAndFrames = try videosAsImageSequences.map(patchify) + let videoPixelsAndFrames = try videosAsImageSequences.map { + try QwenVL.patchify( + images: $0, mergeSize: config.mergeSize, patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize) + } let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) processedVideo = LMInput.ProcessedVideo( pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) if let videoFrames = processedVideo?.frames { - promptTokens = try replacePaddingTokens( - in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>") + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) } } @@ -761,42 +631,6 @@ public class Qwen2VLProcessor: UserInputProcessor { image: processedImage, video: processedVideo) } - - func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String) - throws -> [Int] - { - // Replace single padding token with correct number for each image or video frame - let placeholderTokens = try tokenizer.encode( - text: "<|vision_start|>\(paddingToken)<|vision_end|>") - let placeholderRanges = promptTokens.ranges(of: placeholderTokens) - guard placeholderRanges.count == frames.count else { - throw VLMError.processing( - "Number of placeholder tokens does not match number of frames") - } - let mergeLength = config.mergeSize * config.mergeSize - let replacementSequences = try frames.map { frame in - let paddingCount = frame.product / mergeLength - return try tokenizer.encode( - text: - "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" - ) - } - // Build the final array - var result: [Int] = [] - var currentIndex = promptTokens.startIndex - for (range, replacement) in zip(placeholderRanges, replacementSequences) { - // Add tokens before the placeholder - result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) - // Add replacement sequence - result.append(contentsOf: replacement) - currentIndex = range.upperBound - } - // Add any remaining tokens after the last replacement - if currentIndex < promptTokens.endIndex { - result.append(contentsOf: promptTokens[currentIndex...]) - } - return result - } } // MARK: - Model @@ -842,37 +676,10 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { } // Insert special image tokens in the input_ids - return mergeInputIdsWithImageFeatures( - inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates) - } - - private func mergeInputIdsWithImageFeatures( - inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray - ) -> MLXArray { - let imageTokenIndex = config.baseConfiguration.imageTokenId - let videoTokenIndex = config.baseConfiguration.videoTokenId - - var imageIndices = [Int]() - for (i, v) in inputIds.asArray(Int.self).enumerated() { - if v == imageTokenIndex || v == videoTokenIndex { - imageIndices.append(i) - } - } - - // Make sure shapes match before assignment - var result = inputEmbeds - if result.ndim == 2 { - result = result[.newAxis, 0..., 0...] - } - - if imageFeatures.ndim == 2 { - let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] - result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures - } else { - result[0..., MLXArray(imageIndices), 0...] = imageFeatures - } - - return result + return QwenVL.mergeInputIdsWithImageFeatures( + inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates, + imageTokenId: config.baseConfiguration.imageTokenId, + videoTokenId: config.baseConfiguration.videoTokenId) } public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift new file mode 100644 index 00000000..e1bf168c --- /dev/null +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -0,0 +1,259 @@ +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +// MARK: - Common Utilities for Qwen 2 VL and Qwen 2.5 VL + +private func debug(_ message: @autoclosure () -> String) { + // print(message()) +} + +public struct QwenVL { + /// Rotates half the hidden dims of the input + static func rotateHalf(_ x: MLXArray) -> MLXArray { + let index = x.dim(-1) / 2 + let x1 = x[.ellipsis, 0 ..< index] + let x2 = x[.ellipsis, index...] + return concatenated([-x2, x1], axis: -1) + } + + static func mergeInputIdsWithImageFeatures( + inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray, + imageTokenId: Int, videoTokenId: Int + ) -> MLXArray { + var imageIndices = [Int]() + for (i, v) in inputIds.asArray(Int.self).enumerated() { + if v == imageTokenId || v == videoTokenId { + imageIndices.append(i) + } + } + + // Make sure shapes match before assignment + var result = inputEmbeds + if result.ndim == 2 { + result = result[.newAxis, 0..., 0...] + } + + if imageFeatures.ndim == 2 { + let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] + result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures + } else { + result[0..., MLXArray(imageIndices), 0...] = imageFeatures + } + + return result + } + + public class VisionRotaryEmbedding { + let dimensions: Int + let theta: Float + let inverseFreq: MLXArray + + init(dimensions: Int, theta: Float) { + self.dimensions = dimensions + self.theta = theta + let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions + self.inverseFreq = 1.0 / pow(theta, p) + } + + func callAsFunction(sequenceLength: Int) -> MLXArray { + let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) + let freqs = outer(seq, inverseFreq) + return freqs + } + } + + public class PatchEmbed: Module, UnaryLayer { + @ModuleInfo var proj: Conv3d + + let patchSize: Int + let temporalPatchSize: Int + let inChannels: Int + let outputDimensions: Int + + // For Qwen 2 VL + convenience init( + patchSize: Int, temporalPatchSize: Int, inChannels: Int, embedDimensions: Int + ) { + self.init( + patchSize: patchSize, temporalPatchSize: temporalPatchSize, + inChannels: inChannels, outputDimensions: embedDimensions) + } + + // For Qwen 2.5 VL + convenience init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, hiddenSize: Int) { + self.init( + patchSize: patchSize, temporalPatchSize: temporalPatchSize, + inChannels: inChannels, outputDimensions: hiddenSize) + } + + // Common initializer + init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, outputDimensions: Int) { + self.patchSize = patchSize + self.temporalPatchSize = temporalPatchSize + self.inChannels = inChannels + self.outputDimensions = outputDimensions + + let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) + self._proj.wrappedValue = Conv3d( + inputChannels: inChannels, + outputChannels: outputDimensions, + kernelSize: kernelSize, + stride: kernelSize, + bias: false + ) + } + + public func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { + var hiddenStates = hiddenStates.reshaped( + -1, inChannels, temporalPatchSize, patchSize, patchSize + ).movedAxis(source: 1, destination: 4) + + hiddenStates = proj(hiddenStates) + hiddenStates = hiddenStates.reshaped(-1, outputDimensions) + return hiddenStates + } + } + + // image_processing_qwen2_vl.smart_resize + static func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) + throws + -> (Int, Int) + { + debug("Original dimensions: \(width) × \(height)") + debug("Factor: \(factor), minPixels: \(minPixels), maxPixels: \(maxPixels)") + + if height < factor { + throw VLMError.imageProcessingFailure( + "Height: \(height) must be larger than factor: \(factor)") + } + if width < factor { + throw VLMError.imageProcessingFailure( + "Width: \(width) must be larger than factor: \(factor)") + } + if max(height, width) / min(height, width) > 200 { + throw VLMError.imageProcessingFailure( + "Absolute aspect ratio must be smaller than 200: \(width) × \(height)") + } + + var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) + var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) + debug("After rounding to factor multiples: \(wBar) × \(hBar)") + + // Scale based on total pixel count + if hBar * wBar > maxPixels { + let beta = sqrt(Float(height * width) / Float(maxPixels)) + hBar = Int(floor(Float(height) / beta / Float(factor))) * factor + wBar = Int(floor(Float(width) / beta / Float(factor))) * factor + debug("After scaling down for maxPixels: \(wBar) × \(hBar)") + } else if hBar * wBar < minPixels { + let beta = sqrt(Float(minPixels) / Float(height * width)) + hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor + wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor + debug("After scaling up for minPixels: \(wBar) × \(hBar)") + } + + // Ensure dimensions are divisible by the factor + hBar = (hBar / factor) * factor + wBar = (wBar / factor) * factor + debug("Final dimensions: \(wBar) × \(hBar)") + debug("Total pixels: \(wBar * hBar)") + + // Final sanity check + if hBar <= 0 || wBar <= 0 { + throw VLMError.imageProcessingFailure( + "Invalid target dimensions: \(wBar) × \(hBar)") + } + + return (hBar, wBar) + } + + static func replacePaddingTokens( + in promptTokens: [Int], frames: [THW], paddingToken: String, mergeSize: Int, + tokenizer: any Tokenizer + ) throws -> [Int] { + // Replace single padding token with correct number for each image or video frame + let placeholderTokens = try tokenizer.encode( + text: "<|vision_start|>\(paddingToken)<|vision_end|>") + let placeholderRanges = promptTokens.ranges(of: placeholderTokens) + guard placeholderRanges.count == frames.count else { + throw VLMError.processing( + "Number of placeholder tokens does not match number of frames") + } + let mergeLength = mergeSize * mergeSize + let replacementSequences = try frames.map { frame in + let paddingCount = frame.product / mergeLength + return try tokenizer.encode( + text: + "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" + ) + } + // Build the final array + var result: [Int] = [] + var currentIndex = promptTokens.startIndex + for (range, replacement) in zip(placeholderRanges, replacementSequences) { + // Add tokens before the placeholder + result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) + // Add replacement sequence + result.append(contentsOf: replacement) + currentIndex = range.upperBound + } + // Add any remaining tokens after the last replacement + if currentIndex < promptTokens.endIndex { + result.append(contentsOf: promptTokens[currentIndex...]) + } + return result + } + + static func patchify(images: [MLXArray], mergeSize: Int, patchSize: Int, temporalPatchSize: Int) + throws -> ( + MLXArray, THW + ) + { + guard let firstImage = images.first else { + throw VLMError.imageProcessingFailure("No images in video sequence") + } + let resizedHeight = firstImage.dim(-2) + let resizedWidth = firstImage.dim(-1) + var patches = concatenated(images) + + // Pad to match temporal patch size if needed + let mod = patches.dim(0) % temporalPatchSize + if mod != 0 { + let lastPatch = patches[-1, .ellipsis] + let lastPatchRepeated = tiled( + lastPatch, repetitions: [temporalPatchSize - mod, 1, 1, 1]) + patches = concatenated([patches, lastPatchRepeated]) + } + let channel = patches.dim(1) + let gridT = patches.dim(0) / temporalPatchSize + let gridH = resizedHeight / patchSize + let gridW = resizedWidth / patchSize + + patches = patches.reshaped( + gridT, + temporalPatchSize, + channel, + gridH / mergeSize, + mergeSize, + patchSize, + gridW / mergeSize, + mergeSize, + patchSize + ) + patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) + + let flattenedPatches = patches.reshaped( + gridT * gridH * gridW, + channel * temporalPatchSize * patchSize * patchSize + ) + + return (flattenedPatches, .init(gridT, gridH, gridW)) + } + +} diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index d63957a3..0079fc4f 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -83,11 +83,11 @@ public class VLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable { [ "paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init), "qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init), + "qwen2_5_vl": create(Qwen25VLConfiguration.self, Qwen25VL.init), "idefics3": create(Idefics3Configuration.self, Idefics3.init), "smolvlm": create(SmolVLM2Configuration.self, SmolVLM2.init), ] } - } public class VLMProcessorTypeRegistry: ProcessorTypeRegistry, @unchecked Sendable { @@ -101,15 +101,17 @@ public class VLMProcessorTypeRegistry: ProcessorTypeRegistry, @unchecked Sendabl { [ "PaliGemmaProcessor": create( - PaliGemmaProcessorConfiguration.self, PaligGemmaProcessor.init), - "Qwen2VLProcessor": create(Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init), + PaliGemmaProcessorConfiguration.self, PaliGemmaProcessor.init), + "Qwen2VLProcessor": create( + Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init), + "Qwen2_5_VLProcessor": create( + Qwen25VLProcessorConfiguration.self, Qwen25VLProcessor.init), "Idefics3Processor": create( Idefics3ProcessorConfiguration.self, Idefics3Processor.init), "SmolVLMProcessor": create( SmolVLMProcessorConfiguration.self, SmolVLMProcessor.init), ] } - } /// Registry of models and any overrides that go with them, e.g. prompt augmentation. @@ -133,16 +135,28 @@ public class VLMRegistry: AbstractModelRegistry, @unchecked Sendable { defaultPrompt: "Describe the image in English" ) + static public let qwen2_5VL3BInstruct4Bit = ModelConfiguration( + id: "mlx-community/Qwen2.5-VL-3B-Instruct-4bit", + defaultPrompt: "Describe the image in English" + ) + + static public let smolvlminstruct4bit = ModelConfiguration( + id: "mlx-community/SmolVLM-Instruct-4bit", + defaultPrompt: "Describe the image in English" + ) + static public let smolvlm = ModelConfiguration( id: "HuggingFaceTB/SmolVLM2-500M-Video-Instruct-mlx", defaultPrompt: "What is the main action or notable event happening in this segment? Describe it in one brief sentence." ) - static private func all() -> [ModelConfiguration] { + static public func all() -> [ModelConfiguration] { [ paligemma3bMix448_8bit, qwen2VL2BInstruct4Bit, + qwen2_5VL3BInstruct4Bit, + smolvlminstruct4bit, smolvlm, ] } diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index b2c346d5..9553e9a3 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "369f2014f0f4b1785f2b2642d3b4a3cbd3220a38b18d03ac9d74965949a0f0ba", + "originHash" : "0777c427cd29bb45ee52257882d29c3c2063039870a79b9b91a32154eb35f7b5", "pins" : [ { "identity" : "gzipswift", @@ -24,8 +24,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab", - "version" : "0.21.2" + "revision" : "b990c58153af70eb0914bca7dd74401d341fa9ae", + "version" : "0.21.3" } }, { @@ -96,8 +96,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-numerics", "state" : { - "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", - "version" : "1.0.2" + "revision" : "e0ec0f5f3af6f3e4d5e7a19d2af26b481acb6ba8", + "version" : "1.0.3" } }, { @@ -105,8 +105,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers", "state" : { - "revision" : "55710ddfb1ae804b4b7ce973be75cf2e41272185", - "version" : "0.1.17" + "revision" : "be855fac725dbae27264e47a3eb535cc422a4ba8", + "version" : "0.1.18" } } ],