Skip to content

Commit 9356549

Browse files
committed
More media downsampling fixes
1 parent 788fffa commit 9356549

File tree

5 files changed

+216
-71
lines changed

5 files changed

+216
-71
lines changed

Libraries/MLXVLM/MediaProcessing.swift

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,26 +63,48 @@ public enum MediaProcessing {
6363
/// - image: The image to resample
6464
/// - size: The target size
6565
/// - Returns: The resampled image
66-
static public func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
67-
let filter = CIFilter.bicubicScaleTransform()
68-
let extent = image.extent.size
69-
70-
filter.inputImage = image
66+
public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
67+
// First, create a CIFilter for precise resampling
68+
guard let filter = CIFilter(name: "CILanczosScaleTransform") else {
69+
// Fall back to affine transform if filter isn't available
70+
let scaleX = size.width / image.extent.width
71+
let scaleY = size.height / image.extent.height
72+
let transform = CGAffineTransform(scaleX: scaleX, y: scaleY)
73+
let scaled = image.transformed(by: transform)
74+
75+
// Force exact dimensions by cropping
76+
return scaled.cropped(to: CGRect(origin: .zero, size: size))
77+
}
7178

72-
// set the aspect ratio to match the aspect ratio of the target
73-
let inputAspectRatio = extent.width / extent.height
74-
let desiredAspectRatio = size.width / size.height
75-
filter.aspectRatio = Float(1 / inputAspectRatio * desiredAspectRatio)
79+
filter.setValue(image, forKey: kCIInputImageKey)
80+
filter.setValue(size.width / image.extent.width, forKey: kCIInputScaleKey)
81+
filter.setValue(1.0, forKey: kCIInputAspectRatioKey)
7682

77-
// Use the same scaling approach regardless of orientation
78-
let scale = min(size.width / extent.width, size.height / extent.height)
79-
filter.scale = Float(scale)
83+
guard let scaledImage = filter.outputImage else {
84+
// Fall back if filter fails
85+
let scaleX = size.width / image.extent.width
86+
let scaleY = size.height / image.extent.height
87+
let transform = CGAffineTransform(scaleX: scaleX, y: scaleY)
88+
let scaled = image.transformed(by: transform)
8089

81-
let rescaled = filter.outputImage!
90+
return scaled.cropped(to: CGRect(origin: .zero, size: size))
91+
}
8292

83-
// The image has a DoD larger than the requested size, so crop
84-
// it to the desired size
85-
return rescaled.cropped(to: CGRect(origin: .zero, size: size))
93+
// Calculate the crop rect to get exactly the requested size
94+
// Scale height separately to match the target height
95+
let heightScale = size.height / scaledImage.extent.height
96+
let finalImage = scaledImage.transformed(by: CGAffineTransform(scaleX: 1.0, y: heightScale))
97+
98+
// Create a rect with the exact dimensions we want
99+
let exactRect = CGRect(
100+
x: 0,
101+
y: 0,
102+
width: size.width,
103+
height: size.height
104+
)
105+
106+
// Crop to ensure exact dimensions
107+
return finalImage.cropped(to: exactRect)
86108
}
87109

88110
/// Normalize the image using the given mean and standard deviation parameters.

