Skip to content

Commit 0737268

Browse files
committed
Factor out shared parts
1 parent e324eff commit 0737268

File tree

3 files changed

+224
-350
lines changed

3 files changed

+224
-350
lines changed

Libraries/MLXVLM/Models/Qwen25VL.swift

Lines changed: 18 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,6 @@ import MLXLMCommon
99
import MLXNN
1010
import Tokenizers
1111

12-
// MARK: - Common
13-
14-
/// Rotates half the hidden dims of the input
15-
private func rotateHalf(_ x: MLXArray) -> MLXArray {
16-
let index = x.dim(-1) / 2
17-
let x1 = x[.ellipsis, 0 ..< index]
18-
let x2 = x[.ellipsis, index...]
19-
return concatenated([-x2, x1], axis: -1)
20-
}
21-
2212
// MARK: - Language
2313

2414
private enum Language {
@@ -45,8 +35,8 @@ private enum Language {
4535
)[0..., .newAxis, 0..., 0...]
4636

4737
// Apply rotary embedding
48-
let qEmbed = (q * cos) + (rotateHalf(q) * sin)
49-
let kEmbed = (k * cos) + (rotateHalf(k) * sin)
38+
let qEmbed = (q * cos) + (QwenVL.rotateHalf(q) * sin)
39+
let kEmbed = (k * cos) + (QwenVL.rotateHalf(k) * sin)
5040
return (qEmbed, kEmbed)
5141
}
5242

@@ -264,64 +254,10 @@ private enum Vision {
264254
sin = tiled(sin, repetitions: [1, 1, 2])
265255
sin = expandedDimensions(sin, axis: 0)
266256

267-
let output = (tensor * cos) + (rotateHalf(tensor) * sin)
257+
let output = (tensor * cos) + (QwenVL.rotateHalf(tensor) * sin)
268258
return output.asType(tensor.dtype)
269259
}
270260

