Skip to content

Commit 3885b92

Browse files
authored
Make LLMModelFactory and VLMModelFactory inits public (#226)
* Make LLMModelFactory and VLMModelFactory init public * Make ModelRegistry init public * Make ProcessorTypeRegistry init public * Add init(creators:) to ProcessorTypeRegistry and update corresponding codes * Make ModelTypeRegistry init public in LLMModelFactory * Make ModelTypeRegistry init public in VLMModelFactory
1 parent 06c825a commit 3885b92

File tree

2 files changed

+134
-40
lines changed

2 files changed

+134
-40
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,42 @@ private func create<C: Codable, M>(
2222
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
2323
public class ModelTypeRegistry: @unchecked Sendable {
2424

25+
/// Creates an empty registry.
26+
public init() {
27+
self.creators = [:]
28+
}
29+
30+
/// Creates a registry with given creators.
31+
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
32+
self.creators = creators
33+
}
34+
35+
/// Shared instance with default model types.
36+
public static let shared: ModelTypeRegistry = .init(creators: all())
37+
38+
/// All predefined model types.
39+
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
40+
[
41+
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
42+
"llama": create(LlamaConfiguration.self, LlamaModel.init),
43+
"phi": create(PhiConfiguration.self, PhiModel.init),
44+
"phi3": create(Phi3Configuration.self, Phi3Model.init),
45+
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
46+
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
47+
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
48+
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
49+
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
50+
"cohere": create(CohereConfiguration.self, CohereModel.init),
51+
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
52+
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
53+
]
54+
}
55+
2556
// Note: using NSLock as we have very small (just dictionary get/set)
2657
// critical sections and expect no contention. this allows the methods
2758
// to remain synchronous.
2859
private let lock = NSLock()
29-
30-
private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
31-
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
32-
"llama": create(LlamaConfiguration.self, LlamaModel.init),
33-
"phi": create(PhiConfiguration.self, PhiModel.init),
34-
"phi3": create(Phi3Configuration.self, Phi3Model.init),
35-
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
36-
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
37-
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
38-
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
39-
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
40-
"cohere": create(CohereConfiguration.self, CohereModel.init),
41-
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
42-
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
43-
]
60+
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
4461