Libraries/MLXVLM/Models/Qwen25VL.swift

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -668,18 +668,18 @@ public class Qwen25VLProcessor: UserInputProcessor {
668668
public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> (
669669
MLXArray, THW
670670
) {
671-
// first apply the user requested resizing, etc. if any
671+
// First apply the user requested resizing, etc. if any
672672
let images = images.map { MediaProcessing.apply($0, processing: processing) }
673673

674674
// image_processing_qwen2_vl._preprocess
675-
676675
let size = images[0].extent.size
677676
let (resizedHeight, resizedWidth) = try QwenVL.targetSize(
678677
height: Int(size.height), width: Int(size.width),
679678
factor: config.patchSize * config.mergeSize,
680679
minPixels: config.size.minPixels, maxPixels: config.size.maxPixels)
681680
let resizedSize = CGSize(width: resizedWidth, height: resizedHeight)
682681

682+
// Process images
683683
let processedImages =
684684
try images
685685
.map {
@@ -696,42 +696,79 @@ public class Qwen25VLProcessor: UserInputProcessor {
696696
MediaProcessing.asMLXArray($0)
697697
}
698698

699+
// Calculate grid dimensions
700+
let gridT = images.count
701+
let gridH = resizedHeight / config.patchSize
702+
let gridW = resizedWidth / config.patchSize
703+
704+
// Ensure dimensions are valid
705+
guard
706+
resizedHeight % config.patchSize == 0 && resizedWidth % config.patchSize == 0
707+
&& gridH % config.mergeSize == 0 && gridW % config.mergeSize == 0
708+
else {
709+
throw VLMError.imageProcessingFailure(
710+
"Image dimensions must be divisible by patch size and merge size")
711+
}
712+
713+
// Concatenate images and handle temporal patch size
699714
var patches = concatenated(processedImages)
715+
let channel = patches.dim(1)
716+
717+
// Pad to match temporal patch size if needed
700718
let mod = patches.dim(0) % config.temporalPatchSize
701719
if mod != 0 {
702720
let lastPatch = patches[-1, .ellipsis]
703721
let lastPatchRepeated = tiled(
704722
lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1])
705723
patches = concatenated([patches, lastPatchRepeated])
706724
}
707-
let channel = patches.dim(1)
708-
let gridT = patches.dim(0) / self.config.temporalPatchSize
709-
let gridH = resizedHeight / self.config.patchSize
710-
let gridW = resizedWidth / self.config.patchSize
711-
712-
patches = patches.reshaped(
713-
gridT,
714-
config.temporalPatchSize,
715-
channel,
716-
gridH / config.mergeSize,
717-
config.mergeSize,
718-
config.patchSize,
719-
gridW / config.mergeSize,
720-
config.mergeSize,
721-
config.patchSize
722-
)
725+
726+
// Recalculate gridT after padding
727+
let actualGridT = patches.dim(0) / config.temporalPatchSize
728+
729+
// Calculate expected size for verification
730+
let totalElements = patches.size
731+
let expectedElements =
732+
actualGridT * config.temporalPatchSize * channel * resizedHeight * resizedWidth
733+
734+
// Try to reshape with careful dimension calculation
735+
do {
736+
patches = patches.reshaped(
737+
actualGridT,
738+
config.temporalPatchSize,
739+
channel,
740+
gridH / config.mergeSize,
741+
config.mergeSize,
742+
config.patchSize,
743+
gridW / config.mergeSize,
744+
config.mergeSize,
745+
config.patchSize
746+
)
747+
} catch {
748+
// If reshape fails, provide detailed error
749+
throw VLMError.imageProcessingFailure(
750+
"Failed to reshape patches: \(error). Patches shape: \(patches.shape), "
751+
+ "Target shape: (\(actualGridT), \(config.temporalPatchSize), \(channel), "
752+
+ "\(gridH / config.mergeSize), \(config.mergeSize), \(config.patchSize), "
753+
+ "\(gridW / config.mergeSize), \(config.mergeSize), \(config.patchSize))"
754+
)
755+
}
756+
757+
// Continue with transpose and final reshape
723758
patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8)
724759

725760
let flattenedPatches = patches.reshaped(
726-
gridT * gridH * gridW,
727-
channel * config.temporalPatchSize * config.patchSize * config.patchSize
761+
actualGridT * (gridH / config.mergeSize) * (gridW / config.mergeSize),
762+
channel * config.temporalPatchSize * (config.mergeSize * config.patchSize)
763+
* (config.mergeSize * config.patchSize)
728764
)
729765

730-
return (flattenedPatches, .init(gridT, gridH, gridW))
766+
return (flattenedPatches, .init(actualGridT, gridH, gridW))
731767
}
732768

