Skip to content

Commit efb9ed7

Browse files
Add Qwen 2.5 VL (#222)
* Add Qwen 2.5 VL * Fix media downsampling * hoist attention mask generation to VisionModel -- avoid recomputing the mask 32 times Co-authored-by: David Koski <[email protected]>
1 parent 581c6cb commit efb9ed7

File tree

8 files changed

+1461
-286
lines changed

8 files changed

+1461
-286
lines changed

Libraries/MLXVLM/MediaProcessing.swift

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ public struct ProcessedFrames {
1616
let totalDuration: CMTime
1717
}
1818

19-
// TODO: verify working color space, rendering color space
2019
private let context = CIContext()
2120

2221
/// Collection of methods for processing media (images, video, etc.).
@@ -27,7 +26,7 @@ private let context = CIContext()
2726
/// var image: CIImage
2827
/// image = MediaProcessing.inSRGBToneCurveSpace(image)
2928
///
30-
/// // apply user instructions
29+
/// // Apply user instructions
3130
/// image = MediaProcessing.apply(image, processing: processing)
3231
///
3332
/// image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize)
@@ -76,58 +75,58 @@ public enum MediaProcessing {
7675
return Float(1 / inputAspectRatio * desiredAspectRatio)
7776
}
7877

79-
/// Resample the image using bicubic interpolation.
80-
public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
81-
let filter = CIFilter.bicubicScaleTransform()
82-
let extent = image.extent.size
83-
84-
filter.inputImage = image
85-
86-
// set the aspect ratio to match the aspect ratio of the target
87-
filter.aspectRatio = aspectRatioForResample(image, size: size)
88-
89-
// that image is now the aspect ratio of the target and the size
90-
// of the shorter dimension
91-
let scale: CGFloat
92-
if extent.width < extent.height {
93-
scale = size.width / extent.width
94-
} else {
95-
scale = size.height / extent.height
96-
}
97-
filter.scale = Float(scale)
98-
99-
let rescaled = filter.outputImage!
100-
101-
// the image has a DoD larger than the requested size so crop
102-
// it to the desired size
103-
return rescaled.cropped(to: CGRect(origin: .zero, size: size))
104-
}
105-
10678
/// Resample the image using Lanczos interpolation.
10779
static public func resampleLanczos(_ image: CIImage, to size: CGSize) -> CIImage {
108-
let filter = CIFilter.lanczosScaleTransform()
109-
let extent = image.extent.size
80+
// Create a bicubic scale filter
81+
82+
let yScale = size.height / image.extent.height
83+
let xScale = size.width / image.extent.width
11084

85+
let filter = CIFilter.lanczosScaleTransform()
11186
filter.inputImage = image
87+
filter.scale = Float(yScale)
88+
filter.aspectRatio = Float(xScale / yScale)
89+
let scaledImage = filter.outputImage!
90+
91+
// Create a rect with the exact dimensions we want
92+
let exactRect = CGRect(
93+
x: 0,
94+
y: 0,
95+
width: size.width,
96+
height: size.height
97+
)
11298

113-
// set the aspect ratio to match the aspect ratio of the target
114-
filter.aspectRatio = aspectRatioForResample(image, size: size)
99+
// Crop to ensure exact dimensions
100+
return scaledImage.cropped(to: exactRect)
101+
}
115102

116-
// that image is now the aspect ratio of the target and the size
117-
// of the shorter dimension
118-
let scale: CGFloat
119-
if extent.width < extent.height {
120-
scale = size.width / extent.width
121-
} else {
122-
scale = size.height / extent.height
123-
}
124-
filter.scale = Float(scale)
103+
/// Resample the image using bicubic interpolation.
104+
/// - Parameters:
105+
/// - image: The image to resample
106+
/// - size: The target size
107+
/// - Returns: The resampled image
108+
public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
109+
// Create a bicubic scale filter
110+
111+
let yScale = size.height / image.extent.height
112+
let xScale = size.width / image.extent.width
125113

