Skip to content

Implement Structured Chat Messages #257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ jobs:
brew install swift-format
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
- run:
name: Run Tests (Xcode, macOS)
command: |
xcodebuild -version
xcrun --show-sdk-build-version
swift --version
find . -name Package.resolved -exec rm {} \;
xcodebuild test -scheme mlx-libraries-Package -destination 'platform=OS X'
- run:
name: Build Examples
command: |
Expand Down
13 changes: 11 additions & 2 deletions Applications/VLMEval/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,22 @@ class VLMEvaluator {
if !images.isEmpty || !videos.isEmpty {
[
[
"role": "user",
"role": "system",
"content": [
[
"type": "text",
"text": videoURL != nil
? videoSystemPrompt : imageSystemPrompt,
]
],
],
[
"role": "user",
"content": [
[
"type": "text",
"text": prompt,
]
]
// Messages format for Qwen 2 VL, Qwen 2.5 VL. May need to be adapted for other models.
+ images.map { _ in
Expand All @@ -427,7 +436,7 @@ class VLMEvaluator {
+ videos.map { _ in
["type": "video"]
},
]
],
]
} else {
[
Expand Down
18 changes: 13 additions & 5 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,29 @@ private struct LLMUserInputProcessor: UserInputProcessor {

let tokenizer: Tokenizer
let configuration: ModelConfiguration
let messageGenerator: MessageGenerator

internal init(tokenizer: any Tokenizer, configuration: ModelConfiguration) {
internal init(
tokenizer: any Tokenizer, configuration: ModelConfiguration,
messageGenerator: MessageGenerator
) {
self.tokenizer = tokenizer
self.configuration = configuration
self.messageGenerator = messageGenerator
}

func prepare(input: UserInput) throws -> LMInput {
let messages = messageGenerator.generate(from: input)

do {
let messages = input.prompt.asMessages()
let promptTokens = try tokenizer.applyChatTemplate(
messages: messages, tools: input.tools, additionalContext: input.additionalContext)
return LMInput(tokens: MLXArray(promptTokens))
} catch {
// #150 -- it might be a TokenizerError.chatTemplate("No chat template was specified")
// but that is not public so just fall back to text
let prompt = input.prompt
.asMessages()
let prompt =
messages
.compactMap { $0["content"] as? String }
.joined(separator: ". ")
let promptTokens = tokenizer.encode(text: prompt)
Expand Down Expand Up @@ -273,7 +279,9 @@ public class LLMModelFactory: ModelFactory {

return .init(
configuration: configuration, model: model,
processor: LLMUserInputProcessor(tokenizer: tokenizer, configuration: configuration),
processor: LLMUserInputProcessor(
tokenizer: tokenizer, configuration: configuration,
messageGenerator: DefaultMessageGenerator()),
tokenizer: tokenizer)
}

Expand Down
114 changes: 114 additions & 0 deletions Libraries/MLXLMCommon/Chat.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright © 2025 Apple Inc.

public enum Chat {
public struct Message {
/// The role of the message sender.
public var role: Role

/// The content of the message.
public var content: String

/// Array of image data associated with the message.
public var images: [UserInput.Image]

/// Array of video data associated with the message.
public var videos: [UserInput.Video]

public init(
role: Role, content: String, images: [UserInput.Image] = [],
videos: [UserInput.Video] = []
) {
self.role = role
self.content = content
self.images = images
self.videos = videos
}

public static func system(
_ content: String, images: [UserInput.Image] = [], videos: [UserInput.Video] = []
) -> Self {
Self(role: .system, content: content, images: images, videos: videos)
}

public static func assistant(
_ content: String, images: [UserInput.Image] = [], videos: [UserInput.Video] = []
) -> Self {
Self(role: .assistant, content: content, images: images, videos: videos)
}

public static func user(
_ content: String, images: [UserInput.Image] = [], videos: [UserInput.Video] = []
) -> Self {
Self(role: .user, content: content, images: images, videos: videos)
}

public enum Role: String {
case user
case assistant
case system
}
}
}

/// Protocol for something that can convert structured
/// ``Chat.Message`` into model specific ``Message``
/// (raw dictionary) format.
///
/// Typically this is owned and used by a ``UserInputProcessor``:
///
/// ```swift
/// public func prepare(input: UserInput) async throws -> LMInput {
/// let messages = Qwen2VLMessageGenerator().generate(from: input)
/// ...
/// ```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some documentation

public protocol MessageGenerator {

/// Returns `[String: Any]` aka ``Message``.
func generate(message: Chat.Message) -> Message
}

extension MessageGenerator {
/// Returns array of `[String: Any]` aka ``Message``
public func generate(messages: [Chat.Message]) -> [Message] {
var rawMessages: [Message] = []

for message in messages {
let raw = generate(message: message)
rawMessages.append(raw)
}

return rawMessages
}

/// Generates messages from the input.
public func generate(from input: UserInput) -> [Message] {
switch input.prompt {
case .text(let text):
generate(messages: [.user(text)])
case .messages(let messages):
messages
case .chat(let messages):
generate(messages: messages)
}
}
}

/// Default implementation of ``MessageGenerator`` that produces a
/// `role` and `content`.
///
/// ```swift
/// [
/// "role": message.role.rawValue,
/// "content": message.content,
/// ]
/// ```
public struct DefaultMessageGenerator: MessageGenerator {
public init() {}

public func generate(message: Chat.Message) -> Message {
[
"role": message.role.rawValue,
"content": message.content,
]
}
}
Loading