@@ -668,18 +668,18 @@ public class Qwen25VLProcessor: UserInputProcessor {
668
668
public func preprocess( images: [ CIImage ] , processing: UserInput . Processing ? ) throws -> (
669
669
MLXArray , THW
670
670
) {
671
- // first apply the user requested resizing, etc. if any
671
+ // First apply the user requested resizing, etc. if any
672
672
let images = images. map { MediaProcessing . apply ( $0, processing: processing) }
673
673
674
674
// image_processing_qwen2_vl._preprocess
675
-
676
675
let size = images [ 0 ] . extent. size
677
676
let ( resizedHeight, resizedWidth) = try QwenVL . targetSize (
678
677
height: Int ( size. height) , width: Int ( size. width) ,
679
678
factor: config. patchSize * config. mergeSize,
680
679
minPixels: config. size. minPixels, maxPixels: config. size. maxPixels)
681
680
let resizedSize = CGSize ( width: resizedWidth, height: resizedHeight)
682
681
682
+ // Process images
683
683
let processedImages =
684
684
try images
685
685
. map {
@@ -696,42 +696,79 @@ public class Qwen25VLProcessor: UserInputProcessor {
696
696
MediaProcessing . asMLXArray ( $0)
697
697
}
698
698
699
+ // Calculate grid dimensions
700
+ let gridT = images. count
701
+ let gridH = resizedHeight / config. patchSize
702
+ let gridW = resizedWidth / config. patchSize
703
+
704
+ // Ensure dimensions are valid
705
+ guard
706
+ resizedHeight % config. patchSize == 0 && resizedWidth % config. patchSize == 0
707
+ && gridH % config. mergeSize == 0 && gridW % config. mergeSize == 0
708
+ else {
709
+ throw VLMError . imageProcessingFailure (
710
+ " Image dimensions must be divisible by patch size and merge size " )
711
+ }
712
+
713
+ // Concatenate images and handle temporal patch size
699
714
var patches = concatenated ( processedImages)
715
+ let channel = patches. dim ( 1 )
716
+
717
+ // Pad to match temporal patch size if needed
700
718
let mod = patches. dim ( 0 ) % config. temporalPatchSize
701
719
if mod != 0 {
702
720
let lastPatch = patches [ - 1 , . ellipsis]
703
721
let lastPatchRepeated = tiled (
704
722
lastPatch, repetitions: [ config. temporalPatchSize - mod, 1 , 1 , 1 ] )
705
723
patches = concatenated ( [ patches, lastPatchRepeated] )
706
724
}
707
- let channel = patches. dim ( 1 )
708
- let gridT = patches. dim ( 0 ) / self . config. temporalPatchSize
709
- let gridH = resizedHeight / self . config. patchSize
710
- let gridW = resizedWidth / self . config. patchSize
711
-
712
- patches = patches. reshaped (
713
- gridT,
714
- config. temporalPatchSize,
715
- channel,
716
- gridH / config. mergeSize,
717
- config. mergeSize,
718
- config. patchSize,
719
- gridW / config. mergeSize,
720
- config. mergeSize,
721
- config. patchSize
722
- )
725
+
726
+ // Recalculate gridT after padding
727
+ let actualGridT = patches. dim ( 0 ) / config. temporalPatchSize
728
+
729
+ // Calculate expected size for verification
730
+ let totalElements = patches. size
731
+ let expectedElements =
732
+ actualGridT * config. temporalPatchSize * channel * resizedHeight * resizedWidth
733
+
734
+ // Try to reshape with careful dimension calculation
735
+ do {
736
+ patches = patches. reshaped (
737
+ actualGridT,
738
+ config. temporalPatchSize,
739
+ channel,
740
+ gridH / config. mergeSize,
741
+ config. mergeSize,
742
+ config. patchSize,
743
+ gridW / config. mergeSize,
744
+ config. mergeSize,
745
+ config. patchSize
746
+ )
747
+ } catch {
748
+ // If reshape fails, provide detailed error
749
+ throw VLMError . imageProcessingFailure (
750
+ " Failed to reshape patches: \( error) . Patches shape: \( patches. shape) , "
751
+ + " Target shape: ( \( actualGridT) , \( config. temporalPatchSize) , \( channel) , "
752
+ + " \( gridH / config. mergeSize) , \( config. mergeSize) , \( config. patchSize) , "
753
+ + " \( gridW / config. mergeSize) , \( config. mergeSize) , \( config. patchSize) ) "
754
+ )
755
+ }
756
+
757
+ // Continue with transpose and final reshape
723
758
patches = patches. transposed ( 0 , 3 , 6 , 4 , 7 , 2 , 1 , 5 , 8 )
724
759
725
760
let flattenedPatches = patches. reshaped (
726
- gridT * gridH * gridW,
727
- channel * config. temporalPatchSize * config. patchSize * config. patchSize
761
+ actualGridT * ( gridH / config. mergeSize) * ( gridW / config. mergeSize) ,
762
+ channel * config. temporalPatchSize * ( config. mergeSize * config. patchSize)
763
+ * ( config. mergeSize * config. patchSize)
728
764
)
729
765
730
- return ( flattenedPatches, . init( gridT , gridH, gridW) )
766
+ return ( flattenedPatches, . init( actualGridT , gridH, gridW) )
731
767
}
732
768
733
769
public func prepare( input: UserInput ) async throws -> LMInput {
734
770
let messages = input. prompt. asMessages ( )
771
+
735
772
var promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
736
773
737
774
// Text-only input
@@ -748,10 +785,16 @@ public class Qwen25VLProcessor: UserInputProcessor {
748
785
let imagePixelsConcatenated = concatenated ( imagePixelsAndFrames. map { $0. 0 } )
749
786
processedImage = LMInput . ProcessedImage (
750
787
pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames. map { $0. 1 } )
788
+
751
789
if let imageFrames = processedImage? . frames {
752
- promptTokens = try QwenVL . replacePaddingTokens (
753
- in: promptTokens, frames: imageFrames, paddingToken: " <|image_pad|> " ,
754
- mergeSize: config. mergeSize, tokenizer: tokenizer)
790
+ do {
791
+ promptTokens = try QwenVL . replacePaddingTokens (
792
+ in: promptTokens, frames: imageFrames, paddingToken: " <|image_pad|> " ,
793
+ mergeSize: config. mergeSize, tokenizer: tokenizer)
794
+ } catch {
795
+ print ( " Error in replacePaddingTokens: \( error) " )
796
+ throw error
797
+ }
755
798
}
756
799
}
757
800
@@ -772,10 +815,16 @@ public class Qwen25VLProcessor: UserInputProcessor {
772
815
let videoPixelsConcatenated = concatenated ( videoPixelsAndFrames. map { $0. 0 } )
773
816
processedVideo = LMInput . ProcessedVideo (
774
817
pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames. map { $0. 1 } )
818
+
775
819
if let videoFrames = processedVideo? . frames {
776
- promptTokens = try QwenVL . replacePaddingTokens (
777
- in: promptTokens, frames: videoFrames, paddingToken: " <|video_pad|> " ,
778
- mergeSize: config. mergeSize, tokenizer: tokenizer)
820
+ do {
821
+ promptTokens = try QwenVL . replacePaddingTokens (
822
+ in: promptTokens, frames: videoFrames, paddingToken: " <|video_pad|> " ,
823
+ mergeSize: config. mergeSize, tokenizer: tokenizer)
824
+ } catch {
825
+ print ( " Error in video replacePaddingTokens: \( error) " )
826
+ throw error
827
+ }
779
828
}
780
829
}
781
830
0 commit comments