Skip to content

Commit 362939f

Browse files
Merge branch 'main' into qwen-2.5-vl
2 parents 008d804 + 3885b92 commit 362939f

File tree

3 files changed

+151
-39
lines changed

3 files changed

+151
-39
lines changed

Libraries/Embedders/Models.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ public struct ModelConfiguration: Sendable {
8383
return ModelConfiguration(id: id)
8484
}
8585
}
86+
87+
@MainActor
88+
public static var models: some Collection<ModelConfiguration> & Sendable {
89+
bootstrap()
90+
return Self.registry.values
91+
}
8692
}
8793

8894
extension ModelConfiguration {

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,42 @@ private func create<C: Codable, M>(
2222
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
2323
public class ModelTypeRegistry: @unchecked Sendable {
2424

25+
/// Creates an empty registry.
26+
public init() {
27+
self.creators = [:]
28+
}
29+
30+
/// Creates a registry with given creators.
31+
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
32+
self.creators = creators
33+
}
34+
35+
/// Shared instance with default model types.
36+
public static let shared: ModelTypeRegistry = .init(creators: all())
37+
38+
/// All predefined model types.
39+
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
40+
[
41+
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
42+
"llama": create(LlamaConfiguration.self, LlamaModel.init),
43+
"phi": create(PhiConfiguration.self, PhiModel.init),
44+
"phi3": create(Phi3Configuration.self, Phi3Model.init),
45+
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
46+
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
47+
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
48+
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
49+
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
50+
"cohere": create(CohereConfiguration.self, CohereModel.init),
51+
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
52+
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
53+
]
54+
}
55+
2556
// Note: using NSLock as we have very small (just dictionary get/set)
2657
// critical sections and expect no contention. this allows the methods
2758
// to remain synchronous.
2859
private let lock = NSLock()
29-
30-
private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
31-
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
32-
"llama": create(LlamaConfiguration.self, LlamaModel.init),
33-
"phi": create(PhiConfiguration.self, PhiModel.init),
34-
"phi3": create(Phi3Configuration.self, Phi3Model.init),
35-
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
36-
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
37-
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
38-
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
39-
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
40-
"cohere": create(CohereConfiguration.self, CohereModel.init),
41-
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
42-
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
43-
]
60+
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
4461

