Skip to content

Commit ec9523b

Browse files
cyrilzakkapcuencadavidkoski
authored
VLM support for image and video processing with SmolVLM support (#206)
* Update MediaProcessing.swift * smolvlm processing Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: David Koski <[email protected]> Co-authored-by: David Koski <[email protected]>
1 parent 4ce907b commit ec9523b

File tree

9 files changed

+723
-76
lines changed

9 files changed

+723
-76
lines changed

Applications/VLMEval/ContentView.swift

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ import SwiftUI
1515
typealias PlatformImage = NSImage
1616
#endif
1717

18+
let videoSystemPrompt =
19+
"Focus only on describing the key dramatic action or notable event occurring in this video segment. Skip general context or scene-setting details unless they are crucial to understanding the main action."
20+
let imageSystemPrompt =
21+
"You are an image understanding model capable of describing the salient features of any image."
22+
1823
struct ContentView: View {
1924
@State var prompt = ""
2025
@State var llm = VLMEvaluator()
@@ -28,7 +33,7 @@ struct ContentView: View {
2833
}
2934
}
3035
}
31-
@State private var selectedVideoURL: URL? = nil {
36+
@State private var selectedVideoURL: URL? {
3237
didSet {
3338
if let selectedVideoURL {
3439
player = AVPlayer(url: selectedVideoURL)
@@ -61,7 +66,11 @@ struct ContentView: View {
6166
}
6267

6368
VStack {
64-
if let selectedImage {
69+
if let player {
70+
VideoPlayer(player: player)
71+
.frame(height: 300)
72+
.cornerRadius(12)
73+
} else if let selectedImage {
6574
Group {
6675
#if os(iOS) || os(visionOS)
6776
Image(uiImage: selectedImage)
@@ -91,11 +100,6 @@ struct ContentView: View {
91100
EmptyView()
92101
}
93102
}
94-
} else if let player {
95-
VideoPlayer(player: player)
96-
.scaledToFit()
97-
.frame(maxHeight: 300)
98-
.cornerRadius(12)
99103
}
100104

101105
HStack {
@@ -193,6 +197,7 @@ struct ContentView: View {
193197
.id("bottom")
194198
}
195199
}
200+
.frame(minHeight: 200)
196201

197202
HStack {
198203
TextField("prompt", text: $prompt)
@@ -205,6 +210,11 @@ struct ContentView: View {
205210
.disabled(llm.running)
206211
}
207212
}
213+
.onAppear {
214+
selectedVideoURL = URL(
215+
string:
216+
"https://videos.pexels.com/video-files/4066325/4066325-uhd_2560_1440_24fps.mp4")!
217+
}
208218
#if os(visionOS)
209219
.padding(40)
210220
#else
@@ -320,12 +330,12 @@ class VLMEvaluator {
320330
var modelInfo = ""
321331
var stat = ""
322332

323-
/// This controls which model loads. `qwen2VL2BInstruct4Bit` is one of the smaller ones, so this will fit on
333+
/// This controls which model loads. `smolvlm` is very small even unquantized, so it will fit on
324334
/// more devices.
325-
let modelConfiguration = ModelRegistry.qwen2VL2BInstruct4Bit
335+
let modelConfiguration = VLMRegistry.smolvlm
326336

327-
/// parameters controlling the output
328-
let generateParameters = MLXLMCommon.GenerateParameters(temperature: 0.6)
337+
/// parameters controlling the output – use values appropriate for the model selected above
338+
let generateParameters = MLXLMCommon.GenerateParameters(temperature: 0.7, topP: 0.9)
329339
let maxTokens = 800
330340

331341
/// update the display every N tokens -- 4 looks like it updates continuously
@@ -401,7 +411,11 @@ class VLMEvaluator {
401411
[
402412
"role": "user",
403413
"content": [
404-
["type": "text", "text": prompt]
414+
[
415+
"type": "text",
416+
"text": videoURL != nil
417+
? videoSystemPrompt : imageSystemPrompt,
418+
]
405419
]
406420
// Messages format for Qwen 2 VL, Qwen 2.5 VL. May need to be adapted for other models.
407421
+ images.map { _ in

Libraries/MLXVLM/MediaProcessing.swift

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,18 @@ import CoreImage.CIFilterBuiltins
55
import MLX
66
import MLXLMCommon
77

8+
public struct VideoFrame {
9+
let frame: CIImage
10+
let timeStamp: CMTime
11+
}
12+
13+
public struct ProcessedFrames {
14+
let frames: [MLXArray]
15+
let timestamps: [CMTime]
16+
let totalDuration: CMTime
17+
}
18+
19+
// TODO: verify working color space, rendering color space
820
private let context = CIContext()
921

1022
/// Collection of methods for processing media (images, video, etc.).
@@ -58,6 +70,12 @@ public enum MediaProcessing {
5870
min(other.width / size.width, other.height / size.height)
5971
}
6072

73+
static public func aspectRatioForResample(_ image: CIImage, size: CGSize) -> Float {
74+
let inputAspectRatio = image.extent.width / image.extent.height
75+
let desiredAspectRatio = size.width / size.height
76+
return Float(1 / inputAspectRatio * desiredAspectRatio)
77+
}
78+
6179
/// Resample the image using bicubic interpolation.
6280
public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
6381
let filter = CIFilter.bicubicScaleTransform()
@@ -66,9 +84,34 @@ public enum MediaProcessing {
6684
filter.inputImage = image
6785

6886
// set the aspect ratio to match the aspect ratio of the target
69-
let inputAspectRatio = extent.width / extent.height
70-
let desiredAspectRatio = size.width / size.height
71-
filter.aspectRatio = Float(1 / inputAspectRatio * desiredAspectRatio)
87+
filter.aspectRatio = aspectRatioForResample(image, size: size)
88+
89+
// that image is now the aspect ratio of the target and the size
90+
// of the shorter dimension
91+
let scale: CGFloat
92+
if extent.width < extent.height {
93+
scale = size.width / extent.width
94+
} else {
95+
scale = size.height / extent.height
96+
}
97+
filter.scale = Float(scale)
98+
99+
let rescaled = filter.outputImage!
100+
101+
// the image has a DoD larger than the requested size so crop
102+
// it to the desired size
103+
return rescaled.cropped(to: CGRect(origin: .zero, size: size))
104+
}
105+
106+
/// Resample the image using Lanczos interpolation.
107+
static public func resampleLanczos(_ image: CIImage, to size: CGSize) -> CIImage {
108+
let filter = CIFilter.lanczosScaleTransform()
109+
let extent = image.extent.size
110+
111+
filter.inputImage = image
112+
113+
// set the aspect ratio to match the aspect ratio of the target
114+
filter.aspectRatio = aspectRatioForResample(image, size: size)
72115

73116
// that image is now the aspect ratio of the target and the size
74117
// of the shorter dimension
@@ -264,4 +307,105 @@ public enum MediaProcessing {
264307

265308
return ciImages
266309
}
310+
311+
static public func asProcessedSequence(
312+
_ asset: AVAsset, samplesPerSecond: Int,
313+
frameProcessing: (VideoFrame) throws -> VideoFrame = { $0 }
314+
) async throws -> ProcessedFrames {
315+
return try await asProcessedSequence(
316+
asset, maxFrames: Int.max, targetFPS: { _ in Double(samplesPerSecond) },
317+
frameProcessing: frameProcessing)
318+
}
319+
320+
static public func asProcessedSequence(
321+
_ asset: AVAsset, maxFrames: Int, targetFPS: (CMTime) -> Double,
322+
frameProcessing: (VideoFrame) throws -> VideoFrame = { $0 }
323+
) async throws -> ProcessedFrames {
324+
// Use AVAssetImageGenerator to extract frames
325+
let generator = AVAssetImageGenerator(asset: asset)
326+
generator.appliesPreferredTrackTransform = true
327+
generator.requestedTimeToleranceBefore = .zero
328+
generator.requestedTimeToleranceAfter = .zero
329+
330+
guard let duration = try? await asset.load(.duration) else {
331+
throw NSError(
332+
domain: "MediaProcessing", code: -1,
333+
userInfo: [NSLocalizedDescriptionKey: "Failed to load the asset's duration"])
334+
}
335+
let fps = targetFPS(duration)
336+
// Note: the round was not present in `asCIImageSequence`, so we may now be passing 1 more frame to Qwen depending on video duration.
337+
let estimatedFrames = Int(round(fps * duration.seconds))
338+
var desiredFrames = min(estimatedFrames, maxFrames)
339+
let finalFrameCount = max(desiredFrames, 1)
340+
341+
let sampledTimeValues = MLXArray.linspace(
342+
0, duration.value, count: Int(finalFrameCount)
343+
).asArray(Int64.self)
344+
345+
// Construct a CMTime using the sampled CMTimeValue's and the asset's timescale
346+
let timescale = duration.timescale
347+
let sampledTimes = sampledTimeValues.map { CMTime(value: $0, timescale: timescale) }
348+
349+
// Collect the frames
350+
var ciImages: [CIImage] = []
351+
var timestamps: [CMTime] = []
352+
353+
var frames: [VideoFrame] = []
354+
355+
for await result in await generator.images(for: sampledTimes) {
356+
switch result {
357+
case .success(requestedTime: let requested, let image, actualTime: let actual):
358+
let ciImage = CIImage(
359+
cgImage: image, options: [.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!])
360+
let frame = try frameProcessing(.init(frame: ciImage, timeStamp: actual))
361+
ciImages.append(frame.frame)
362+
timestamps.append(frame.timeStamp)
363+
case .failure(requestedTime: let requested, let error):
364+
break
365+
}
366+
}
367+
368+
let framesAsArrays = ciImages.map { $0.asMLXArray() }
369+
return ProcessedFrames(
370+
frames: framesAsArrays,
371+
timestamps: timestamps,
372+
totalDuration: duration
373+
)
374+
}
375+
}
376+
377+
// MARK: - Convenience
378+
379+
extension CIImage {
380+
public enum ResamplingMethod {
381+
case bicubic
382+
case lanczos
383+
}
384+
385+
public func resampled(to size: CGSize, method: ResamplingMethod = .bicubic) -> CIImage {
386+
switch method {
387+
case .bicubic:
388+
return MediaProcessing.resampleBicubic(self, to: size)
389+
case .lanczos:
390+
return MediaProcessing.resampleLanczos(self, to: size)
391+
}
392+
}
393+
394+
public func toSRGB() -> CIImage {
395+
return MediaProcessing.inSRGBToneCurveSpace(self)
396+
}
397+
398+
public func toLinear() -> CIImage {
399+
return MediaProcessing.inLinearToneCurveSpace(self)
400+
}
401+
402+
public func normalized(mean: (CGFloat, CGFloat, CGFloat), std: (CGFloat, CGFloat, CGFloat))
403+
-> CIImage
404+
{
405+
return MediaProcessing.normalize(self, mean: mean, std: std)
406+
}
407+
408+
public func asMLXArray(colorSpace: CGColorSpace? = nil) -> MLXArray {
409+
return MediaProcessing.asMLXArray(self, colorSpace: colorSpace)
410+
}
267411
}

0 commit comments

Comments
 (0)