733769
public func prepare(input: UserInput) async throws -> LMInput {
734770
let messages = input.prompt.asMessages()
771+
735772
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
736773

737774
// Text-only input
@@ -748,10 +785,16 @@ public class Qwen25VLProcessor: UserInputProcessor {
748785
let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 })
749786
processedImage = LMInput.ProcessedImage(
750787
pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 })
788+
751789
if let imageFrames = processedImage?.frames {
752-
promptTokens = try QwenVL.replacePaddingTokens(
753-
in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>",
754-
mergeSize: config.mergeSize, tokenizer: tokenizer)
790+
do {
791+
promptTokens = try QwenVL.replacePaddingTokens(
792+
in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>",
793+
mergeSize: config.mergeSize, tokenizer: tokenizer)
794+
} catch {
795+
print("Error in replacePaddingTokens: \(error)")
796+
throw error
797+
}
755798
}
756799
}
757800

@@ -772,10 +815,16 @@ public class Qwen25VLProcessor: UserInputProcessor {
772815
let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 })
773816
processedVideo = LMInput.ProcessedVideo(
774817
pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 })
818+
775819
if let videoFrames = processedVideo?.frames {
776-
promptTokens = try QwenVL.replacePaddingTokens(
777-
in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>",
778-
mergeSize: config.mergeSize, tokenizer: tokenizer)
820+
do {
821+
promptTokens = try QwenVL.replacePaddingTokens(
822+
in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>",
823+
mergeSize: config.mergeSize, tokenizer: tokenizer)
824+
} catch {
825+
print("Error in video replacePaddingTokens: \(error)")
826+
throw error
827+
}
779828
}
780829
}
781830

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -531,18 +531,18 @@ public class Qwen2VLProcessor: UserInputProcessor {
531531
public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> (
532532
MLXArray, THW
533533
) {
534-
// first apply the user requested resizing, etc. if any
534+
// First apply the user requested resizing, etc. if any
535535
let images = images.map { MediaProcessing.apply($0, processing: processing) }
536536

537537
// image_processing_qwen2_vl._preprocess
538-
539538
let size = images[0].extent.size
540539
let (resizedHeight, resizedWidth) = try QwenVL.targetSize(
541540
height: Int(size.height), width: Int(size.width),
542541
factor: config.patchSize * config.mergeSize,
543542
minPixels: config.size.minPixels, maxPixels: config.size.maxPixels)
544543
let resizedSize = CGSize(width: resizedWidth, height: resizedHeight)
545544

545+
// Process images
546546
let processedImages =
547547
try images
548548
.map {
@@ -559,42 +559,79 @@ public class Qwen2VLProcessor: UserInputProcessor {
559559
MediaProcessing.asMLXArray($0)
560560
}
561561

562+
// Calculate grid dimensions
563+
let gridT = images.count
564+
let gridH = resizedHeight / config.patchSize
565+
let gridW = resizedWidth / config.patchSize
566+
567+
// Ensure dimensions are valid
568+
guard
569+
resizedHeight % config.patchSize == 0 && resizedWidth % config.patchSize == 0
570+
&& gridH % config.mergeSize == 0 && gridW % config.mergeSize == 0
571+
else {
572+
throw VLMError.imageProcessingFailure(
573+
"Image dimensions must be divisible by patch size and merge size")
574+
}
575+
576+
// Concatenate images and handle temporal patch size
562577
var patches = concatenated(processedImages)
578+
let channel = patches.dim(1)
579+
580+
// Pad to match temporal patch size if needed
563581
let mod = patches.dim(0) % config.temporalPatchSize
564582
if mod != 0 {
565583
let lastPatch = patches[-1, .ellipsis]
566584
let lastPatchRepeated = tiled(
567585
lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1])
568586
patches = concatenated([patches, lastPatchRepeated])
569587
}
570-
let channel = patches.dim(1)
571-
let gridT = patches.dim(0) / self.config.temporalPatchSize
572-
let gridH = resizedHeight / self.config.patchSize
573-
let gridW = resizedWidth / self.config.patchSize
574-
575-
patches = patches.reshaped(
576-
gridT,
577-
config.temporalPatchSize,
578-
channel,
579-
gridH / config.mergeSize,
580-
config.mergeSize,
581-
config.patchSize,
582-
gridW / config.mergeSize,
583-
config.mergeSize,
584-
config.patchSize
585-
)
588+
589+
// Recalculate gridT after padding
590+
let actualGridT = patches.dim(0) / config.temporalPatchSize
591+
592+
// Calculate expected size for verification
593+
let totalElements = patches.size
594+
let expectedElements =
595+
actualGridT * config.temporalPatchSize * channel * resizedHeight * resizedWidth
596+
597+
// Try to reshape with careful dimension calculation
598+
do {
599+
patches = patches.reshaped(
600+
actualGridT,
601+
config.temporalPatchSize,
602+
channel,
603+
gridH / config.mergeSize,
604+
config.mergeSize,
605+
config.patchSize,
606+
gridW / config.mergeSize,
607+
config.mergeSize,
608+
config.patchSize
609+
)
610+
} catch {
611+
// If reshape fails, provide detailed error
612+
throw VLMError.imageProcessingFailure(
613+
"Failed to reshape patches: \(error). Patches shape: \(patches.shape), "
614+
+ "Target shape: (\(actualGridT), \(config.temporalPatchSize), \(channel), "
615+
+ "\(gridH / config.mergeSize), \(config.mergeSize), \(config.patchSize), "
616+
+ "\(gridW / config.mergeSize), \(config.mergeSize), \(config.patchSize))"
617+
)
618+
}
619+
620+
// Continue with transpose and final reshape
586621
patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8)
587622

588623
let flattenedPatches = patches.reshaped(
589-
gridT * gridH * gridW,
590-
channel * config.temporalPatchSize * config.patchSize * config.patchSize
624+
actualGridT * (gridH / config.mergeSize) * (gridW / config.mergeSize),
625+
channel * config.temporalPatchSize * (config.mergeSize * config.patchSize)
626+
* (config.mergeSize * config.patchSize)
591627
)
592628

593-
return (flattenedPatches, .init(gridT, gridH, gridW))
629+
return (flattenedPatches, .init(actualGridT, gridH, gridW))
594630
}
595631

