@@ -21,7 +21,7 @@ let imageSystemPrompt =
21
21
" You are an image understanding model capable of describing the salient features of any image. "
22
22
23
23
struct ContentView: View {
24
- @ State var prompt = " "
24
+
25
25
@State var llm = VLMEvaluator ( )
26
26
@Environment ( DeviceStat . self) private var deviceStat
27
27
@@ -200,14 +200,13 @@ struct ContentView: View {
200
200
. frame ( minHeight: 200 )
201
201
202
202
HStack {
203
- TextField ( " prompt " , text: $ prompt)
203
+ TextField ( " prompt " , text: Bindable ( llm ) . prompt)
204
204
. onSubmit ( generate)
205
205
. disabled ( llm. running)
206
206
#if os(visionOS)
207
207
. textFieldStyle( . roundedBorder)
208
208
#endif
209
- Button ( " generate " , action: generate)
210
- . disabled ( llm. running)
209
+ Button ( llm. running ? " stop " : " generate " , action: llm. running ? cancel : generate)
211
210
}
212
211
}
213
212
. onAppear {
@@ -251,7 +250,6 @@ struct ContentView: View {
251
250
}
252
251
}
253
252
. task {
254
- self . prompt = llm. modelConfiguration. defaultPrompt
255
253
_ = try ? await llm. load ( )
256
254
}
257
255
}
@@ -261,32 +259,36 @@ struct ContentView: View {
261
259
if let selectedImage = selectedImage {
262
260
#if os(iOS) || os(visionOS)
263
261
let ciImage = CIImage ( image: selectedImage)
264
- await llm. generate ( prompt : prompt , image: ciImage ?? CIImage ( ) , videoURL: nil )
262
+ llm. generate ( image: ciImage ?? CIImage ( ) , videoURL: nil )
265
263
#else
266
264
if let cgImage = selectedImage. cgImage (
267
265
forProposedRect: nil , context: nil , hints: nil )
268
266
{
269
267
let ciImage = CIImage ( cgImage: cgImage)
270
- await llm. generate ( prompt : prompt , image: ciImage, videoURL: nil )
268
+ llm. generate ( image: ciImage, videoURL: nil )
271
269
}
272
270
#endif
273
271
} else if let imageURL = currentImageURL {
274
272
do {
275
273
let ( data, _) = try await URLSession . shared. data ( from: imageURL)
276
274
if let ciImage = CIImage ( data: data) {
277
- await llm. generate ( prompt : prompt , image: ciImage, videoURL: nil )
275
+ llm. generate ( image: ciImage, videoURL: nil )
278
276
}
279
277
} catch {
280
278
print ( " Failed to load image: \( error. localizedDescription) " )
281
279
}
282
280
} else {
283
281
if let videoURL = selectedVideoURL {
284
- await llm. generate ( prompt : prompt , image: nil , videoURL: videoURL)
282
+ llm. generate ( image: nil , videoURL: videoURL)
285
283
}
286
284
}
287
285
}
288
286
}
289
287
288
+ private func cancel( ) {
289
+ llm. cancelGeneration ( )
290
+ }
291
+
290
292
#if os(macOS)
291
293
private func loadData( from url: URL) throws -> Data {
292
294
guard url. startAccessingSecurityScopedResource ( ) else {
@@ -326,6 +328,7 @@ class VLMEvaluator {
326
328
327
329
var running = false
328
330
331
+ var prompt = " "
329
332
var output = " "
330
333
var modelInfo = " "
331
334
var stat = " "
@@ -337,11 +340,10 @@ class VLMEvaluator {
337
340
/// parameters controlling the output – use values appropriate for the model selected above
338
341
let generateParameters = MLXLMCommon . GenerateParameters ( temperature: 0.7 , topP: 0.9 )
339
342
let maxTokens = 800
343
+ let updateInterval = 0.25
340
344
341
- /// update the display every N tokens -- 4 looks like it updates continuously
342
- /// and is low overhead. observed ~15% reduction in tokens/s when updating
343
- /// on every token
344
- let displayEveryNTokens = 4
345
+ /// A task responsible for handling the generation process.
346
+ var generationTask : Task < Void , Error > ?
345
347
346
348
enum LoadState {
347
349
case idle
@@ -371,6 +373,7 @@ class VLMEvaluator {
371
373
context. model. numParameters ( )
372
374
}
373
375
376
+ self . prompt = modelConfiguration. defaultPrompt
374
377
self . modelInfo = " Loaded \( modelConfiguration. id) . Weights: \( numParams / ( 1024 * 1024 ) ) M "
375
378
loadState = . loaded( modelContainer)
376
379
return modelContainer
@@ -380,10 +383,8 @@ class VLMEvaluator {
380
383
}
381
384
}
382
385
383
- func generate( prompt: String , image: CIImage ? , videoURL: URL ? ) async {
384
- guard !running else { return }
386
+ private func generate( prompt: String , image: CIImage ? , videoURL: URL ? ) async {
385
387
386
- running = true
387
388
self . output = " "
388
389
389
390
do {
@@ -392,7 +393,8 @@ class VLMEvaluator {
392
393
// each time you generate you will get something new
393
394
MLXRandom . seed ( UInt64 ( Date . timeIntervalSinceReferenceDate * 1000 ) )
394
395
395
- let result = try await modelContainer. perform { context in
396
+ try await modelContainer. perform { ( context: ModelContext ) -> Void in
397
+
396
398
let images : [ UserInput . Image ] =
397
399
if let image {
398
400
[ UserInput . Image. ciImage ( image) ]
@@ -436,38 +438,61 @@ class VLMEvaluator {
436
438
}
437
439
var userInput = UserInput ( messages: messages, images: images, videos: videos)
438
440
userInput. processing. resize = . init( width: 448 , height: 448 )
439
- let input = try await context. processor. prepare ( input: userInput)
440
- return try MLXLMCommon . generate (
441
- input: input,
442
- parameters: generateParameters,
443
- context: context
444
- ) { tokens in
445
- // update the output -- this will make the view show the text as it generates
446
- if tokens. count % displayEveryNTokens == 0 {
447
- let text = context. tokenizer. decode ( tokens: tokens)
441
+
442
+ let lmInput = try await context. processor. prepare ( input: userInput)
443
+
444
+ let stream = try MLXLMCommon . generate (
445
+ input: lmInput, parameters: generateParameters, context: context)
446
+
447
+ var tokenCount = 0
448
+ var lastEmissionTime : Date = Date ( )
449
+ var chunks = " "
450
+
451
+ for await result in stream {
452
+ switch result {
453
+ case . chunk( let string) :
454
+ tokenCount += 1
455
+ if tokenCount >= maxTokens { await generationTask? . cancel ( ) }
456
+ let now = Date ( )
457
+ if now. timeIntervalSince ( lastEmissionTime) >= updateInterval {
458
+ lastEmissionTime = now
459
+ let text = chunks
460
+ chunks = " "
461
+ Task { @MainActor in
462
+ self . output += text
463
+ }
464
+ } else {
465
+ chunks += string
466
+ }
467
+ case . info( let info) :
448
468
Task { @MainActor in
449
- self . output = text
469
+ self . stat = " \( info . tokensPerSecond ) tokens/s "
450
470
}
451
471
}
452
-
453
- if tokens. count >= maxTokens {
454
- return . stop
455
- } else {
456
- return . more
457
- }
458
472
}
459
- }
460
473
461
- // update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
462
- if result . output != self . output {
463
- self . output = result . output
474
+ Task { @ MainActor in
475
+ self . output += chunks
476
+ }
464
477
}
465
- self . stat = " Tokens/second: \( String ( format: " %.3f " , result. tokensPerSecond) ) "
466
-
467
478
} catch {
468
479
output = " Failed: \( error) "
469
480
}
481
+ }
482
+
483
+ func generate( image: CIImage ? , videoURL: URL ? ) {
484
+ guard !running else { return }
485
+ let currentPrompt = prompt
486
+ prompt = " "
487
+ generationTask = Task {
488
+ running = true
489
+ await generate ( prompt: currentPrompt, image: image, videoURL: videoURL)
490
+ running = false
491
+ }
492
+ }
470
493
494
+ func cancelGeneration( ) {
495
+ generationTask? . cancel ( )
471
496
running = false
472
497
}
473
498
}
0 commit comments