126-
let rescaled = filter.outputImage!
114+
let filter = CIFilter.bicubicScaleTransform()
115+
filter.inputImage = image
116+
filter.scale = Float(yScale)
117+
filter.aspectRatio = Float(xScale / yScale)
118+
let scaledImage = filter.outputImage!
119+
120+
// Create a rect with the exact dimensions we want
121+
let exactRect = CGRect(
122+
x: 0,
123+
y: 0,
124+
width: size.width,
125+
height: size.height
126+
)
127127

128-
// the image has a DoD larger than the requested size so crop
129-
// it to the desired size
130-
return rescaled.cropped(to: CGRect(origin: .zero, size: size))
128+
// Crop to ensure exact dimensions
129+
return scaledImage.cropped(to: exactRect)
131130
}
132131

133132
/// Normalize the image using the given mean and standard deviation parameters.
@@ -137,7 +136,7 @@ public enum MediaProcessing {
137136
let filter = CIFilter.colorMatrix()
138137
filter.inputImage = image
139138

140-
// this should match
139+
// This should match
141140
// https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html
142141
//
143142
// output[channel] = (input[channel] - mean[channel]) / std[channel]
@@ -156,6 +155,10 @@ public enum MediaProcessing {
156155
}
157156

158157
/// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]`
158+
/// - Parameters:
159+
/// - image: The image to convert
160+
/// - colorSpace: Optional color space for rendering
161+
/// - Returns: The MLXArray representation of the image
159162
public static func asMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) -> MLXArray {
160163
let size = image.extent.size
161164
let w = Int(size.width.rounded())
@@ -178,10 +181,10 @@ public enum MediaProcessing {
178181

179182
var array = MLXArray(data, [h, w, 4], type: Float32.self)
180183

181-
// drop 4th channel
184+
// Drop 4th channel
182185
array = array[0..., 0..., ..<3]
183186

184-
// convert to 1, C, H, W
187+
// Convert to 1, C, H, W
185188
array = array.reshaped(1, h, w, 3).transposed(0, 3, 1, 2)
186189

187190
return array

Libraries/MLXVLM/Models/Idefics3.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ public class Idefics3Processor: UserInputProcessor {
851851
height: fixedImageSize
852852
)
853853
image = MediaProcessing.apply(image, processing: input.processing)
854-
image = MediaProcessing.resampleBicubic(image, to: targetSize)
854+
image = try MediaProcessing.resampleBicubic(image, to: targetSize)
855855
image = MediaProcessing.normalize(
856856
image,
857857
mean: config.imageMeanTuple,

Libraries/MLXVLM/Models/Paligemma.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ private enum Vision {
441441
/// PaliGemma VLM `UserInputProcessor`.
442442
///
443443
/// This is meant to be used with ``PaliGemma`` and is typically created by ``VLMModelFactory``.
444-
public class PaligGemmaProcessor: UserInputProcessor {
444+
public class PaliGemmaProcessor: UserInputProcessor {
445445

446446
private let config: PaliGemmaProcessorConfiguration
447447
private let tokenizer: any Tokenizer
@@ -451,7 +451,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
451451
self.tokenizer = tokenizer
452452
}
453453

454-
private func prepare(image: CIImage, processing: UserInput.Processing?) -> MLXArray {
454+
private func prepare(image: CIImage, processing: UserInput.Processing?) throws -> MLXArray {
455455
// based on image_processing_siglip from transformers
456456
var image = image
457457

@@ -463,7 +463,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
463463
// apply user instructions
464464
image = MediaProcessing.apply(image, processing: processing)
465465

466-
image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize)
466+
image = try MediaProcessing.resampleBicubic(image, to: config.size.cgSize)
467467
image = MediaProcessing.normalize(
468468
image, mean: config.imageMeanTuple, std: config.imageStdTuple)
469469

@@ -705,7 +705,7 @@ public struct PaliGemmaConfiguration: Codable, Sendable {
705705
}
706706
}
707707

708-
/// Configuration for ``PaligGemmaProcessor``
708+
/// Configuration for ``PaliGemmaProcessor``
709709
public struct PaliGemmaProcessorConfiguration: Codable, Sendable {
710710

711711
public struct Size: Codable, Sendable {

0 commit comments

Comments
 (0)