Skip to content

Commit 4add690

Browse files
Merge branch 'main' into qwen-2.5-vl
2 parents 19e2aa8 + c8164a2 commit 4add690

14 files changed

+221
-248
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 11 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,10 @@ private func create<C: Codable, M>(
2020
/// Registry of model type, e.g 'llama', to functions that can instantiate the model from configuration.
2121
///
2222
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
23-
public class ModelTypeRegistry: @unchecked Sendable {
24-
25-
/// Creates an empty registry.
26-
public init() {
27-
self.creators = [:]
28-
}
29-
30-
/// Creates a registry with given creators.
31-
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
32-
self.creators = creators
33-
}
23+
public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
3424

3525
/// Shared instance with default model types.
36-
public static let shared: ModelTypeRegistry = .init(creators: all())
26+
public static let shared: LLMTypeRegistry = .init(creators: all())
3727

3828
/// All predefined model types.
3929
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
@@ -53,32 +43,6 @@ public class ModelTypeRegistry: @unchecked Sendable {
5343
]
5444
}
5545

56-
// Note: using NSLock as we have very small (just dictionary get/set)
57-
// critical sections and expect no contention. this allows the methods
58-
// to remain synchronous.
59-
private let lock = NSLock()
60-
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
61-
62-
/// Add a new model to the type registry.
63-
public func registerModelType(
64-
_ type: String, creator: @Sendable @escaping (URL) throws -> any LanguageModel
65-
) {
66-
lock.withLock {
67-
creators[type] = creator
68-
}
69-
}
70-
71-
/// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
72-
public func createModel(configuration: URL, modelType: String) throws -> LanguageModel {
73-
let creator = lock.withLock {
74-
creators[modelType]
75-
}
76-
guard let creator else {
77-
throw ModelFactoryError.unsupportedModelType(modelType)
78-
}
79-
return try creator(configuration)
80-
}
81-
8246
}
8347

