Skip to content

Commit 560d4f4

Browse files
committed
Initial implementation for structured chat messages
1 parent 289bb67 commit 560d4f4

File tree

6 files changed

+161
-19
lines changed

6 files changed

+161
-19
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,23 +199,37 @@ private struct LLMUserInputProcessor: UserInputProcessor {
199199

200200
let tokenizer: Tokenizer
201201
let configuration: ModelConfiguration
202+
let messageGenerator: MessageGenerator
202203

203-
internal init(tokenizer: any Tokenizer, configuration: ModelConfiguration) {
204+
internal init(
205+
tokenizer: any Tokenizer, configuration: ModelConfiguration,
206+
messageGenerator: MessageGenerator
207+
) {
204208
self.tokenizer = tokenizer
205209
self.configuration = configuration
210+
self.messageGenerator = messageGenerator
206211
}
207212

208213
func prepare(input: UserInput) throws -> LMInput {
214+
let messages =
215+
switch input.prompt {
216+
case .text(let text):
217+
messageGenerator.generate(messages: [.user(text)])
218+
case .messages(let messages):
219+
messages
220+
case .chat(let messages):
221+
messageGenerator.generate(messages: messages)
222+
}
223+
209224
do {
210-
let messages = input.prompt.asMessages()
211225
let promptTokens = try tokenizer.applyChatTemplate(
212226
messages: messages, tools: input.tools, additionalContext: input.additionalContext)
213227
return LMInput(tokens: MLXArray(promptTokens))
214228
} catch {
215229
// #150 -- it might be a TokenizerError.chatTemplate("No chat template was specified")
216230
// but that is not public so just fall back to text
217-
let prompt = input.prompt
218-
.asMessages()
231+
let prompt =
232+
messages
219233
.compactMap { $0["content"] as? String }
220234
.joined(separator: ". ")
221235
let promptTokens = tokenizer.encode(text: prompt)
@@ -273,7 +287,9 @@ public class LLMModelFactory: ModelFactory {
273287

274288
return .init(
275289
configuration: configuration, model: model,
276-
processor: LLMUserInputProcessor(tokenizer: tokenizer, configuration: configuration),
290+
processor: LLMUserInputProcessor(
291+
tokenizer: tokenizer, configuration: configuration,
292+
messageGenerator: DefaultMessageGenerator()),
277293
tokenizer: tokenizer)
278294
}
279295

Libraries/MLXLMCommon/Chat.swift

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
public enum Chat {
4+
public struct Message {
5+
/// The role of the message sender.
6+
public let role: Role
7+
8+
/// The content of the message.
9+
public let content: String
10+
11+
/// Array of image data associated with the message.
12+
public let images: [UserInput.Image]
13+
14+
/// Array of video data associated with the message.
15+
public let videos: [UserInput.Video]
16+
17+
public static func system(
18+
_ content: String, images: [UserInput.Image] = [], videos: [UserInput.Video] = []
19+
) -> Self {
20+
Self(role: .system, content: content, images: images, videos: videos)
21+
}
22+
23+
public static func assistant(
24+
_ content: String, images: [UserInput.Image] = [], videos: [UserInput.Video] = []
25+
) -> Self {
26+
Self(role: .assistant, content: content, images: images, videos: videos)
27+
}
28+
29+
public static func user(
30+
_ content: String, images: [UserInput.Image] = [], videos: [UserInput.Video] = []
31+
) -> Self {
32+
Self(role: .user, content: content, images: images, videos: videos)
33+
}
34+
35+
public enum Role: String {
36+
case user
37+
case assistant
38+
case system
39+
}
40+
}
41+
}
42+
43+
public protocol MessageGenerator {
44+
/// Returns [String: Any] aka Message
45+
func generate(message: Chat.Message) -> Message
46+
}
47+
48+
extension MessageGenerator {
49+
/// Returns array of [String: Any] aka Message
50+
public func generate(messages: [Chat.Message]) -> [Message] {
51+
var rawMessages: [Message] = []
52+
53+
for message in messages {
54+
let raw = generate(message: message)
55+
rawMessages.append(raw)
56+
}
57+
58+
return rawMessages
59+
}
60+
}
61+
62+
public struct DefaultMessageGenerator: MessageGenerator {
63+
public init() {}
64+
65+
public func generate(message: Chat.Message) -> Message {
66+
[
67+
"role": message.role.rawValue,
68+
"content": message.content,
69+
]
70+
}
71+
}
72+
73+
public struct Qwen2VLMessageGenerator: MessageGenerator {
74+
public init() {}
75+
76+
public func generate(message: Chat.Message) -> Message {
77+
[
78+
"role": message.role.rawValue,
79+
"content": [
80+
["type": "text", "text": message.content]
81+
]
82+
// Messages format for Qwen 2 VL, Qwen 2.5 VL. May need to be adapted for other models.
83+
+ message.images.map { _ in
84+
["type": "image"]
85+
}
86+
+ message.videos.map { _ in
87+
["type": "video"]
88+
},
89+
]
90+
}
91+
}

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,16 @@ public struct UserInput: Sendable {
1717
public enum Prompt: Sendable, CustomStringConvertible {
1818
case text(String)
1919
case messages([Message])
20-
21-
public func asMessages() -> [Message] {
22-
switch self {
23-
case .text(let text):
24-
return [["role": "user", "content": text]]
25-
case .messages(let messages):
26-
return messages
27-
}
28-
}
20+
case chat([Chat.Message])
2921

3022
public var description: String {
3123
switch self {
3224
case .text(let text):
3325
return text
3426
case .messages(let messages):
3527
return messages.map { $0.description }.joined(separator: "\n")
28+
case .chat(let messages):
29+
return messages.map(\.content).joined(separator: "\n")
3630
}
3731
}
3832
}
@@ -156,6 +150,18 @@ public struct UserInput: Sendable {
156150
self.additionalContext = additionalContext
157151
}
158152

153+
public init(
154+
messages: [Chat.Message], images: [Image] = [Image](), videos: [Video] = [Video](),
155+
tools: [ToolSpec]? = nil,
156+
additionalContext: [String: Any]? = nil
157+
) {
158+
self.prompt = .chat(messages)
159+
self.images = images
160+
self.videos = videos
161+
self.tools = tools
162+
self.additionalContext = additionalContext
163+
}
164+
159165
public init(
160166
prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init(),
161167
tools: [ToolSpec]? = nil, additionalContext: [String: Any]? = nil

Libraries/MLXVLM/Models/Idefics3.swift

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,10 +817,19 @@ public class Idefics3Processor: UserInputProcessor {
817817
self.tokenizer = tokenizer
818818
}
819819

820-
public func prepare(input: UserInput) throws -> LMInput {
821-
822-
let prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
820+
private func prompt(from userInput: UserInput) -> String {
821+
switch userInput.prompt {
822+
case .text(let text):
823+
text
824+
case .messages(let messages):
825+
messages.last?["content"] as? String ?? ""
826+
case .chat(let messages):
827+
messages.last?.content ?? ""
828+
}
829+
}
823830

831+
public func prepare(input: UserInput) throws -> LMInput {
832+
let prompt = prompt(from: input)
824833
if input.images.isEmpty {
825834
// No image scenario
826835
let tokens = try tokenizer.encode(text: prompt)

Libraries/MLXVLM/Models/Paligemma.swift

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
478478
}
479479

480480
// this doesn't have a chat template so just use the last message.
481-
var prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
481+
var prompt = prompt(from: input)
482482

483483
// based on transformers/processing_paligemma
484484
let count = input.images.count * config.imageSequenceLength
@@ -495,6 +495,17 @@ public class PaligGemmaProcessor: UserInputProcessor {
495495
return LMInput(text: .init(tokens: promptArray, mask: mask), image: .init(pixels: pixels))
496496
}
497497

498+
private func prompt(from userInput: UserInput) -> String {
499+
switch userInput.prompt {
500+
case .text(let text):
501+
text
502+
case .messages(let messages):
503+
messages.last?["content"] as? String ?? ""
504+
case .chat(let messages):
505+
messages.last?.content ?? ""
506+
}
507+
}
508+
498509
}
499510

500511
// MARK: - Model

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,16 @@ public class Qwen2VLProcessor: UserInputProcessor {
696696
}
697697

698698
public func prepare(input: UserInput) async throws -> LMInput {
699-
let messages = input.prompt.asMessages()
699+
let generator = Qwen2VLMessageGenerator()
700+
let messages =
701+
switch input.prompt {
702+
case .text(let text):
703+
generator.generate(messages: [.user(text)])
704+
case .messages(let messages):
705+
messages
706+
case .chat(let messages):
707+
generator.generate(messages: messages)
708+
}
700709
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
701710

702711
// Text-only input

0 commit comments

Comments
 (0)