@@ -9,16 +9,6 @@ import MLXLMCommon
9
9
import MLXNN
10
10
import Tokenizers
11
11
12
- // MARK: - Common
13
-
14
- /// Rotates half the hidden dims of the input
15
- private func rotateHalf( _ x: MLXArray ) -> MLXArray {
16
- let index = x. dim ( - 1 ) / 2
17
- let x1 = x [ . ellipsis, 0 ..< index]
18
- let x2 = x [ . ellipsis, index... ]
19
- return concatenated ( [ - x2, x1] , axis: - 1 )
20
- }
21
-
22
12
// MARK: - Language
23
13
24
14
private enum Language {
@@ -45,8 +35,8 @@ private enum Language {
45
35
) [ 0 ... , . newAxis, 0 ... , 0 ... ]
46
36
47
37
// Apply rotary embedding
48
- let qEmbed = ( q * cos) + ( rotateHalf ( q) * sin)
49
- let kEmbed = ( k * cos) + ( rotateHalf ( k) * sin)
38
+ let qEmbed = ( q * cos) + ( QwenVL . rotateHalf ( q) * sin)
39
+ let kEmbed = ( k * cos) + ( QwenVL . rotateHalf ( k) * sin)
50
40
return ( qEmbed, kEmbed)
51
41
}
52
42
@@ -264,64 +254,10 @@ private enum Vision {
264
254
sin = tiled ( sin, repetitions: [ 1 , 1 , 2 ] )
265
255
sin = expandedDimensions ( sin, axis: 0 )
266
256
267
- let output = ( tensor * cos) + ( rotateHalf ( tensor) * sin)
257
+ let output = ( tensor * cos) + ( QwenVL . rotateHalf ( tensor) * sin)
268
258
return output. asType ( tensor. dtype)
269
259
}
270
260
271
- fileprivate class VisionRotaryEmbedding {
272
- let dimensions : Int
273
- let theta : Float
274
- let inverseFreq : MLXArray
275
-
276
- init ( dimensions: Int , theta: Float ) {
277
- self . dimensions = dimensions
278
- self . theta = theta
279
- let p = MLXArray ( stride ( from: 0 , to: dimensions, by: 2 ) ) . asType ( . float32) / dimensions
280
- self . inverseFreq = 1.0 / pow( theta, p)
281
- }
282
-
283
- func callAsFunction( sequenceLength: Int ) -> MLXArray {
284
- let seq = MLXArray ( 0 ..< sequenceLength) . asType ( inverseFreq. dtype)
285
- let freqs = outer ( seq, inverseFreq)
286
- return freqs
287
- }
288
- }
289
-
290
- fileprivate class PatchEmbed : Module , UnaryLayer {
291
- @ModuleInfo var proj : Conv3d
292
-
293
- let patchSize : Int
294
- let temporalPatchSize : Int
295
- let inChannels : Int
296
- let hiddenSize : Int
297
-
298
- init ( patchSize: Int , temporalPatchSize: Int , inChannels: Int , hiddenSize: Int ) {
299
- self . patchSize = patchSize
300
- self . temporalPatchSize = temporalPatchSize
301
- self . inChannels = inChannels
302
- self . hiddenSize = hiddenSize
303
-
304
- let kernelSize = IntOrTriple ( [ temporalPatchSize, patchSize, patchSize] )
305
- self . _proj. wrappedValue = Conv3d (
306
- inputChannels: inChannels,
307
- outputChannels: hiddenSize,
308
- kernelSize: kernelSize,
309
- stride: kernelSize,
310
- bias: false
311
- )
312
- }
313
-
314
- func callAsFunction( _ hiddenStates: MLXArray ) -> MLXArray {
315
- var hiddenStates = hiddenStates. reshaped (
316
- - 1 , inChannels, temporalPatchSize, patchSize, patchSize
317
- ) . movedAxis ( source: 1 , destination: 4 )
318
-
319
- hiddenStates = proj ( hiddenStates)
320
- hiddenStates = hiddenStates. reshaped ( - 1 , hiddenSize)
321
- return hiddenStates
322
- }
323
- }
324
-
325
261
fileprivate class PatchMerger : Module , UnaryLayer {
326
262
let hiddenSize : Int
327
263
@ModuleInfo ( key: " ln_q " ) var layerNormQ : RMSNorm
@@ -457,8 +393,8 @@ private enum Vision {
457
393
458
394
fileprivate class VisionModel : Module {
459
395
460
- @ModuleInfo ( key: " patch_embed " ) var patchEmbed : PatchEmbed
461
- @ModuleInfo ( key: " rotary_pos_emb " ) var rotaryPositionEmbedding : VisionRotaryEmbedding
396
+ @ModuleInfo ( key: " patch_embed " ) var patchEmbed : QwenVL . PatchEmbed
397
+ @ModuleInfo ( key: " rotary_pos_emb " ) var rotaryPositionEmbedding : QwenVL . VisionRotaryEmbedding
462
398
@ModuleInfo ( key: " blocks " ) var blocks : [ Qwen25VLVisionBlock ]
463
399
@ModuleInfo ( key: " merger " ) var patchMerger : PatchMerger
464
400
@@ -475,14 +411,14 @@ private enum Vision {
475
411
self . spatialMergeUnit = config. spatialMergeSize * config. spatialMergeSize
476
412
self . fullattBlockIndexes = config. fullattBlockIndexes
477
413
478
- self . _patchEmbed. wrappedValue = PatchEmbed (
414
+ self . _patchEmbed. wrappedValue = QwenVL . PatchEmbed (
479
415
patchSize: config. patchSize,
480
416
temporalPatchSize: config. temporalPatchSize,
481
417
inChannels: config. inChannels,
482
418
hiddenSize: config. hiddenSize)
483
419
484
420
let headDimensions = config. hiddenSize / config. numHeads
485
- self . _rotaryPositionEmbedding. wrappedValue = VisionRotaryEmbedding (
421
+ self . _rotaryPositionEmbedding. wrappedValue = QwenVL . VisionRotaryEmbedding (
486
422
dimensions: headDimensions / 2 , theta: 10_000 )
487
423
488
424
self . _blocks. wrappedValue = ( 0 ..< config. depth) . map { _ in
@@ -729,38 +665,6 @@ public class Qwen25VLProcessor: UserInputProcessor {
729
665
self . tokenizer = tokenizer
730
666
}
731
667
732
- // image_processing_qwen2_vl.smart_resize
733
- private func targetSize( height: Int , width: Int , factor: Int , minPixels: Int , maxPixels: Int )
734
- throws -> ( Int , Int )
735
- {
736
- if height < factor {
737
- throw VLMError . imageProcessingFailure (
738
- " height: \( height) must be larger than factor: \( factor) " )
739
- }
740
- if width < factor {
741
- throw VLMError . imageProcessingFailure (
742
- " width: \( width) must be larger than factor: \( factor) " )
743
- }
744
- if max ( height, width) / min( height, width) > 200 {
745
- throw VLMError . imageProcessingFailure (
746
- " absolute aspect ratio must be smaller than 200: \( width) x \( height) " )
747
- }
748
-
749
- var hBar = max ( factor, Int ( round ( Float ( height) / Float( factor) ) ) * factor)
750
- var wBar = max ( factor, Int ( round ( Float ( width) / Float( factor) ) ) * factor)
751
-
752
- if hBar * wBar > maxPixels {
753
- let beta = sqrt ( Float ( height * width) / Float( maxPixels) )
754
- hBar = Int ( floor ( Float ( height) / beta / Float( factor) ) ) * factor
755
- wBar = Int ( floor ( Float ( width) / beta / Float( factor) ) ) * factor
756
- } else if hBar * wBar < minPixels {
757
- let beta = sqrt ( Float ( minPixels) / Float( height * width) )
758
- hBar = Int ( ceil ( Float ( height) * beta / Float( factor) ) ) * factor
759
- wBar = Int ( ceil ( Float ( width) * beta / Float( factor) ) ) * factor
760
- }
761
- return ( hBar, wBar)
762
- }
763
-
764
668
public func preprocess( images: [ CIImage ] , processing: UserInput . Processing ? ) throws -> (
765
669
MLXArray , THW
766
670
) {
@@ -770,7 +674,7 @@ public class Qwen25VLProcessor: UserInputProcessor {
770
674
// image_processing_qwen2_vl._preprocess
771
675
772
676
let size = images [ 0 ] . extent. size
773
- let ( resizedHeight, resizedWidth) = try targetSize (
677
+ let ( resizedHeight, resizedWidth) = try QwenVL . targetSize (
774
678
height: Int ( size. height) , width: Int ( size. width) ,
775
679
factor: config. patchSize * config. mergeSize,
776
680
minPixels: config. size. minPixels, maxPixels: config. size. maxPixels)
@@ -845,8 +749,9 @@ public class Qwen25VLProcessor: UserInputProcessor {
845
749
processedImage = LMInput . ProcessedImage (
846
750
pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames. map { $0. 1 } )
847
751
if let imageFrames = processedImage? . frames {
848
- promptTokens = try replacePaddingTokens (
849
- in: promptTokens, frames: imageFrames, paddingToken: " <|image_pad|> " )
752
+ promptTokens = try QwenVL . replacePaddingTokens (
753
+ in: promptTokens, frames: imageFrames, paddingToken: " <|image_pad|> " ,
754
+ mergeSize: config. mergeSize, tokenizer: tokenizer)
850
755
}
851
756
}
852
757
@@ -868,8 +773,9 @@ public class Qwen25VLProcessor: UserInputProcessor {
868
773
processedVideo = LMInput . ProcessedVideo (
869
774
pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames. map { $0. 1 } )
870
775
if let videoFrames = processedVideo? . frames {
871
- promptTokens = try replacePaddingTokens (
872
- in: promptTokens, frames: videoFrames, paddingToken: " <|video_pad|> " )
776
+ promptTokens = try QwenVL . replacePaddingTokens (
777
+ in: promptTokens, frames: videoFrames, paddingToken: " <|video_pad|> " ,
778
+ mergeSize: config. mergeSize, tokenizer: tokenizer)
873
779
}
874
780
}
875
781
@@ -880,42 +786,6 @@ public class Qwen25VLProcessor: UserInputProcessor {
880
786
image: processedImage,
881
787
video: processedVideo)
882
788
}
883
-
884
- func replacePaddingTokens( in promptTokens: [ Int] , frames: [ THW] , paddingToken: String)
885
- throws -> [ Int]
886
- {
887
- // Replace single padding token with correct number for each image or video frame
888
- let placeholderTokens = try tokenizer. encode (
889
- text: " <|vision_start|> \( paddingToken) <|vision_end|> " )
890
- let placeholderRanges = promptTokens. ranges ( of: placeholderTokens)
891
- guard placeholderRanges. count == frames. count else {
892
- throw VLMError . processing (
893
- " Number of placeholder tokens does not match number of frames " )
894
- }
895
- let mergeLength = config. mergeSize * config. mergeSize
896
- let replacementSequences = try frames. map { frame in
897
- let paddingCount = frame. product / mergeLength
898
- return try tokenizer. encode (
899
- text:
900
- " <|vision_start|> \( Array ( repeating: paddingToken, count: paddingCount) . joined ( ) ) <|vision_end|> "
901
- )
902
- }
903
- // Build the final array
904
- var result : [ Int ] = [ ]
905
- var currentIndex = promptTokens. startIndex
906
- for (range, replacement) in zip ( placeholderRanges, replacementSequences) {
907
- // Add tokens before the placeholder
908
- result. append ( contentsOf: promptTokens [ currentIndex ..< range. lowerBound] )
909
- // Add replacement sequence
910
- result. append ( contentsOf: replacement)
911
- currentIndex = range. upperBound
912
- }
913
- // Add any remaining tokens after the last replacement
914
- if currentIndex < promptTokens. endIndex {
915
- result. append ( contentsOf: promptTokens [ currentIndex... ] )
916
- }
917
- return result
918
- }
919
789
}
920
790
921
791
// MARK: - Model
@@ -961,37 +831,10 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider {
961
831
}
962
832
963
833
// Insert special image tokens in the input_ids
964
- return mergeInputIdsWithImageFeatures (
965
- inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates)
966
- }
967
-
968
- private func mergeInputIdsWithImageFeatures(
969
- inputIds: MLXArray , inputEmbeds: MLXArray , imageFeatures: MLXArray
970
- ) -> MLXArray {
971
- let imageTokenId = config. baseConfiguration. imageTokenId
972
- let videoTokenId = config. baseConfiguration. videoTokenId
973
-
974
- var imageIndices = [ Int] ( )
975
- for (i, v) in inputIds. asArray ( Int . self) . enumerated ( ) {
976
- if v == imageTokenId || v == videoTokenId {
977
- imageIndices. append ( i)
978
- }
979
- }
980
-
981
- // Make sure shapes match before assignment
982
- var result = inputEmbeds
983
- if result. ndim == 2 {
984
- result = result [ . newAxis, 0 ... , 0 ... ]
985
- }
986
-
987
- if imageFeatures. ndim == 2 {
988
- let reshapedFeatures = imageFeatures [ . newAxis, 0 ... , 0 ... ]
989
- result [ 0 ... , MLXArray ( imageIndices) , 0 ... ] = reshapedFeatures
990
- } else {
991
- result [ 0 ... , MLXArray ( imageIndices) , 0 ... ] = imageFeatures
992
- }
993
-
994
- return result
834
+ return QwenVL . mergeInputIdsWithImageFeatures (
835
+ inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates,
836
+ imageTokenId: config. baseConfiguration. imageTokenId,
837
+ videoTokenId: config. baseConfiguration. videoTokenId)
995
838
}
996
839
997
840
public func prepare( _ input: LMInput , cache: [ any KVCache ] , windowSize: Int ? ) throws
0 commit comments