Skip to content

chat example (command line) #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ public class LLMModelFactory: ModelFactory {
let modelDirectory = try await downloadModel(
hub: hub, configuration: configuration, progressHandler: progressHandler)

// load the generic config to unerstand which model and how to load the weights
// load the generic config to understand which model and how to load the weights
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
Expand Down
5 changes: 3 additions & 2 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ public func generate(
///
/// - Parameters:
/// - input: The input for the language model.
/// - cache: optional ``KVCache``
/// - parameters: The configuration options for token generation.
/// - context: The model context, including the model itself and associated tokenizer.
/// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`)
Expand Down Expand Up @@ -729,10 +730,10 @@ public func generate(
/// }
/// ```
public func generate(
input: LMInput, parameters: GenerateParameters, context: ModelContext
input: LMInput, cache: [KVCache]? = nil, parameters: GenerateParameters, context: ModelContext
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easily pass the optional cache in to the iterator

) throws -> AsyncStream<Generation> {
let iterator = try TokenIterator(
input: input, model: context.model, parameters: parameters)
input: input, model: context.model, cache: cache, parameters: parameters)
return generate(
input: input, context: context, iterator: iterator)
}
Expand Down
5 changes: 4 additions & 1 deletion Libraries/MLXLMCommon/KVCache.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? {
}

/// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11
public class KVCacheSimple: KVCache, Evaluatable {
public class KVCacheSimple: KVCache, Evaluatable, CustomDebugStringConvertible {
var keys: MLXArray?
var values: MLXArray?

Expand Down Expand Up @@ -97,4 +97,7 @@ public class KVCacheSimple: KVCache, Evaluatable {
)
}

public var debugDescription: String {
"\(String(describing: Self.self)) \(Unmanaged.passUnretained(self).toOpaque()), offset: \(offset), step: \(step), keys: \(keys?.shape.description ?? "-"), values: \(values?.shape.description ?? "-")"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a convenience I needed while debugging Qwen2

}
}
4 changes: 4 additions & 0 deletions Libraries/MLXLMCommon/UserInput.swift
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,14 @@ public struct UserInput: Sendable {
/// - Parameters:
/// - chat: structured content
/// - tools: optional tool specifications
/// - processing: optional processing to be applied to media
/// - additionalContext: optional context (model specific)
/// ### See Also
/// - ``Prompt-swift.enum/text(_:)``
/// - ``init(chat:tools:additionalContext:)``
public init(
chat: [Chat.Message],
processing: Processing = .init(),
tools: [ToolSpec]? = nil,
additionalContext: [String: Any]? = nil
) {
Expand All @@ -267,6 +269,8 @@ public struct UserInput: Sendable {
self.videos = chat.reduce(into: []) { result, message in
result.append(contentsOf: message.videos)
}

self.processing = processing
self.tools = tools
self.additionalContext = additionalContext
}
Expand Down
4 changes: 2 additions & 2 deletions Libraries/MLXVLM/Models/Qwen2VL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ private enum Language {
values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3)

let offset = cache?.offset ?? 0
let mask = mask?[0..., 0 ..< keys.dim(-2)]

queries = rotaryEmbedding(queries, offset: offset)
keys = rotaryEmbedding(keys, offset: offset)

if let cache {
(keys, values) = cache.update(keys: keys, values: values)
}

let mask = mask?[.ellipsis, 0 ..< keys.dim(-2)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that the dimensions were mismatched on a second message -- the keys dimension needs to be considered after the KVCache. In mlx-vlm it works out ok because:

  1. it looks like KVCache isn't used persistently
  2. the KVCache implementation doesn't window on the cache


let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
Expand Down
2 changes: 1 addition & 1 deletion Libraries/MLXVLM/VLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public class VLMModelFactory: ModelFactory {
let modelDirectory = try await downloadModel(
hub: hub, configuration: configuration, progressHandler: progressHandler)

// load the generic config to unerstand which model and how to load the weights
// load the generic config to understand which model and how to load the weights
let configurationURL = modelDirectory.appending(
component: "config.json"
)
Expand Down
222 changes: 222 additions & 0 deletions Tools/llm-tool/Chat.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
// Copyright © 2025 Apple Inc.

import ArgumentParser
import Foundation
import MLX
import MLXLLM
import MLXLMCommon
import MLXVLM

struct ChatCommand: AsyncParsableCommand {
static let configuration = CommandConfiguration(
commandName: "chat",
abstract: "interactive chat with model"
)

@OptionGroup var args: ModelArguments
@OptionGroup var memory: MemoryArguments
@OptionGroup var generate: GenerateArguments
@OptionGroup var media: MediaArguments

struct State {
var parameters: GenerateParameters
var processing: UserInput.Processing

var images: [UserInput.Image]
var videos: [UserInput.Video]

var chat: [Chat.Message]

var cache: [KVCache]

var printStats = false
}

@MainActor
mutating func run() async throws {
let defaultModel = MLXLLM.LLMRegistry.mistral7B4bit

// Load the model
let modelContainer = try await memory.start { [args] in
do {
return try await args.load(
defaultModel: defaultModel.name, modelFactory: LLMModelFactory.shared)
} catch ModelFactoryError.unsupportedModelType {
return try await args.load(
defaultModel: defaultModel.name, modelFactory: VLMModelFactory.shared)
}
}

// update the context/configuration with any command line parameters
await modelContainer.update { [generate] context in
generate.prepare(&context)
}

try await chat(modelContainer: modelContainer)
}

func chat(modelContainer: ModelContainer) async throws {
try await modelContainer.perform { context in
let parameters = generate.generateParameters
let initialState = State(
parameters: parameters,
processing: media.processing,
images: media.images, videos: media.videos,
chat: [.system(generate.system)],
cache: context.model.newCache(parameters: parameters))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want a follow up PR for #276 to add the KVCache, but without the change to generate() it is hard to use.


var state = initialState

print("> ", terminator: "")
while let line = readLine() {
if line.hasPrefix("/") {
// handle commands
switch command(line: line, state: &state) {
case .exit:
return
case .reset:
state = initialState
state.cache = context.model.newCache(parameters: parameters)
continue
case .inference:
// continue and run inference
break
case .handled:
print("\n\n> ", terminator: "")
continue
}
} else {
// chat input
state.chat.append(.user(line, images: state.images, videos: state.videos))
}

// consume the media, if any
state.images.removeAll()
state.videos.removeAll()

// convert UserInput to LMInput
let userInput = UserInput(chat: state.chat, processing: state.processing)
let input = try await context.processor.prepare(input: userInput)

// generate the output
var output = ""
var result: GenerateCompletionInfo?
for await item in try MLXLMCommon.generate(
input: input, cache: state.cache, parameters: parameters, context: context
) {
switch item {
case .chunk(let string):
output += string
print(string, terminator: "")
case .info(let info):
result = info
}
}

// add the assistant response to the chat messages
state.chat.append(.assistant(output))

if state.printStats, let result {
print(
"\ntime to first token: \(result.promptTime.formatted()) tps: \(result.tokensPerSecond.formatted())"
)
}
print("\n\n> ", terminator: "")
}
}
}

enum CommandDisposition {
case exit
case reset
case inference
case handled
}

func help() {
print(
"""
/help -- this message
/quit -- terminate the chat
/memory -- print memory stats
/stats -- toggle token stats
/reset -- reset the chat session to initial state
/image [pathOrURL] -- provide an image
/video [pathOrURL] -- provide a video
/again -- rerun inference for last response
/parameters -- print generation parametes
/temperature [number] -- set the sampling temperature
/topP [number] -- set the top p sampling
/maxTokens [number] -- set the maximum number of tokens to generate or no number to remove limit
""")
}

func command(line: String, state: inout State) -> CommandDisposition {
let command = line.split(separator: " ")[0]
let rest = String(
line.dropFirst(command.count).trimmingCharacters(in: .whitespaces))

func url(_ string: String) -> URL? {
if string.hasPrefix("/") {
URL(filePath: string)
} else {
URL(string: string)
}
}

switch command {
case "/help":
help()

case "/quit":
return .exit

case "/memory":
let memory = GPU.snapshot()
print("Memory size: \(GPU.memoryLimit / 1024)K")
print("Cache size: \(GPU.cacheLimit / 1024)K")
print(memory.description)

case "/stats":
state.printStats.toggle()
print("Token stats: \(state.printStats ? "ON" : "OFF")")

case "/reset":
return .reset

case "/image":
if let url = url(rest) {
state.images.append(UserInput.Image.url(url))
}
case "/video":
if let url = url(rest) {
state.videos.append(UserInput.Video.url(url))
}

case "/again":
state.chat.removeLast()
return .inference

case "/parameters":
print(state.parameters)
case "/temperature":
if let value = Float(rest) {
state.parameters.temperature = value
print(state.parameters)
}
case "/topP":
if let value = Float(rest) {
state.parameters.topP = value
print(state.parameters)
}
case "/maxTokens":
state.parameters.maxTokens = Int(rest)
print(state.parameters)

default:
help()
}

return .handled
}
}
Loading