-
Notifications
You must be signed in to change notification settings - Fork 219
chat example (command line) #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ public func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? { | |
} | ||
|
||
/// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11 | ||
public class KVCacheSimple: KVCache, Evaluatable { | ||
public class KVCacheSimple: KVCache, Evaluatable, CustomDebugStringConvertible { | ||
var keys: MLXArray? | ||
var values: MLXArray? | ||
|
||
|
@@ -97,4 +97,7 @@ public class KVCacheSimple: KVCache, Evaluatable { | |
) | ||
} | ||
|
||
public var debugDescription: String { | ||
"\(String(describing: Self.self)) \(Unmanaged.passUnretained(self).toOpaque()), offset: \(offset), step: \(step), keys: \(keys?.shape.description ?? "-"), values: \(values?.shape.description ?? "-")" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a convenience I needed while debugging Qwen2 |
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,15 +103,15 @@ private enum Language { | |
values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) | ||
|
||
let offset = cache?.offset ?? 0 | ||
let mask = mask?[0..., 0 ..< keys.dim(-2)] | ||
|
||
queries = rotaryEmbedding(queries, offset: offset) | ||
keys = rotaryEmbedding(keys, offset: offset) | ||
|
||
if let cache { | ||
(keys, values) = cache.update(keys: keys, values: values) | ||
} | ||
|
||
let mask = mask?[.ellipsis, 0 ..< keys.dim(-2)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found that the dimensions were mismatched on a second message -- the keys dimension needs to be considered after the KVCache. In mlx-vlm it works out ok because:
|
||
|
||
let output = MLXFast.scaledDotProductAttention( | ||
queries: queries, keys: keys, values: values, scale: scale, mask: mask | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
// Copyright © 2025 Apple Inc. | ||
|
||
import ArgumentParser | ||
import Foundation | ||
import MLX | ||
import MLXLLM | ||
import MLXLMCommon | ||
import MLXVLM | ||
|
||
struct ChatCommand: AsyncParsableCommand { | ||
static let configuration = CommandConfiguration( | ||
commandName: "chat", | ||
abstract: "interactive chat with model" | ||
) | ||
|
||
@OptionGroup var args: ModelArguments | ||
@OptionGroup var memory: MemoryArguments | ||
@OptionGroup var generate: GenerateArguments | ||
@OptionGroup var media: MediaArguments | ||
|
||
struct State { | ||
var parameters: GenerateParameters | ||
var processing: UserInput.Processing | ||
|
||
var images: [UserInput.Image] | ||
var videos: [UserInput.Video] | ||
|
||
var chat: [Chat.Message] | ||
|
||
var cache: [KVCache] | ||
|
||
var printStats = false | ||
} | ||
|
||
@MainActor | ||
mutating func run() async throws { | ||
let defaultModel = MLXLLM.LLMRegistry.mistral7B4bit | ||
|
||
// Load the model | ||
let modelContainer = try await memory.start { [args] in | ||
do { | ||
return try await args.load( | ||
defaultModel: defaultModel.name, modelFactory: LLMModelFactory.shared) | ||
} catch ModelFactoryError.unsupportedModelType { | ||
return try await args.load( | ||
defaultModel: defaultModel.name, modelFactory: VLMModelFactory.shared) | ||
} | ||
} | ||
|
||
// update the context/configuration with any command line parameters | ||
await modelContainer.update { [generate] context in | ||
generate.prepare(&context) | ||
} | ||
|
||
try await chat(modelContainer: modelContainer) | ||
} | ||
|
||
func chat(modelContainer: ModelContainer) async throws { | ||
try await modelContainer.perform { context in | ||
let parameters = generate.generateParameters | ||
let initialState = State( | ||
parameters: parameters, | ||
processing: media.processing, | ||
images: media.images, videos: media.videos, | ||
chat: [.system(generate.system)], | ||
cache: context.model.newCache(parameters: parameters)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may want a follow up PR for #276 to add the KVCache, but without the change to |
||
|
||
var state = initialState | ||
|
||
print("> ", terminator: "") | ||
while let line = readLine() { | ||
if line.hasPrefix("/") { | ||
// handle commands | ||
switch command(line: line, state: &state) { | ||
case .exit: | ||
return | ||
case .reset: | ||
state = initialState | ||
state.cache = context.model.newCache(parameters: parameters) | ||
continue | ||
case .inference: | ||
// continue and run inference | ||
break | ||
case .handled: | ||
print("\n\n> ", terminator: "") | ||
continue | ||
} | ||
} else { | ||
// chat input | ||
state.chat.append(.user(line, images: state.images, videos: state.videos)) | ||
} | ||
|
||
// consume the media, if any | ||
state.images.removeAll() | ||
state.videos.removeAll() | ||
|
||
// convert UserInput to LMInput | ||
let userInput = UserInput(chat: state.chat, processing: state.processing) | ||
let input = try await context.processor.prepare(input: userInput) | ||
|
||
// generate the output | ||
var output = "" | ||
var result: GenerateCompletionInfo? | ||
for await item in try MLXLMCommon.generate( | ||
input: input, cache: state.cache, parameters: parameters, context: context | ||
) { | ||
switch item { | ||
case .chunk(let string): | ||
output += string | ||
print(string, terminator: "") | ||
case .info(let info): | ||
result = info | ||
} | ||
} | ||
|
||
// add the assistant response to the chat messages | ||
state.chat.append(.assistant(output)) | ||
|
||
if state.printStats, let result { | ||
print( | ||
"\ntime to first token: \(result.promptTime.formatted()) tps: \(result.tokensPerSecond.formatted())" | ||
) | ||
} | ||
print("\n\n> ", terminator: "") | ||
} | ||
} | ||
} | ||
|
||
enum CommandDisposition { | ||
case exit | ||
case reset | ||
case inference | ||
case handled | ||
} | ||
|
||
func help() { | ||
print( | ||
""" | ||
/help -- this message | ||
/quit -- terminate the chat | ||
/memory -- print memory stats | ||
/stats -- toggle token stats | ||
/reset -- reset the chat session to initial state | ||
/image [pathOrURL] -- provide an image | ||
/video [pathOrURL] -- provide a video | ||
/again -- rerun inference for last response | ||
/parameters -- print generation parametes | ||
/temperature [number] -- set the sampling temperature | ||
/topP [number] -- set the top p sampling | ||
/maxTokens [number] -- set the maximum number of tokens to generate or no number to remove limit | ||
""") | ||
} | ||
|
||
func command(line: String, state: inout State) -> CommandDisposition { | ||
let command = line.split(separator: " ")[0] | ||
let rest = String( | ||
line.dropFirst(command.count).trimmingCharacters(in: .whitespaces)) | ||
|
||
func url(_ string: String) -> URL? { | ||
if string.hasPrefix("/") { | ||
URL(filePath: string) | ||
} else { | ||
URL(string: string) | ||
} | ||
} | ||
|
||
switch command { | ||
case "/help": | ||
help() | ||
|
||
case "/quit": | ||
return .exit | ||
|
||
case "/memory": | ||
let memory = GPU.snapshot() | ||
print("Memory size: \(GPU.memoryLimit / 1024)K") | ||
print("Cache size: \(GPU.cacheLimit / 1024)K") | ||
print(memory.description) | ||
|
||
case "/stats": | ||
state.printStats.toggle() | ||
print("Token stats: \(state.printStats ? "ON" : "OFF")") | ||
|
||
case "/reset": | ||
return .reset | ||
|
||
case "/image": | ||
if let url = url(rest) { | ||
state.images.append(UserInput.Image.url(url)) | ||
} | ||
case "/video": | ||
if let url = url(rest) { | ||
state.videos.append(UserInput.Video.url(url)) | ||
} | ||
|
||
case "/again": | ||
state.chat.removeLast() | ||
return .inference | ||
|
||
case "/parameters": | ||
print(state.parameters) | ||
case "/temperature": | ||
if let value = Float(rest) { | ||
state.parameters.temperature = value | ||
print(state.parameters) | ||
} | ||
case "/topP": | ||
if let value = Float(rest) { | ||
state.parameters.topP = value | ||
print(state.parameters) | ||
} | ||
case "/maxTokens": | ||
state.parameters.maxTokens = Int(rest) | ||
print(state.parameters) | ||
|
||
default: | ||
help() | ||
} | ||
|
||
return .handled | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Easily pass the optional cache in to the iterator