8
8
import CoreML
9
9
import Tokenizers
10
10
import Generation
11
+ import Hub
11
12
12
13
public class LanguageModel {
13
14
public let model : MLModel
@@ -17,10 +18,15 @@ public class LanguageModel {
17
18
18
19
let input_ids = " input_ids "
19
20
let attention_mask = " attention_mask "
20
-
21
- lazy public var tokenizer : Tokenizer = {
22
- return architecture. tokenizerClass. init ( )
23
- } ( )
21
+
22
+ struct Configurations {
23
+ var modelConfig : Config
24
+ var tokenizerConfig : Config ?
25
+ var tokenizerData : Config
26
+ }
27
+
28
+ private var configPromise : Task < Configurations , Error > ? = nil
29
+ private var _tokenizer : Tokenizer ? = nil
24
30
25
31
public required init ( model: MLModel ) {
26
32
self . model = model
@@ -49,6 +55,10 @@ public class LanguageModel {
49
55
minContextLength = 128
50
56
maxContextLength = 128
51
57
}
58
+
59
+ self . configPromise = Task . init {
60
+ return try await self . loadConfig ( )
61
+ }
52
62
}
53
63
}
54
64
@@ -71,16 +81,7 @@ public extension LanguageModel {
71
81
guard let modelName = model. configuration. modelDisplayName else { fatalError ( " Models must have a name that identifies them " ) }
72
82
return modelName
73
83
}
74
-
75
- var architecture : Architecture {
76
- guard let architecture = Architecture . from ( modelName: modelName) else { fatalError ( " Cannot obtain model architecture " ) }
77
- return architecture
78
- }
79
84
80
- var padTokenId : Int ? { architecture. padTokenId ?? architecture. eosTokenId }
81
- var bosTokenId : Int ? { architecture. bosTokenId }
82
- var eosTokenId : Int ? { architecture. eosTokenId }
83
-
84
85
var inputIdsDescription : MLFeatureDescription {
85
86
model. modelDescription. inputDescriptionsByName [ input_ids] !
86
87
}
@@ -99,13 +100,13 @@ public extension LanguageModel {
99
100
}
100
101
101
102
// MLShapedArrayProtocol is either a MLShapedArray or a MLShapedArraySlice
102
- func predictNextTokenScores( _ tokens: InputTokens ) -> any MLShapedArrayProtocol {
103
+ func predictNextTokenScores( _ tokens: InputTokens , config : GenerationConfig ) -> any MLShapedArrayProtocol {
103
104
// TODO: exceptions
104
105
105
106
// Maybe pad or truncate
106
107
let maxTokens = min ( tokens. count, maxContextLength)
107
108
let padLength = maxTokens >= minContextLength ? 0 : minContextLength- maxTokens
108
- let inputTokens = Array ( tokens [ 0 ..< maxTokens] ) + Array( repeating: padTokenId ?? 0 , count: padLength)
109
+ let inputTokens = Array ( tokens [ 0 ..< maxTokens] ) + Array( repeating: config . padTokenId ?? 0 , count: padLength)
109
110
110
111
let inputIds = MLMultiArray . from ( inputTokens, dims: inputIdsShape. count)
111
112
var inputDictionary = [ inputIdsName: inputIds]
@@ -126,6 +127,100 @@ public extension LanguageModel {
126
127
}
127
128
}
128
129
130
+ extension LanguageModel {
131
+ func loadConfig( ) async throws -> Configurations {
132
+ // TODO: caching
133
+ async let modelConfig = try Hub . downloadConfig ( repoId: modelName, filename: " config.json " )
134
+ async let tokenizerConfig = try Hub . downloadConfig ( repoId: modelName, filename: " tokenizer_config.json " )
135
+ async let tokenizerVocab = try Hub . downloadConfig ( repoId: modelName, filename: " tokenizer.json " )
136
+
137
+ // Note tokenizerConfig may be nil (does not exist in all models)
138
+ let configs = await Configurations ( modelConfig: try modelConfig, tokenizerConfig: try ? tokenizerConfig, tokenizerData: try tokenizerVocab)
139
+ return configs
140
+ }
141
+ }
142
+
143
+ /// async properties downloaded from the configuration
144
+ public extension LanguageModel {
145
+ var modelConfig : Config {
146
+ get async throws {
147
+ try await configPromise!. value. modelConfig
148
+ }
149
+ }
150
+
151
+ var tokenizerConfig : Config ? {
152
+ get async throws {
153
+ try await configPromise!. value. tokenizerConfig
154
+ }
155
+ }
156
+
157
+ var tokenizerData : Config {
158
+ get async throws {
159
+ try await configPromise!. value. tokenizerData
160
+ }
161
+ }
162
+
163
+ var modelType : String ? {
164
+ get async throws {
165
+ try await modelConfig. modelType? . stringValue
166
+ }
167
+ }
168
+
169
+ var textGenerationParameters : Config ? {
170
+ get async throws {
171
+ try await modelConfig. taskSpecificParams? . textGeneration
172
+ }
173
+ }
174
+
175
+ var defaultDoSample : Bool {
176
+ get async throws {
177
+ try await textGenerationParameters? . doSample? . boolValue ?? true
178
+ }
179
+ }
180
+
181
+ var architecture : Architecture ? {
182
+ get async throws {
183
+ guard let modelType = try await modelType else { return nil }
184
+ return Architecture . from ( modelType: modelType)
185
+ }
186
+ }
187
+
188
+ var padTokenId : Int ? {
189
+ get async throws {
190
+ guard let architecture = try await architecture else { return nil }
191
+ return architecture. padTokenId ?? architecture. eosTokenId
192
+ }
193
+ }
194
+
195
+ var bosTokenId : Int ? {
196
+ get async throws {
197
+ let modelConfig = try await modelConfig
198
+ if let bosTokenId = modelConfig. bosTokenId? . intValue { return bosTokenId }
199
+ return try await architecture? . bosTokenId
200
+ }
201
+ }
202
+
203
+ var eosTokenId : Int ? {
204
+ get async throws {
205
+ let modelConfig = try await modelConfig
206
+ if let eosTokenId = modelConfig. eosTokenId? . intValue { return eosTokenId }
207
+ return try await architecture? . eosTokenId
208
+ }
209
+ }
210
+
211
+ var tokenizer : Tokenizer {
212
+ get async throws {
213
+ guard _tokenizer == nil else { return _tokenizer! }
214
+ guard let architecture = try await architecture else { throw " Cannot retrieve Tokenizer " }
215
+ let tokenizerData = try await tokenizerData
216
+ guard let vocab = tokenizerData. model? . vocab? . dictionary as? [ String : Int ] else { throw " Cannot find vocab in tokenizer JSON " }
217
+ let merges = tokenizerData. model? . merges? . value as? [ String ]
218
+ _tokenizer = architecture. tokenizerClass. init ( vocab: vocab, merges: merges)
219
+ return _tokenizer!
220
+ }
221
+ }
222
+ }
223
+
129
224
extension LanguageModel : TextGenerationModel {
130
225
//TODO: retrieve from the json: https://huggingface.co/nlpcloud/instruct-gpt-j-fp16/blob/main/config.json#L26
131
226
public var defaultGenerationConfig : GenerationConfig {
@@ -139,3 +234,5 @@ extension LanguageModel: TextGenerationModel {
139
234
return config
140
235
}
141
236
}
237
+
238
+ extension String : Error { }
0 commit comments