Skip to content

Commit ea86d94

Browse files
Add argument for extra EOS token to llm-tool (#217)
* Improve llm-tool - make ModelConfiguration and ModelContext properties mutable - update context/configuration with extra EOS tokens --------- Co-authored-by: David Koski <[email protected]>
1 parent fc9dfc1 commit ea86d94

File tree

6 files changed

+44
-24
lines changed

6 files changed

+44
-24
lines changed

Libraries/MLXLMCommon/ModelConfiguration.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ public struct ModelConfiguration: Sendable {
3131
public let overrideTokenizer: String?
3232

3333
/// A reasonable default prompt for the model
34-
public let defaultPrompt: String
34+
public var defaultPrompt: String
3535

3636
/// Additional tokens to use for end of string
37-
public let extraEOSTokens: Set<String>
37+
public var extraEOSTokens: Set<String>
3838

3939
public init(
4040
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,

Libraries/MLXLMCommon/ModelContainer.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,11 @@ import Tokenizers
3232
/// }
3333
/// ```
3434
public actor ModelContainer {
35-
let context: ModelContext
36-
nonisolated public let configuration: ModelConfiguration
35+
var context: ModelContext
36+
public var configuration: ModelConfiguration { context.configuration }
3737

3838
public init(context: ModelContext) {
3939
self.context = context
40-
self.configuration = context.configuration
4140
}
4241

4342
/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
@@ -75,4 +74,9 @@ public actor ModelContainer {
7574
try await action(context, values)
7675
}
7776

77+
/// Update the owned `ModelContext`.
78+
/// - Parameter action: update action
79+
public func update(_ action: @Sendable (inout ModelContext) -> Void) {
80+
action(&context)
81+
}
7882
}

Libraries/MLXLMCommon/ModelFactory.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ public enum ModelFactoryError: Error {
2222
/// See also ``ModelFactory/loadContainer(hub:configuration:progressHandler:)`` and
2323
/// ``ModelContainer``.
2424
public struct ModelContext {
25-
public let configuration: ModelConfiguration
26-
public let model: any LanguageModel
27-
public let processor: any UserInputProcessor
28-
public let tokenizer: Tokenizer
25+
public var configuration: ModelConfiguration
26+
public var model: any LanguageModel
27+
public var processor: any UserInputProcessor
28+
public var tokenizer: Tokenizer
2929

3030
public init(
3131
configuration: ModelConfiguration, model: any LanguageModel,

Tools/llm-tool/LLMTool.swift

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct ModelArguments: ParsableArguments, Sendable {
3131

3232
let modelName = self.model ?? defaultModel
3333

34+
print("Loading \(modelName)...")
35+
3436
if modelName.hasPrefix("/") {
3537
// path
3638
modelConfiguration = ModelConfiguration(directory: URL(filePath: modelName))
@@ -67,6 +69,9 @@ struct GenerateArguments: ParsableArguments, Sendable {
6769
@Option(name: .long, help: "The number of tokens to consider for repetition penalty")
6870
var repetitionContextSize: Int = 20
6971

72+
@Option(name: .long, help: "Additional end-of-sequence token to stop generation")
73+
var extraEosToken: String?
74+
7075
@Option(name: .long, help: "The PRNG seed")
7176
var seed: UInt64 = 0
7277

@@ -89,17 +94,22 @@ struct GenerateArguments: ParsableArguments, Sendable {
8994
}
9095
}
9196

97+
func prepare(
98+
_ context: inout ModelContext
99+
) {
100+
if let extraEosToken {
101+
context.configuration.extraEOSTokens.insert(extraEosToken)
102+
}
103+
}
104+
92105
func generate(
93106
input: LMInput, context: ModelContext
94-
)
95-
throws -> GenerateResult
96-
{
107+
) throws -> GenerateResult {
97108
var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer)
98109

99110
return try MLXLMCommon.generate(
100111
input: input, parameters: generateParameters, context: context
101112
) { tokens in
102-
103113
if let last = tokens.last {
104114
detokenizer.append(token: last)
105115
}
@@ -276,11 +286,16 @@ struct EvaluateCommand: AsyncParsableCommand {
276286
try await args.load(defaultModel: defaultModel.name, modelFactory: modelFactory)
277287
}
278288

289+
// update the context/configuration with any command line parameters
290+
await modelContainer.update { [generate] context in
291+
generate.prepare(&context)
292+
}
293+
279294
// Get the resolved configuration (this has the default prompt)
280-
let modelConfiguration = modelContainer.configuration
295+
let modelConfiguration = await modelContainer.configuration
281296

282297
if !generate.quiet {
283-
print("Model loaded -> \(modelConfiguration.id)")
298+
print("Loaded \(modelConfiguration.name)")
284299
}
285300

286301
let userInput = self.userInput(modelConfiguration: modelConfiguration)

Tools/llm-tool/LoraCommands.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct LoRAModelArguments: ParsableArguments, Sendable {
4848
// convert some of the Linear layers to LoRALinear
4949
await modelContainer.perform { context in
5050
guard let lora = context.model as? LoRAModel else {
51-
fatalError("Model \(modelContainer.configuration.name) is not a LoRAModel")
51+
fatalError("Model \(context.configuration.name) is not a LoRAModel")
5252
}
5353
LoRATrain.convert(model: context.model, layers: lora.loraLinearLayers(loraLayers))
5454
}
@@ -197,7 +197,7 @@ struct LoRAFuseCommand: AsyncParsableCommand {
197197
// fuse them back into Linear/QuantizedLinear
198198
await modelContainer.perform { [args, deQuantize] context in
199199
guard let lora = context.model as? LoRAModel else {
200-
fatalError("Model \(modelContainer.configuration.name) is not a LoRAModel")
200+
fatalError("Model \(context.configuration.name) is not a LoRAModel")
201201
}
202202

203203
LoRATrain.fuse(
@@ -207,7 +207,7 @@ struct LoRAFuseCommand: AsyncParsableCommand {
207207

208208
// make the new directory and copy files from source model
209209
try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true)
210-
let inputURL = modelContainer.configuration.modelDirectory()
210+
let inputURL = await modelContainer.configuration.modelDirectory()
211211
let enumerator = FileManager.default.enumerator(
212212
at: inputURL, includingPropertiesForKeys: nil)!
213213
for case let url as URL in enumerator {
@@ -296,7 +296,8 @@ struct LoRAEvalCommand: AsyncParsableCommand {
296296

297297
memory.start()
298298

299-
let prompt = generate.prompt ?? modelContainer.configuration.defaultPrompt
299+
let defaultPrompt = await modelContainer.configuration.defaultPrompt
300+
let prompt = generate.prompt ?? defaultPrompt
300301

301302
if !generate.quiet {
302303
print("Starting generation ...")

mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@
5656
isEnabled = "NO">
5757
</CommandLineArgument>
5858
<CommandLineArgument
59-
argument = "--prompt &apos;Describe the image in English.&apos; --image https://www.gstatic.com/webp/gallery/1.webp"
59+
argument = "--model microsoft/Phi-4-mini-instruct --prompt &quot;Why is the sky blue?&quot; --extra-eos-token &quot;&lt;|end|&gt;&quot;"
6060
isEnabled = "NO">
6161
</CommandLineArgument>
6262
<CommandLineArgument
63-
argument = "--model mlx-community/Qwen2-VL-2B-Instruct-4bit"
64-
isEnabled = "NO">
63+
argument = "--model mlx-community/Qwen2-VL-2B-Instruct-4bit --prompt &apos;Describe the image in English.&apos; --image https://www.gstatic.com/webp/gallery/1.webp"
64+
isEnabled = "YES">
6565
</CommandLineArgument>
6666
<CommandLineArgument
6767
argument = "--repetition-penalty 1.2"
@@ -89,15 +89,15 @@
8989
</CommandLineArgument>
9090
<CommandLineArgument
9191
argument = "--prompt &apos;Why is the sky blue?&apos;"
92-
isEnabled = "YES">
92+
isEnabled = "NO">
9393
</CommandLineArgument>
9494
<CommandLineArgument
9595
argument = "--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
9696
isEnabled = "NO">
9797
</CommandLineArgument>
9898
<CommandLineArgument
9999
argument = "--model mlx-community/Llama-3.2-1B-Instruct-4bit"
100-
isEnabled = "YES">
100+
isEnabled = "NO">
101101
</CommandLineArgument>
102102
<CommandLineArgument
103103
argument = "--model mlx-community/phi-2-hf-4bit-mlx"

0 commit comments

Comments
 (0)