Skip to content

Commit c8164a2

Browse files
authored
Move common Registry types to MLXLMCommon (#229)
* Move VLMTypeRegistry to MLXLMCommon * Move ModelRegistry to MLXLMCommon * Move ProcessorTypeRegistry to MLXLMCommon * Group registries * Rename ModelRegistry to AbstractModelRegistry * Add ModelRegistry as a typealias and deprecate it
1 parent f35df96 commit c8164a2

File tree

5 files changed

+162
-207
lines changed

5 files changed

+162
-207
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 11 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,10 @@ private func create<C: Codable, M>(
2020
/// Registry of model type, e.g 'llama', to functions that can instantiate the model from configuration.
2121
///
2222
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
23-
public class ModelTypeRegistry: @unchecked Sendable {
24-
25-
/// Creates an empty registry.
26-
public init() {
27-
self.creators = [:]
28-
}
29-
30-
/// Creates a registry with given creators.
31-
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
32-
self.creators = creators
33-
}
23+
public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
3424

3525
/// Shared instance with default model types.
36-
public static let shared: ModelTypeRegistry = .init(creators: all())
26+
public static let shared: LLMTypeRegistry = .init(creators: all())
3727

3828
/// All predefined model types.
3929
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
@@ -53,32 +43,6 @@ public class ModelTypeRegistry: @unchecked Sendable {
5343
]
5444
}
5545

56-
// Note: using NSLock as we have very small (just dictionary get/set)
57-
// critical sections and expect no contention. this allows the methods
58-
// to remain synchronous.
59-
private let lock = NSLock()
60-
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
61-
62-
/// Add a new model to the type registry.
63-
public func registerModelType(
64-
_ type: String, creator: @Sendable @escaping (URL) throws -> any LanguageModel
65-
) {
66-
lock.withLock {
67-
creators[type] = creator
68-
}
69-
}
70-
71-
/// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
72-
public func createModel(configuration: URL, modelType: String) throws -> LanguageModel {
73-
let creator = lock.withLock {
74-
creators[modelType]
75-
}
76-
guard let creator else {
77-
throw ModelFactoryError.unsupportedModelType(modelType)
78-
}
79-
return try creator(configuration)
80-
}
81-
8246
}
8347

