Skip to content

Commit 289bb67

Browse files
authored
Updated LLMEval and VLMEval with the new AsyncStream token generation. (#256)
* Updated LLMEval and VLMEval with the new AsyncStream token generation.
1 parent ec9523b commit 289bb67

File tree

2 files changed

+124
-79
lines changed

2 files changed

+124
-79
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ struct ContentView: View {
1313
@Environment(DeviceStat.self) private var deviceStat
1414

1515
@State var llm = LLMEvaluator()
16-
@State var prompt = "What's the current weather in Paris?"
1716

1817
enum displayStyle: String, CaseIterable, Identifiable {
1918
case plain, markdown
@@ -83,14 +82,13 @@ struct ContentView: View {
8382
}
8483

8584
HStack {
86-
TextField("prompt", text: $prompt)
85+
TextField("prompt", text: Bindable(llm).prompt)
8786
.onSubmit(generate)
8887
.disabled(llm.running)
8988
#if os(visionOS)
9089
.textFieldStyle(.roundedBorder)
9190
#endif
92-
Button("generate", action: generate)
93-
.disabled(llm.running)
91+
Button(llm.running ? "stop" : "generate", action: llm.running ? cancel : generate)
9492
}
9593
}
9694
#if os(visionOS)
@@ -130,17 +128,19 @@ struct ContentView: View {
130128

131129
}
132130
.task {
133-
self.prompt = llm.modelConfiguration.defaultPrompt
134131
// pre-load the weights on launch to speed up the first generation
135132
_ = try? await llm.load()
136133
}
137134
}
138135

139136
private func generate() {
140-
Task {
141-
await llm.generate(prompt: prompt)
142-
}
137+
llm.generate()
143138
}
139+
140+
private func cancel() {
141+
llm.cancelGeneration()
142+
}
143+
144144
private func copyToClipboard(_ string: String) {
145145
#if os(macOS)
146146
NSPasteboard.general.clearContents()
@@ -159,22 +159,22 @@ class LLMEvaluator {
159159

160160
var includeWeatherTool = false
161161

162+
var prompt = ""
162163
var output = ""
163164
var modelInfo = ""
164165
var stat = ""
165166

166167
/// This controls which model loads. `qwen2_5_1_5b` is one of the smaller ones, so this will fit on
167168
/// more devices.
168-
let modelConfiguration = ModelRegistry.qwen2_5_1_5b
169+
let modelConfiguration = LLMRegistry.qwen2_5_1_5b
169170

170171
/// parameters controlling the output
171172
let generateParameters = GenerateParameters(temperature: 0.6)
172173
let maxTokens = 240
174+
let updateInterval = 0.25
173175

174-
/// update the display every N tokens -- 4 looks like it updates continuously
175-
/// and is low overhead. observed ~15% reduction in tokens/s when updating
176-
/// on every token
177-
let displayEveryNTokens = 4
176+
/// A task responsible for handling the generation process.
177+
var generationTask: Task<Void, Error>?
178178

179179
enum LoadState {
180180
case idle
@@ -227,6 +227,7 @@ class LLMEvaluator {
227227
context.model.numParameters()
228228
}
229229

230+
self.prompt = modelConfiguration.defaultPrompt
230231
self.modelInfo =
231232
"Loaded \(modelConfiguration.id). Weights: \(numParams / (1024*1024))M"
232233
loadState = .loaded(modelContainer)
@@ -237,53 +238,72 @@ class LLMEvaluator {
237238
}
238239
}
239240

240-
func generate(prompt: String) async {
241-
guard !running else { return }
241+
private func generate(prompt: String) async {
242242

243-
running = true
244243
self.output = ""
244+
let userInput = UserInput(prompt: prompt)
245245

246246
do {
247247
let modelContainer = try await load()
248248

249249
// each time you generate you will get something new
250250
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
251251

252-
let result = try await modelContainer.perform { context in
253-
let input = try await context.processor.prepare(
254-
input: .init(
255-
messages: [
256-
["role": "system", "content": "You are a helpful assistant."],
257-
["role": "user", "content": prompt],
258-
], tools: includeWeatherTool ? [currentWeatherToolSpec] : nil))
259-
return try MLXLMCommon.generate(
260-
input: input, parameters: generateParameters, context: context
261-
) { tokens in
262-
// Show the text in the view as it generates
263-
if tokens.count % displayEveryNTokens == 0 {
264-
let text = context.tokenizer.decode(tokens: tokens)
252+
try await modelContainer.perform { (context: ModelContext) -> Void in
253+
let lmInput = try await context.processor.prepare(input: userInput)
254+
let stream = try MLXLMCommon.generate(
255+
input: lmInput, parameters: generateParameters, context: context)
256+
257+
var tokenCount = 0
258+
var lastEmissionTime: Date = Date()
259+
var chunks = ""
260+
261+
for await result in stream {
262+
switch result {
263+
case .chunk(let string):
264+
tokenCount += 1
265+
if tokenCount >= maxTokens { await generationTask?.cancel() }
266+
let now = Date()
267+
if now.timeIntervalSince(lastEmissionTime) >= updateInterval {
268+
lastEmissionTime = now
269+
let text = chunks
270+
chunks = ""
271+
Task { @MainActor in
272+
self.output += text
273+
}
274+
} else {
275+
chunks += string
276+
}
277+
case .info(let info):
265278
Task { @MainActor in
266-
self.output = text
279+
self.stat = "\(info.tokensPerSecond) tokens/s"
267280
}
268281
}
269-
if tokens.count >= maxTokens {
270-
return .stop
271-
} else {
272-
return .more
273-
}
274282
}
275-
}
276283

277-
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
278-
if result.output != self.output {
279-
self.output = result.output
284+
Task { @MainActor in
285+
self.output += chunks
286+
}
280287
}
281-
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
282288

283289
} catch {
284290
output = "Failed: \(error)"
285291
}
292+
}
293+
294+
func generate() {
295+
guard !running else { return }
296+
let currentPrompt = prompt
297+
prompt = ""
298+
generationTask = Task {
299+
running = true
300+
await generate(prompt: currentPrompt)
301+
running = false
302+
}
303+
}
286304

305+
func cancelGeneration() {
306+
generationTask?.cancel()
287307
running = false
288308
}
289309
}

Applications/VLMEval/ContentView.swift

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ let imageSystemPrompt =
2121
"You are an image understanding model capable of describing the salient features of any image."
2222

2323
struct ContentView: View {
24-
@State var prompt = ""
24+
2525
@State var llm = VLMEvaluator()
2626
@Environment(DeviceStat.self) private var deviceStat
2727

@@ -200,14 +200,13 @@ struct ContentView: View {
200200
.frame(minHeight: 200)
201201

202202
HStack {
203-
TextField("prompt", text: $prompt)
203+
TextField("prompt", text: Bindable(llm).prompt)
204204
.onSubmit(generate)
205205
.disabled(llm.running)
206206
#if os(visionOS)
207207
.textFieldStyle(.roundedBorder)
208208
#endif
209-
Button("generate", action: generate)
210-
.disabled(llm.running)
209+
Button(llm.running ? "stop" : "generate", action: llm.running ? cancel : generate)
211210
}
212211
}
213212
.onAppear {
@@ -251,7 +250,6 @@ struct ContentView: View {
251250
}
252251
}
253252
.task {
254-
self.prompt = llm.modelConfiguration.defaultPrompt
255253
_ = try? await llm.load()
256254
}
257255
}
@@ -261,32 +259,36 @@ struct ContentView: View {
261259
if let selectedImage = selectedImage {
262260
#if os(iOS) || os(visionOS)
263261
let ciImage = CIImage(image: selectedImage)
264-
await llm.generate(prompt: prompt, image: ciImage ?? CIImage(), videoURL: nil)
262+
llm.generate(image: ciImage ?? CIImage(), videoURL: nil)
265263
#else
266264
if let cgImage = selectedImage.cgImage(
267265
forProposedRect: nil, context: nil, hints: nil)
268266
{
269267
let ciImage = CIImage(cgImage: cgImage)
270-
await llm.generate(prompt: prompt, image: ciImage, videoURL: nil)
268+
llm.generate(image: ciImage, videoURL: nil)
271269
}
272270
#endif
273271
} else if let imageURL = currentImageURL {
274272
do {
275273
let (data, _) = try await URLSession.shared.data(from: imageURL)
276274
if let ciImage = CIImage(data: data) {
277-
await llm.generate(prompt: prompt, image: ciImage, videoURL: nil)
275+
llm.generate(image: ciImage, videoURL: nil)
278276
}
279277
} catch {
280278
print("Failed to load image: \(error.localizedDescription)")
281279
}
282280
} else {
283281
if let videoURL = selectedVideoURL {
284-
await llm.generate(prompt: prompt, image: nil, videoURL: videoURL)
282+
llm.generate(image: nil, videoURL: videoURL)
285283
}
286284
}
287285
}
288286
}
289287

288+
private func cancel() {
289+
llm.cancelGeneration()
290+
}
291+
290292
#if os(macOS)
291293
private func loadData(from url: URL) throws -> Data {
292294
guard url.startAccessingSecurityScopedResource() else {
@@ -326,6 +328,7 @@ class VLMEvaluator {
326328

327329
var running = false
328330

331+
var prompt = ""
329332
var output = ""
330333
var modelInfo = ""
331334
var stat = ""
@@ -337,11 +340,10 @@ class VLMEvaluator {
337340
/// parameters controlling the output – use values appropriate for the model selected above
338341
let generateParameters = MLXLMCommon.GenerateParameters(temperature: 0.7, topP: 0.9)
339342
let maxTokens = 800
343+
let updateInterval = 0.25
340344

341-
/// update the display every N tokens -- 4 looks like it updates continuously
342-
/// and is low overhead. observed ~15% reduction in tokens/s when updating
343-
/// on every token
344-
let displayEveryNTokens = 4
345+
/// A task responsible for handling the generation process.
346+
var generationTask: Task<Void, Error>?
345347

346348
enum LoadState {
347349
case idle
@@ -371,6 +373,7 @@ class VLMEvaluator {
371373
context.model.numParameters()
372374
}
373375

376+
self.prompt = modelConfiguration.defaultPrompt
374377
self.modelInfo = "Loaded \(modelConfiguration.id). Weights: \(numParams / (1024*1024))M"
375378
loadState = .loaded(modelContainer)
376379
return modelContainer
@@ -380,10 +383,8 @@ class VLMEvaluator {
380383
}
381384
}
382385

383-
func generate(prompt: String, image: CIImage?, videoURL: URL?) async {
384-
guard !running else { return }
386+
private func generate(prompt: String, image: CIImage?, videoURL: URL?) async {
385387

386-
running = true
387388
self.output = ""
388389

389390
do {
@@ -392,7 +393,8 @@ class VLMEvaluator {
392393
// each time you generate you will get something new
393394
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
394395

395-
let result = try await modelContainer.perform { context in
396+
try await modelContainer.perform { (context: ModelContext) -> Void in
397+
396398
let images: [UserInput.Image] =
397399
if let image {
398400
[UserInput.Image.ciImage(image)]
@@ -436,38 +438,61 @@ class VLMEvaluator {
436438
}
437439
var userInput = UserInput(messages: messages, images: images, videos: videos)
438440
userInput.processing.resize = .init(width: 448, height: 448)
439-
let input = try await context.processor.prepare(input: userInput)
440-
return try MLXLMCommon.generate(
441-
input: input,
442-
parameters: generateParameters,
443-
context: context
444-
) { tokens in
445-
// update the output -- this will make the view show the text as it generates
446-
if tokens.count % displayEveryNTokens == 0 {
447-
let text = context.tokenizer.decode(tokens: tokens)
441+
442+
let lmInput = try await context.processor.prepare(input: userInput)
443+
444+
let stream = try MLXLMCommon.generate(
445+
input: lmInput, parameters: generateParameters, context: context)
446+
447+
var tokenCount = 0
448+
var lastEmissionTime: Date = Date()
449+
var chunks = ""
450+
451+
for await result in stream {
452+
switch result {
453+
case .chunk(let string):
454+
tokenCount += 1
455+
if tokenCount >= maxTokens { await generationTask?.cancel() }
456+
let now = Date()
457+
if now.timeIntervalSince(lastEmissionTime) >= updateInterval {
458+
lastEmissionTime = now
459+
let text = chunks
460+
chunks = ""
461+
Task { @MainActor in
462+
self.output += text
463+
}
464+
} else {
465+
chunks += string
466+
}
467+
case .info(let info):
448468
Task { @MainActor in
449-
self.output = text
469+
self.stat = "\(info.tokensPerSecond) tokens/s"
450470
}
451471
}
452-
453-
if tokens.count >= maxTokens {
454-
return .stop
455-
} else {
456-
return .more
457-
}
458472
}
459-
}
460473

461-
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
462-
if result.output != self.output {
463-
self.output = result.output
474+
Task { @MainActor in
475+
self.output += chunks
476+
}
464477
}
465-
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
466-
467478
} catch {
468479
output = "Failed: \(error)"
469480
}
481+
}
482+
483+
func generate(image: CIImage?, videoURL: URL?) {
484+
guard !running else { return }
485+
let currentPrompt = prompt
486+
prompt = ""
487+
generationTask = Task {
488+
running = true
489+
await generate(prompt: currentPrompt, image: image, videoURL: videoURL)
490+
running = false
491+
}
492+
}
470493

494+
func cancelGeneration() {
495+
generationTask?.cancel()
471496
running = false
472497
}
473498
}

0 commit comments

Comments
 (0)