4562
/// Add a new model to the type registry.
4663
public func registerModelType(
@@ -72,8 +89,21 @@ public class ModelTypeRegistry: @unchecked Sendable {
7289
/// implementation, if needed.
7390
public class ModelRegistry: @unchecked Sendable {
7491

92+
/// Creates an empty registry.
93+
public init() {
94+
self.registry = Dictionary()
95+
}
96+
97+
/// Creates a new registry with from given model configurations.
98+
public init(modelConfigurations: [ModelConfiguration]) {
99+
self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
100+
}
101+
102+
/// Shared instance with default model configurations.
103+
public static let shared = ModelRegistry(modelConfigurations: all())
104+
75105
private let lock = NSLock()
76-
private var registry = Dictionary(uniqueKeysWithValues: all().map { ($0.name, $0) })
106+
private var registry: [String: ModelConfiguration]
77107

78108
static public let smolLM_135M_4bit = ModelConfiguration(
79109
id: "mlx-community/SmolLM-135M-Instruct-4bit",
@@ -226,6 +256,12 @@ public class ModelRegistry: @unchecked Sendable {
226256
}
227257
}
228258
}
259+
260+
public var models: some Collection<ModelConfiguration> & Sendable {
261+
lock.withLock {
262+
return registry.values
263+
}
264+
}
229265
}
230266

231267
private struct LLMUserInputProcessor: UserInputProcessor {
@@ -268,13 +304,19 @@ private struct LLMUserInputProcessor: UserInputProcessor {
268304
/// ```
269305
public class LLMModelFactory: ModelFactory {
270306

271-
public static let shared = LLMModelFactory()
307+
public init(typeRegistry: ModelTypeRegistry, modelRegistry: ModelRegistry) {
308+
self.typeRegistry = typeRegistry
309+
self.modelRegistry = modelRegistry
310+
}
311+
312+
/// Shared instance with default behavior.
313+
public static let shared = LLMModelFactory(typeRegistry: .shared, modelRegistry: .shared)
272314

273315
/// registry of model type, e.g. configuration value `llama` -> configuration and init methods
274-
public let typeRegistry = ModelTypeRegistry()
316+
public let typeRegistry: ModelTypeRegistry
275317

276318
/// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit`
277-
public let modelRegistry = ModelRegistry()
319+
public let modelRegistry: ModelRegistry
278320

279321
public func configuration(id: String) -> ModelConfiguration {
280322
modelRegistry.configuration(id: id)

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,34 @@ private func create<C: Codable, P>(
6767
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
6868
public class ModelTypeRegistry: @unchecked Sendable {
6969

70+
/// Creates an empty registry.
71+
public init() {
72+
self.creators = [:]
73+
}
74+
75+
/// Creates a registry with given creators.
76+
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
77+
self.creators = creators
78+
}
79+
80+
/// Shared instance with default model types.
81+
public static let shared: ModelTypeRegistry = .init(creators: all())
82+
83+
/// All predefined model types
84+
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
85+
[
86+
"paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init),
87+
"qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init),
88+
"idefics3": create(Idefics3Configuration.self, Idefics3.init),
89+
]
90+
}
91+
7092
// Note: using NSLock as we have very small (just dictionary get/set)
7193
// critical sections and expect no contention. this allows the methods
7294
// to remain synchronous.
7395
private let lock = NSLock()
7496

75-
private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
76-
"paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init),
77-
"qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init),
78-
"qwen2_5_vl": create(Qwen25VLConfiguration.self, Qwen25VL.init),
79-
"idefics3": create(Idefics3Configuration.self, Idefics3.init),
80-
]
97+
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
8198

8299
/// Add a new model to the type registry.
83100
public func registerModelType(
@@ -105,13 +122,25 @@ public class ModelTypeRegistry: @unchecked Sendable {
105122

106123
public class ProcessorTypeRegistry: @unchecked Sendable {
107124

108-
// Note: using NSLock as we have very small (just dictionary get/set)
109-
// critical sections and expect no contention. this allows the methods
110-
// to remain synchronous.
111-
private let lock = NSLock()
125+
/// Creates an empty registry.
126+
public init() {
127+
self.creators = [:]
128+
}
112129

113-
private var creators:
114-
[String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] = [
130+
/// Creates a registry with given creators.
131+
public init(creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor])
132+
{
133+
self.creators = creators
134+
}
135+
136+
/// Shared instance with default processor types.
137+
public static let shared: ProcessorTypeRegistry = .init(creators: all())
138+
139+
/// All predefined processor types.
140+
private static func all() -> [String: @Sendable (URL, any Tokenizer) throws ->
141+
any UserInputProcessor]
142+
{
143+
[
115144
"PaliGemmaProcessor": create(
116145
PaliGemmaProcessorConfiguration.self, PaliGemmaProcessor.init),
117146
"Qwen2VLProcessor": create(
@@ -121,6 +150,14 @@ public class ProcessorTypeRegistry: @unchecked Sendable {
121150
"Idefics3Processor": create(
122151
Idefics3ProcessorConfiguration.self, Idefics3Processor.init),
123152
]
153+
}
154+
155+
// Note: using NSLock as we have very small (just dictionary get/set)
156+
// critical sections and expect no contention. this allows the methods
157+
// to remain synchronous.
158+
private let lock = NSLock()
159+
160+
private var creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor]
124161

125162
/// Add a new model to the type registry.
126163
public func registerProcessorType(
@@ -156,12 +193,21 @@ public class ProcessorTypeRegistry: @unchecked Sendable {
156193
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
157194
/// implementation, if needed.
158195
public class ModelRegistry: @unchecked Sendable {
196+
/// Creates an empty registry.
197+
public init() {
198+
registry = Dictionary()
199+
}
200+
201+
/// Creates a new registry with from given model configurations.
202+
public init(modelConfigurations: [ModelConfiguration]) {
203+
registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
204+
}
205+
206+
/// Shared instance with default model configurations.
207+
public static let shared = ModelRegistry(modelConfigurations: all())
159208

160209
private let lock = NSLock()
161-
private var registry = Dictionary(
162-
uniqueKeysWithValues: all().map {
163-
($0.name, $0)
164-
})
210+
private var registry: [String: ModelConfiguration]
165211

166212
static public let paligemma3bMix448_8bit = ModelConfiguration(
167213
id: "mlx-community/paligemma-3b-mix-448-8bit",
@@ -188,6 +234,7 @@ public class ModelRegistry: @unchecked Sendable {
188234
paligemma3bMix448_8bit,
189235
qwen2VL2BInstruct4Bit,
190236
qwen2_5VL3BInstruct4Bit,
237+
smolvlminstruct4bit,
191238
]
192239
}
193240

@@ -208,6 +255,12 @@ public class ModelRegistry: @unchecked Sendable {
208255
}
209256
}
210257
}
258+
259+
public var models: some Collection<ModelConfiguration> & Sendable {
260+
lock.withLock {
261+
return registry.values
262+
}
263+
}
211264
}
212265

213266
/// Factory for creating new LLMs.
@@ -221,16 +274,27 @@ public class ModelRegistry: @unchecked Sendable {
221274
/// ```
222275
public class VLMModelFactory: ModelFactory {
223276

224-
public static let shared = VLMModelFactory()
277+
public init(
278+
typeRegistry: ModelTypeRegistry, processorRegistry: ProcessorTypeRegistry,
279+
modelRegistry: ModelRegistry
280+
) {
281+
self.typeRegistry = typeRegistry
282+
self.processorRegistry = processorRegistry
283+
self.modelRegistry = modelRegistry
284+
}
285+
286+
/// Shared instance with default behavior.
287+
public static let shared = VLMModelFactory(
288+
typeRegistry: .shared, processorRegistry: .shared, modelRegistry: .shared)
225289

226290
/// registry of model type, e.g. configuration value `paligemma` -> configuration and init methods
227-
public let typeRegistry = ModelTypeRegistry()
291+
public let typeRegistry: ModelTypeRegistry
228292

229293
/// registry of input processor type, e.g. configuration value `PaliGemmaProcessor` -> configuration and init methods
230-
public let processorRegistry = ProcessorTypeRegistry()
294+
public let processorRegistry: ProcessorTypeRegistry
231295

232296
/// registry of model id to configuration, e.g. `mlx-community/paligemma-3b-mix-448-8bit`
233-
public let modelRegistry = ModelRegistry()
297+
public let modelRegistry: ModelRegistry
234298

235299
public func configuration(id: String) -> ModelConfiguration {
236300
modelRegistry.configuration(id: id)

0 commit comments

Comments
 (0)