4562
/// Add a new model to the type registry.
4663
public func registerModelType(
@@ -72,8 +89,21 @@ public class ModelTypeRegistry: @unchecked Sendable {
7289
/// implementation, if needed.
7390
public class ModelRegistry: @unchecked Sendable {
7491

92+
/// Creates an empty registry.
93+
public init() {
94+
self.registry = Dictionary()
95+
}
96+
97+
/// Creates a new registry with from given model configurations.
98+
public init(modelConfigurations: [ModelConfiguration]) {
99+
self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
100+
}
101+
102+
/// Shared instance with default model configurations.
103+
public static let shared = ModelRegistry(modelConfigurations: all())
104+
75105
private let lock = NSLock()
76-
private var registry = Dictionary(uniqueKeysWithValues: all().map { ($0.name, $0) })
106+
private var registry: [String: ModelConfiguration]
77107

78108
static public let smolLM_135M_4bit = ModelConfiguration(
79109
id: "mlx-community/SmolLM-135M-Instruct-4bit",
@@ -274,13 +304,19 @@ private struct LLMUserInputProcessor: UserInputProcessor {
274304
/// ```
275305
public class LLMModelFactory: ModelFactory {
276306

277-
public static let shared = LLMModelFactory()
307+
public init(typeRegistry: ModelTypeRegistry, modelRegistry: ModelRegistry) {
308+
self.typeRegistry = typeRegistry
309+
self.modelRegistry = modelRegistry
310+
}
311+
312+
/// Shared instance with default behavior.
313+
public static let shared = LLMModelFactory(typeRegistry: .shared, modelRegistry: .shared)
278314

279315
/// registry of model type, e.g. configuration value `llama` -> configuration and init methods
280-
public let typeRegistry = ModelTypeRegistry()
316+
public let typeRegistry: ModelTypeRegistry
281317

282318
/// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit`
283-
public let modelRegistry = ModelRegistry()
319+
public let modelRegistry: ModelRegistry
284320

285321
public func configuration(id: String) -> ModelConfiguration {
286322
modelRegistry.configuration(id: id)

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,34 @@ private func create<C: Codable, P>(
5252
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
5353
public class ModelTypeRegistry: @unchecked Sendable {
5454

55+
/// Creates an empty registry.
56+
public init() {
57+
self.creators = [:]
58+
}
59+
60+
/// Creates a registry with given creators.
61+
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
62+
self.creators = creators
63+
}
64+
65+
/// Shared instance with default model types.
66+
public static let shared: ModelTypeRegistry = .init(creators: all())
67+
68+
/// All predefined model types
69+
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
70+
[
71+
"paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init),
72+
"qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init),
73+
"idefics3": create(Idefics3Configuration.self, Idefics3.init),
74+
]
75+
}
76+
5577
// Note: using NSLock as we have very small (just dictionary get/set)
5678
// critical sections and expect no contention. this allows the methods
5779
// to remain synchronous.
5880
private let lock = NSLock()
5981

60-
private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
61-
"paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init),
62-
"qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init),
63-
"idefics3": create(Idefics3Configuration.self, Idefics3.init),
64-
]
82+
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
6583

6684
/// Add a new model to the type registry.
6785
public func registerModelType(
@@ -90,20 +108,39 @@ public class ModelTypeRegistry: @unchecked Sendable {
90108

91109
public class ProcessorTypeRegistry: @unchecked Sendable {
92110

93-
// Note: using NSLock as we have very small (just dictionary get/set)
94-
// critical sections and expect no contention. this allows the methods
95-
// to remain synchronous.
96-
private let lock = NSLock()
111+
/// Creates an empty registry.
112+
public init() {
113+
self.creators = [:]
114+
}
115+
116+
/// Creates a registry with given creators.
117+
public init(creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor])
118+
{
119+
self.creators = creators
120+
}
97121

98-
private var creators:
99-
[String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] = [
122+
/// Shared instance with default processor types.
123+
public static let shared: ProcessorTypeRegistry = .init(creators: all())
124+
125+
/// All predefined processor types.
126+
private static func all() -> [String: @Sendable (URL, any Tokenizer) throws ->
127+
any UserInputProcessor]
128+
{
129+
[
100130
"PaliGemmaProcessor": create(
101131
PaliGemmaProcessorConfiguration.self, PaligGemmaProcessor.init),
102-
"Qwen2VLProcessor": create(
103-
Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init),
132+
"Qwen2VLProcessor": create(Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init),
104133
"Idefics3Processor": create(
105134
Idefics3ProcessorConfiguration.self, Idefics3Processor.init),
106135
]
136+
}
137+
138+
// Note: using NSLock as we have very small (just dictionary get/set)
139+
// critical sections and expect no contention. this allows the methods
140+
// to remain synchronous.
141+
private let lock = NSLock()
142+
143+
private var creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor]
107144

108145
/// Add a new model to the type registry.
109146
public func registerProcessorType(
@@ -140,12 +177,21 @@ public class ProcessorTypeRegistry: @unchecked Sendable {
140177
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
141178
/// implementation, if needed.
142179
public class ModelRegistry: @unchecked Sendable {
180+
/// Creates an empty registry.
181+
public init() {
182+
registry = Dictionary()
183+
}
184+
185+
/// Creates a new registry with from given model configurations.
186+
public init(modelConfigurations: [ModelConfiguration]) {
187+
registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
188+
}
189+
190+
/// Shared instance with default model configurations.
191+
public static let shared = ModelRegistry(modelConfigurations: all())
143192

144193
private let lock = NSLock()
145-
private var registry = Dictionary(
146-
uniqueKeysWithValues: all().map {
147-
($0.name, $0)
148-
})
194+
private var registry: [String: ModelConfiguration]
149195

150196
static public let paligemma3bMix448_8bit = ModelConfiguration(
151197
id: "mlx-community/paligemma-3b-mix-448-8bit",
@@ -166,6 +212,7 @@ public class ModelRegistry: @unchecked Sendable {
166212
[
167213
paligemma3bMix448_8bit,
168214
qwen2VL2BInstruct4Bit,
215+
smolvlminstruct4bit,
169216
]
170217
}
171218

@@ -205,16 +252,27 @@ public class ModelRegistry: @unchecked Sendable {
205252
/// ```
206253
public class VLMModelFactory: ModelFactory {
207254

208-
public static let shared = VLMModelFactory()
255+
public init(
256+
typeRegistry: ModelTypeRegistry, processorRegistry: ProcessorTypeRegistry,
257+
modelRegistry: ModelRegistry
258+
) {
259+
self.typeRegistry = typeRegistry
260+
self.processorRegistry = processorRegistry
261+
self.modelRegistry = modelRegistry
262+
}
263+
264+
/// Shared instance with default behavior.
265+
public static let shared = VLMModelFactory(
266+
typeRegistry: .shared, processorRegistry: .shared, modelRegistry: .shared)
209267

210268
/// registry of model type, e.g. configuration value `paligemma` -> configuration and init methods
211-
public let typeRegistry = ModelTypeRegistry()
269+
public let typeRegistry: ModelTypeRegistry
212270

213271
/// registry of input processor type, e.g. configuration value `PaliGemmaProcessor` -> configuration and init methods
214-
public let processorRegistry = ProcessorTypeRegistry()
272+
public let processorRegistry: ProcessorTypeRegistry
215273

216274
/// registry of model id to configuration, e.g. `mlx-community/paligemma-3b-mix-448-8bit`
217-
public let modelRegistry = ModelRegistry()
275+
public let modelRegistry: ModelRegistry
218276

219277
public func configuration(id: String) -> ModelConfiguration {
220278
modelRegistry.configuration(id: id)

0 commit comments

Comments
 (0)