596632
public func prepare(input: UserInput) async throws -> LMInput {
597633
let messages = input.prompt.asMessages()
634+
598635
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
599636

600637
// Text-only input
@@ -611,10 +648,16 @@ public class Qwen2VLProcessor: UserInputProcessor {
611648
let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 })
612649
processedImage = LMInput.ProcessedImage(
613650
pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 })
651+
614652
if let imageFrames = processedImage?.frames {
615-
promptTokens = try QwenVL.replacePaddingTokens(
616-
in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>",
617-
mergeSize: config.mergeSize, tokenizer: tokenizer)
653+
do {
654+
promptTokens = try QwenVL.replacePaddingTokens(
655+
in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>",
656+
mergeSize: config.mergeSize, tokenizer: tokenizer)
657+
} catch {
658+
print("Error in replacePaddingTokens: \(error)")
659+
throw error
660+
}
618661
}
619662
}
620663

@@ -635,10 +678,16 @@ public class Qwen2VLProcessor: UserInputProcessor {
635678
let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 })
636679
processedVideo = LMInput.ProcessedVideo(
637680
pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 })
681+
638682
if let videoFrames = processedVideo?.frames {
639-
promptTokens = try QwenVL.replacePaddingTokens(
640-
in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>",
641-
mergeSize: config.mergeSize, tokenizer: tokenizer)
683+
do {
684+
promptTokens = try QwenVL.replacePaddingTokens(
685+
in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>",
686+
mergeSize: config.mergeSize, tokenizer: tokenizer)
687+
} catch {
688+
print("Error in video replacePaddingTokens: \(error)")
689+
throw error
690+
}
642691
}
643692
}
644693

0 commit comments

Comments
 (0)