271-
fileprivate class VisionRotaryEmbedding {
272-
let dimensions: Int
273-
let theta: Float
274-
let inverseFreq: MLXArray
275-
276-
init(dimensions: Int, theta: Float) {
277-
self.dimensions = dimensions
278-
self.theta = theta
279-
let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions
280-
self.inverseFreq = 1.0 / pow(theta, p)
281-
}
282-
283-
func callAsFunction(sequenceLength: Int) -> MLXArray {
284-
let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype)
285-
let freqs = outer(seq, inverseFreq)
286-
return freqs
287-
}
288-
}
289-
290-
fileprivate class PatchEmbed: Module, UnaryLayer {
291-
@ModuleInfo var proj: Conv3d
292-
293-
let patchSize: Int
294-
let temporalPatchSize: Int
295-
let inChannels: Int
296-
let hiddenSize: Int
297-
298-
init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, hiddenSize: Int) {
299-
self.patchSize = patchSize
300-
self.temporalPatchSize = temporalPatchSize
301-
self.inChannels = inChannels
302-
self.hiddenSize = hiddenSize
303-
304-
let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize])
305-
self._proj.wrappedValue = Conv3d(
306-
inputChannels: inChannels,
307-
outputChannels: hiddenSize,
308-
kernelSize: kernelSize,
309-
stride: kernelSize,
310-
bias: false
311-
)
312-
}
313-
314-
func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray {
315-
var hiddenStates = hiddenStates.reshaped(
316-
-1, inChannels, temporalPatchSize, patchSize, patchSize
317-
).movedAxis(source: 1, destination: 4)
318-
319-
hiddenStates = proj(hiddenStates)
320-
hiddenStates = hiddenStates.reshaped(-1, hiddenSize)
321-
return hiddenStates
322-
}
323-
}
324-
325261
fileprivate class PatchMerger: Module, UnaryLayer {
326262
let hiddenSize: Int
327263
@ModuleInfo(key: "ln_q") var layerNormQ: RMSNorm
@@ -457,8 +393,8 @@ private enum Vision {
457393

458394
fileprivate class VisionModel: Module {
459395

460-
@ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed
461-
@ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding
396+
@ModuleInfo(key: "patch_embed") var patchEmbed: QwenVL.PatchEmbed
397+
@ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: QwenVL.VisionRotaryEmbedding
462398
@ModuleInfo(key: "blocks") var blocks: [Qwen25VLVisionBlock]
463399
@ModuleInfo(key: "merger") var patchMerger: PatchMerger
464400

@@ -475,14 +411,14 @@ private enum Vision {
475411
self.spatialMergeUnit = config.spatialMergeSize * config.spatialMergeSize
476412
self.fullattBlockIndexes = config.fullattBlockIndexes
477413

478-
self._patchEmbed.wrappedValue = PatchEmbed(
414+
self._patchEmbed.wrappedValue = QwenVL.PatchEmbed(
479415
patchSize: config.patchSize,
480416
temporalPatchSize: config.temporalPatchSize,
481417
inChannels: config.inChannels,
482418
hiddenSize: config.hiddenSize)
483419

484420
let headDimensions = config.hiddenSize / config.numHeads
485-
self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding(
421+
self._rotaryPositionEmbedding.wrappedValue = QwenVL.VisionRotaryEmbedding(
486422
dimensions: headDimensions / 2, theta: 10_000)
487423

488424
self._blocks.wrappedValue = (0 ..< config.depth).map { _ in
@@ -729,38 +665,6 @@ public class Qwen25VLProcessor: UserInputProcessor {
729665
self.tokenizer = tokenizer
730666
}
731667

732-
// image_processing_qwen2_vl.smart_resize
733-
private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int)
734-
throws -> (Int, Int)
735-
{
736-
if height < factor {
737-
throw VLMError.imageProcessingFailure(
738-
"height: \(height) must be larger than factor: \(factor)")
739-
}
740-
if width < factor {
741-
throw VLMError.imageProcessingFailure(
742-
"width: \(width) must be larger than factor: \(factor)")
743-
}
744-
if max(height, width) / min(height, width) > 200 {
745-
throw VLMError.imageProcessingFailure(
746-
"absolute aspect ratio must be smaller than 200: \(width)x\(height)")
747-
}
748-
749-
var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor)
750-
var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor)
751-
752-
if hBar * wBar > maxPixels {
753-
let beta = sqrt(Float(height * width) / Float(maxPixels))
754-
hBar = Int(floor(Float(height) / beta / Float(factor))) * factor
755-
wBar = Int(floor(Float(width) / beta / Float(factor))) * factor
756-
} else if hBar * wBar < minPixels {
757-
let beta = sqrt(Float(minPixels) / Float(height * width))
758-
hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor
759-
wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor
760-
}
761-
return (hBar, wBar)
762-
}
763-
764668
public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> (
765669
MLXArray, THW
766670
) {
@@ -770,7 +674,7 @@ public class Qwen25VLProcessor: UserInputProcessor {
770674
// image_processing_qwen2_vl._preprocess
771675

772676
let size = images[0].extent.size
773-
let (resizedHeight, resizedWidth) = try targetSize(
677+
let (resizedHeight, resizedWidth) = try QwenVL.targetSize(
774678
height: Int(size.height), width: Int(size.width),
775679
factor: config.patchSize * config.mergeSize,
776680
minPixels: config.size.minPixels, maxPixels: config.size.maxPixels)
@@ -845,8 +749,9 @@ public class Qwen25VLProcessor: UserInputProcessor {
845749
processedImage = LMInput.ProcessedImage(
846750
pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 })
847751
if let imageFrames = processedImage?.frames {
848-
promptTokens = try replacePaddingTokens(
849-
in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>")
752+
promptTokens = try QwenVL.replacePaddingTokens(
753+
in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>",
754+
mergeSize: config.mergeSize, tokenizer: tokenizer)
850755
}
851756
}
852757

@@ -868,8 +773,9 @@ public class Qwen25VLProcessor: UserInputProcessor {
868773
processedVideo = LMInput.ProcessedVideo(
869774
pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 })
870775
if let videoFrames = processedVideo?.frames {
871-
promptTokens = try replacePaddingTokens(
872-
in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>")
776+
promptTokens = try QwenVL.replacePaddingTokens(
777+
in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>",
778+
mergeSize: config.mergeSize, tokenizer: tokenizer)
873779
}
874780
}
875781

@@ -880,42 +786,6 @@ public class Qwen25VLProcessor: UserInputProcessor {
880786
image: processedImage,
881787
video: processedVideo)
882788
}
883-
884-
func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String)
885-
throws -> [Int]
886-
{
887-
// Replace single padding token with correct number for each image or video frame
888-
let placeholderTokens = try tokenizer.encode(
889-
text: "<|vision_start|>\(paddingToken)<|vision_end|>")
890-
let placeholderRanges = promptTokens.ranges(of: placeholderTokens)
891-
guard placeholderRanges.count == frames.count else {
892-
throw VLMError.processing(
893-
"Number of placeholder tokens does not match number of frames")
894-
}
895-
let mergeLength = config.mergeSize * config.mergeSize
896-
let replacementSequences = try frames.map { frame in
897-
let paddingCount = frame.product / mergeLength
898-
return try tokenizer.encode(
899-
text:
900-
"<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>"
901-
)
902-
}
903-
// Build the final array
904-
var result: [Int] = []
905-
var currentIndex = promptTokens.startIndex
906-
for (range, replacement) in zip(placeholderRanges, replacementSequences) {
907-
// Add tokens before the placeholder
908-
result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound])
909-
// Add replacement sequence
910-
result.append(contentsOf: replacement)
911-
currentIndex = range.upperBound
912-
}
913-
// Add any remaining tokens after the last replacement
914-
if currentIndex < promptTokens.endIndex {
915-
result.append(contentsOf: promptTokens[currentIndex...])
916-
}
917-
return result
918-
}
919789
}
920790

921791
// MARK: - Model
@@ -961,37 +831,10 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider {
961831
}
962832

963833
// Insert special image tokens in the input_ids
964-
return mergeInputIdsWithImageFeatures(
965-
inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates)
966-
}
967-
968-
private func mergeInputIdsWithImageFeatures(
969-
inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray
970-
) -> MLXArray {
971-
let imageTokenId = config.baseConfiguration.imageTokenId
972-
let videoTokenId = config.baseConfiguration.videoTokenId
973-
974-
var imageIndices = [Int]()
975-
for (i, v) in inputIds.asArray(Int.self).enumerated() {
976-
if v == imageTokenId || v == videoTokenId {
977-
imageIndices.append(i)
978-
}
979-
}
980-
981-
// Make sure shapes match before assignment
982-
var result = inputEmbeds
983-
if result.ndim == 2 {
984-
result = result[.newAxis, 0..., 0...]
985-
}
986-
987-
if imageFeatures.ndim == 2 {
988-
let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...]
989-
result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures
990-
} else {
991-
result[0..., MLXArray(imageIndices), 0...] = imageFeatures
992-
}
993-
994-
return result
834+
return QwenVL.mergeInputIdsWithImageFeatures(
835+
inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates,
836+
imageTokenId: config.baseConfiguration.imageTokenId,
837+
videoTokenId: config.baseConfiguration.videoTokenId)
995838
}
996839

997840
public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws

0 commit comments

Comments
 (0)