8448
/// Registry of models and any overrides that go with them, e.g. prompt augmentation.
@@ -87,23 +51,10 @@ public class ModelTypeRegistry: @unchecked Sendable {
8751
/// The python tokenizers have a very rich set of implementations and configuration. The
8852
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
8953
/// implementation, if needed.
90-
public class ModelRegistry: @unchecked Sendable {
91-
92-
/// Creates an empty registry.
93-
public init() {
94-
self.registry = Dictionary()
95-
}
96-
97-
/// Creates a new registry with from given model configurations.
98-
public init(modelConfigurations: [ModelConfiguration]) {
99-
self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
100-
}
54+
public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
10155

10256
/// Shared instance with default model configurations.
103-
public static let shared = ModelRegistry(modelConfigurations: all())
104-
105-
private let lock = NSLock()
106-
private var registry: [String: ModelConfiguration]
57+
public static let shared = LLMRegistry(modelConfigurations: all())
10758

10859
static public let smolLM_135M_4bit = ModelConfiguration(
10960
id: "mlx-community/SmolLM-135M-Instruct-4bit",
@@ -239,31 +190,11 @@ public class ModelRegistry: @unchecked Sendable {
239190
]
240191
}
241192

242-
public func register(configurations: [ModelConfiguration]) {
243-
lock.withLock {
244-
for c in configurations {
245-
registry[c.name] = c
246-
}
247-
}
248-
}
249-
250-
public func configuration(id: String) -> ModelConfiguration {
251-
lock.withLock {
252-
if let c = registry[id] {
253-
return c
254-
} else {
255-
return ModelConfiguration(id: id)
256-
}
257-
}
258-
}
259-
260-
public var models: some Collection<ModelConfiguration> & Sendable {
261-
lock.withLock {
262-
return registry.values
263-
}
264-
}
265193
}
266194

195+
@available(*, deprecated, renamed: "LLMRegistry", message: "Please use LLMRegistry directly.")
196+
public typealias ModelRegistry = LLMRegistry
197+
267198
private struct LLMUserInputProcessor: UserInputProcessor {
268199

269200
let tokenizer: Tokenizer
@@ -304,19 +235,20 @@ private struct LLMUserInputProcessor: UserInputProcessor {
304235
/// ```
305236
public class LLMModelFactory: ModelFactory {
306237

307-
public init(typeRegistry: ModelTypeRegistry, modelRegistry: ModelRegistry) {
238+
public init(typeRegistry: ModelTypeRegistry, modelRegistry: AbstractModelRegistry) {
308239
self.typeRegistry = typeRegistry
309240
self.modelRegistry = modelRegistry
310241
}
311242

312243
/// Shared instance with default behavior.
313-
public static let shared = LLMModelFactory(typeRegistry: .shared, modelRegistry: .shared)
244+
public static let shared = LLMModelFactory(
245+
typeRegistry: LLMTypeRegistry.shared, modelRegistry: LLMRegistry.shared)
314246

315247
/// registry of model type, e.g. configuration value `llama` -> configuration and init methods
316248
public let typeRegistry: ModelTypeRegistry
317249

318250
/// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit`
319-
public let modelRegistry: ModelRegistry
251+
public let modelRegistry: AbstractModelRegistry
320252

321253
public func configuration(id: String) -> ModelConfiguration {
322254
modelRegistry.configuration(id: id)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import Foundation
4+
5+
open class AbstractModelRegistry: @unchecked Sendable {
6+
7+
/// Creates an empty registry.
8+
public init() {
9+
self.registry = Dictionary()
10+
}
11+
12+
/// Creates a new registry with from given model configurations.
13+
public init(modelConfigurations: [ModelConfiguration]) {
14+
self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
15+
}
16+
17+
private let lock = NSLock()
18+
private var registry: [String: ModelConfiguration]
19+
20+
public func register(configurations: [ModelConfiguration]) {
21+
lock.withLock {
22+
for c in configurations {
23+
registry[c.name] = c
24+
}
25+
}
26+
}
27+
28+
public func configuration(id: String) -> ModelConfiguration {
29+
lock.withLock {
30+
if let c = registry[id] {
31+
return c
32+
} else {
33+
return ModelConfiguration(id: id)
34+
}
35+
}
36+
}
37+
38+
public var models: some Collection<ModelConfiguration> & Sendable {
39+
lock.withLock {
40+
return registry.values
41+
}
42+
}
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import Foundation
4+
5+
open class ModelTypeRegistry: @unchecked Sendable {
6+
7+
/// Creates an empty registry.
8+
public init() {
9+
self.creators = [:]
10+
}
11+
12+
/// Creates a registry with given creators.
13+
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
14+
self.creators = creators
15+
}
16+
17+
// Note: using NSLock as we have very small (just dictionary get/set)
18+
// critical sections and expect no contention. this allows the methods
19+
// to remain synchronous.
20+
private let lock = NSLock()
21+
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
22+
23+
/// Add a new model to the type registry.
24+
public func registerModelType(
25+
_ type: String, creator: @Sendable @escaping (URL) throws -> any LanguageModel
26+
) {
27+
lock.withLock {
28+
creators[type] = creator
29+
}
30+
}
31+
32+
/// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
33+
public func createModel(configuration: URL, modelType: String) throws -> LanguageModel {
34+
let creator = lock.withLock {
35+
creators[modelType]
36+
}
37+
guard let creator else {
38+
throw ModelFactoryError.unsupportedModelType(modelType)
39+
}
40+
return try creator(configuration)
41+
}
42+
43+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import Foundation
4+
import Tokenizers
5+
6+
open class ProcessorTypeRegistry: @unchecked Sendable {
7+
8+
/// Creates an empty registry.
9+
public init() {
10+
self.creators = [:]
11+
}
12+
13+
/// Creates a registry with given creators.
14+
public init(creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor])
15+
{
16+
self.creators = creators
17+
}
18+
19+
// Note: using NSLock as we have very small (just dictionary get/set)
20+
// critical sections and expect no contention. this allows the methods
21+
// to remain synchronous.
22+
private let lock = NSLock()
23+
24+
private var creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor]
25+
26+
/// Add a new model to the type registry.
27+
public func registerProcessorType(
28+
_ type: String,
29+
creator: @Sendable @escaping (
30+
URL,
31+
any Tokenizer
32+
) throws -> any UserInputProcessor
33+
) {
34+
lock.withLock {
35+
creators[type] = creator
36+
}
37+
}
38+
39+
/// Given a `processorType` and configuration file instantiate a new `UserInputProcessor`.
40+
public func createModel(configuration: URL, processorType: String, tokenizer: any Tokenizer)
41+
throws -> any UserInputProcessor
42+
{
43+
let creator = lock.withLock {
44+
creators[processorType]
45+
}
46+
guard let creator else {
47+
throw ModelFactoryError.unsupportedProcessorType(processorType)
48+
}
49+
return try creator(configuration, tokenizer)
50+
}
51+
52+
}

0 commit comments

Comments
 (0)