8448
/// Registry of models and any overrides that go with them, e.g. prompt augmentation.
@@ -87,23 +51,10 @@ public class ModelTypeRegistry: @unchecked Sendable {
8751
/// The python tokenizers have a very rich set of implementations and configuration. The
8852
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
8953
/// implementation, if needed.
90-
public class ModelRegistry: @unchecked Sendable {
91-
92-
/// Creates an empty registry.
93-
public init() {
94-
self.registry = Dictionary()
95-
}
96-
97-
/// Creates a new registry with from given model configurations.
98-
public init(modelConfigurations: [ModelConfiguration]) {
99-
self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
100-
}
54+
public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
10155

10256
/// Shared instance with default model configurations.
103-
public static let shared = ModelRegistry(modelConfigurations: all())
104-
105-
private let lock = NSLock()
106-
private var registry: [String: ModelConfiguration]
57+
public static let shared = LLMRegistry(modelConfigurations: all())
10758

10859
static public let smolLM_135M_4bit = ModelConfiguration(
10960
id: "mlx-community/SmolLM-135M-Instruct-4bit",
@@ -239,31 +190,11 @@ public class ModelRegistry: @unchecked Sendable {
239190
]
240191
}
241192

242-
public func register(configurations: [ModelConfiguration]) {
243-
lock.withLock {
244-
for c in configurations {
245-
registry[c.name] = c
246-
}
247-
}
248-
}
249-
250-
public func configuration(id: String) -> ModelConfiguration {
251-
lock.withLock {
252-
if let c = registry[id] {
253-
return c
254-
} else {
255-
return ModelConfiguration(id: id)
256-
}
257-
}
258-
}
259-
260-
public var models: some Collection<ModelConfiguration> & Sendable {
261-
lock.withLock {
262-
return registry.values
263-
}
264-
}
265193
}
266194

195+
@available(*, deprecated, renamed: "LLMRegistry", message: "Please use LLMRegistry directly.")
196+
public typealias ModelRegistry = LLMRegistry
197+
267198
private struct LLMUserInputProcessor: UserInputProcessor {
268199

269200
let tokenizer: Tokenizer
@@ -304,19 +235,20 @@ private struct LLMUserInputProcessor: UserInputProcessor {
304235
/// ```
305236
public class LLMModelFactory: ModelFactory {
306237

307-
public init(typeRegistry: ModelTypeRegistry, modelRegistry: ModelRegistry) {
238+
public init(typeRegistry: ModelTypeRegistry, modelRegistry: AbstractModelRegistry) {
308239
self.typeRegistry = typeRegistry
309240
self.modelRegistry = modelRegistry
310241
}
311242

312243
/// Shared instance with default behavior.
313-
public static let shared = LLMModelFactory(typeRegistry: .shared, modelRegistry: .shared)
244+
public static let shared = LLMModelFactory(
245+
typeRegistry: LLMTypeRegistry.shared, modelRegistry: LLMRegistry.shared)
314246

315247
/// registry of model type, e.g. configuration value `llama` -> configuration and init methods
316248
public let typeRegistry: ModelTypeRegistry
317249

318250
/// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit`
319-
public let modelRegistry: ModelRegistry
251+
public let modelRegistry: AbstractModelRegistry
320252

321253
public func configuration(id: String) -> ModelConfiguration {
322254
modelRegistry.configuration(id: id)

Libraries/MLXLLM/Models/Cohere.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ public class CohereModel: Module, LLMModel, KVCacheDimensionProvider {
163163

164164
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
165165
var out = model(inputs, cache: cache)
166-
out = matmul(out, model.embedTokens.weight.T)
166+
out = model.embedTokens.asLinear(out)
167167
out = out * self.logitScale
168168
return out
169169
}

Libraries/MLXLLM/Models/OpenELM.swift

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ func makeDivisible(_ v: Float, divisor: Int = 8, minValue: Float? = nil) -> Int
2727
}
2828

2929
private class MultiHeadCausalAttention: Module {
30-
var args: OpenElmConfiguration
3130
let scale: Float
3231
let heads: Int
3332
let headDim: Int
@@ -36,18 +35,17 @@ private class MultiHeadCausalAttention: Module {
3635
@ModuleInfo(key: "qkv_proj") var qkvProj: Linear
3736
@ModuleInfo(key: "out_proj") var outProj: Linear
3837

39-
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
40-
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm
38+
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm?
39+
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm?
4140

4241
let rope: RoPE
4342

4443
public init(_ args: OpenElmConfiguration, layerId: Int) {
45-
self.args = args
4644
self.headDim = args.headDimensions
4745
let modelDim = args.modelDim
4846

49-
self.heads = self.args.numQueryHeads[layerId]
50-
self.kvHeads = self.args.kvHeads[layerId]
47+
self.heads = args.numQueryHeads[layerId]
48+
self.kvHeads = args.kvHeads[layerId]
5149
self.scale = pow(Float(headDim), -0.5)
5250

5351
let opSize = (heads + (kvHeads * 2)) * headDim
@@ -74,7 +72,7 @@ private class MultiHeadCausalAttention: Module {
7472
var keys = qkvSplit[1]
7573
var values = qkvSplit[2]
7674

77-
if args.normalizeQkProjections {
75+
if let qNorm, let kNorm {
7876
queries = qNorm(queries)
7977
keys = kNorm(keys)
8078
}
@@ -181,27 +179,27 @@ public class OpenELMModel: Module, LLMModel, KVCacheDimensionProvider {
181179
public let vocabularySize: Int
182180
public let kvHeads: [Int]
183181

184-
let shareInputOutputLayers: Bool
185182
let transformer: OpenELMModelInner
186183

187-
@ModuleInfo(key: "lm_head") var lmHead: Linear
184+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
188185

189186
public init(_ args: OpenElmConfiguration) {
190187
self.vocabularySize = args.vocabularySize
191188
self.kvHeads = args.kvHeads
192189

193190
self.transformer = OpenELMModelInner(args)
194-
self.shareInputOutputLayers = args.shareInputOutputLayers
195-
self._lmHead.wrappedValue = Linear(
196-
args.numTransformerLayers, args.vocabularySize, bias: false)
191+
if !args.shareInputOutputLayers {
192+
self._lmHead.wrappedValue = Linear(
193+
args.numTransformerLayers, args.vocabularySize, bias: false)
194+
}
197195
}
198196

199197
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
200198
var out = transformer(inputs, cache: cache)
201-
if shareInputOutputLayers {
202-
out = matmul(out, transformer.embedTokens.weight.T)
203-
} else {
199+
if let lmHead {
204200
out = lmHead(out)
201+
} else {
202+
out = transformer.embedTokens.asLinear(out)
205203
}
206204

207205
return out

Libraries/MLXLLM/Models/Starcoder2.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ public class Starcoder2Model: Module, LLMModel, KVCacheDimensionProvider {
173173
if !tieWordEmbeddings {
174174
return lmHead(out)
175175
} else {
176-
out = matmul(out, model.embedTokens.weight.T)
176+
out = model.embedTokens.asLinear(out)
177177
return out
178178
}
179179
}

Libraries/MLXLMCommon/ModelConfiguration.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ public struct ModelConfiguration: Sendable {
3131
public let overrideTokenizer: String?
3232

3333
/// A reasonable default prompt for the model
34-
public let defaultPrompt: String
34+
public var defaultPrompt: String
3535

3636
/// Additional tokens to use for end of string
37-
public let extraEOSTokens: Set<String>
37+
public var extraEOSTokens: Set<String>
3838

3939
public init(
4040
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,

Libraries/MLXLMCommon/ModelContainer.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,11 @@ import Tokenizers
3232
/// }
3333
/// ```
3434
public actor ModelContainer {
35-
let context: ModelContext
36-
nonisolated public let configuration: ModelConfiguration
35+
var context: ModelContext
36+
public var configuration: ModelConfiguration { context.configuration }
3737

3838
public init(context: ModelContext) {
3939
self.context = context
40-
self.configuration = context.configuration
4140
}
4241

4342
/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
@@ -75,4 +74,9 @@ public actor ModelContainer {
7574
try await action(context, values)
7675
}
7776

77+
/// Update the owned `ModelContext`.
78+
/// - Parameter action: update action
79+
public func update(_ action: @Sendable (inout ModelContext) -> Void) {
80+
action(&context)
81+
}
7882
}

Libraries/MLXLMCommon/ModelFactory.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ public enum ModelFactoryError: Error {
2222
/// See also ``ModelFactory/loadContainer(hub:configuration:progressHandler:)`` and
2323
/// ``ModelContainer``.
2424
public struct ModelContext {
25-
public let configuration: ModelConfiguration
26-
public let model: any LanguageModel
27-
public let processor: any UserInputProcessor
28-
public let tokenizer: Tokenizer
25+
public var configuration: ModelConfiguration
26+
public var model: any LanguageModel
27+
public var processor: any UserInputProcessor
28+
public var tokenizer: Tokenizer
2929

3030
public init(
3131
configuration: ModelConfiguration, model: any LanguageModel,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import Foundation
4+
5+
open class AbstractModelRegistry: @unchecked Sendable {
6+
7+
/// Creates an empty registry.
8+
public init() {
9+
self.registry = Dictionary()
10+
}
11+
12+
/// Creates a new registry with from given model configurations.
13+
public init(modelConfigurations: [ModelConfiguration]) {
14+
self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
15+
}
16+
17+
private let lock = NSLock()
18+
private var registry: [String: ModelConfiguration]
19+
20+
public func register(configurations: [ModelConfiguration]) {
21+
lock.withLock {
22+
for c in configurations {
23+
registry[c.name] = c
24+
}
25+
}
26+
}
27+
28+
public func configuration(id: String) -> ModelConfiguration {
29+
lock.withLock {
30+
if let c = registry[id] {
31+
return c
32+
} else {
33+
return ModelConfiguration(id: id)
34+
}
35+
}
36+
}
37+
38+
public var models: some Collection<ModelConfiguration> & Sendable {
39+
lock.withLock {
40+
return registry.values
41+
}
42+
}
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import Foundation
4+
5+
open class ModelTypeRegistry: @unchecked Sendable {
6+
7+
/// Creates an empty registry.
8+
public init() {
9+
self.creators = [:]
10+
}
11+
12+
/// Creates a registry with given creators.
13+
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
14+
self.creators = creators
15+
}
16+
17+
// Note: using NSLock as we have very small (just dictionary get/set)
18+
// critical sections and expect no contention. this allows the methods
19+
// to remain synchronous.
20+
private let lock = NSLock()
21+
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
22+
23+
/// Add a new model to the type registry.
24+
public func registerModelType(
25+
_ type: String, creator: @Sendable @escaping (URL) throws -> any LanguageModel
26+
) {
27+
lock.withLock {
28+
creators[type] = creator
29+
}
30+
}
31+
32+
/// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
33+
public func createModel(configuration: URL, modelType: String) throws -> LanguageModel {
34+
let creator = lock.withLock {
35+
creators[modelType]
36+
}
37+
guard let creator else {
38+
throw ModelFactoryError.unsupportedModelType(modelType)
39+
}
40+
return try creator(configuration)
41+
}
42+
43+
}

0 commit comments

Comments
 (0)