Skip to content

Commit eb7ccf5

Browse files
authored
Merge pull request #1 from huggingface/hub-tokenizers
Download tokenizers data from the Hub
2 parents 209f0c5 + cbde4aa commit eb7ccf5

15 files changed

+371
-67
lines changed

Package.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ let package = Package(
1010
.library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]),
1111
],
1212
targets: [
13-
.target(name: "Tokenizers", resources: [.process("Vocabs")]),
13+
.target(name: "Hub"),
14+
.target(name: "Tokenizers", dependencies: ["Hub"]),
1415
.target(name: "TensorUtils"),
1516
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
1617
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
17-
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers"], resources: [.process("Resources")]),
18+
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers"], resources: [.process("Resources"), .process("Vocabs")]),
19+
.testTarget(name: "HubTests", dependencies: ["Hub"]),
1820
]
1921
)

Sources/Generation/Generation.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public typealias InputTokens = [Int]
2222
public typealias GenerationOutput = [Int]
2323

2424
/// A callable (a model, usually), that predicts the next token after a given sequence
25-
public typealias NextTokenModel = (InputTokens) -> any MLShapedArrayProtocol
25+
public typealias NextTokenModel = (InputTokens, GenerationConfig) -> any MLShapedArrayProtocol
2626

2727
public typealias PredictionTokensCallback = (GenerationOutput) -> Void
2828
public typealias PredictionStringCallback = (String) -> Void
@@ -40,7 +40,7 @@ public extension Generation {
4040
// TODO: additional stopping criteria
4141
var outputTokens = tokens
4242
while outputTokens.count < config.maxLength {
43-
let logits = model(outputTokens)
43+
let logits = model(outputTokens, config)
4444
let (nextToken, _) = Math.argmax(logits)
4545
if nextToken == config.eosTokenId { break }
4646
outputTokens.append(nextToken)
@@ -55,7 +55,7 @@ public extension Generation {
5555
// TODO: additional stopping criteria
5656
var outputTokens = tokens
5757
while outputTokens.count < config.maxLength {
58-
let outputs = model(outputTokens)
58+
let outputs = model(outputTokens, config)
5959

6060
/// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
6161
var logits = (outputs as? MLShapedArraySlice<Float>)?.floats ?? outputs.scalars as! [Float]

Sources/Hub/Hub.swift

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
//
2+
// Hub.swift
3+
//
4+
//
5+
// Created by Pedro Cuenca on 18/5/23.
6+
//
7+
8+
import Foundation
9+
10+
public struct Hub {}
11+
12+
public extension Hub {
13+
enum HubClientError: Error {
14+
case download
15+
case parse
16+
}
17+
18+
static func download(url: URL) async throws -> Data {
19+
let (data, _) = try await URLSession.shared.data(from: url)
20+
return data
21+
}
22+
23+
static func download(url: String) async throws -> Data {
24+
guard let realUrl = URL(string: url) else { throw HubClientError.download }
25+
let (data, _) = try await URLSession.shared.data(from: realUrl)
26+
return data
27+
}
28+
29+
/// Downloads file from the given repo, and JSON-decodes it
30+
/// Returns a `Config` (just a dictionary wrapper) as I'm not sure we can use the same object structure for all tokenizers or models
31+
static func downloadConfig(repoId: String, filename: String) async throws -> Config {
32+
let url = "https://huggingface.co/\(repoId)/resolve/main/\(filename)"
33+
let data = try await download(url: url)
34+
35+
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
36+
guard let dictionary = parsed as? [String: Any] else { throw HubClientError.parse }
37+
return Config(dictionary)
38+
}
39+
}
40+
41+
@dynamicMemberLookup
42+
public struct Config {
43+
public private(set) var dictionary: [String: Any]
44+
45+
init(_ dictionary: [String: Any]) {
46+
self.dictionary = dictionary
47+
}
48+
49+
func camelCase(_ string: String) -> String {
50+
return string
51+
.split(separator: "_")
52+
.enumerated()
53+
.map { $0.offset == 0 ? $0.element.lowercased() : $0.element.capitalized }
54+
.joined()
55+
}
56+
57+
func uncamelCase(_ string: String) -> String {
58+
let scalars = string.unicodeScalars
59+
var result = ""
60+
61+
var previousCharacterIsLowercase = false
62+
for scalar in scalars {
63+
if CharacterSet.uppercaseLetters.contains(scalar) {
64+
if previousCharacterIsLowercase {
65+
result += "_"
66+
}
67+
let lowercaseChar = Character(scalar).lowercased()
68+
result += lowercaseChar
69+
previousCharacterIsLowercase = false
70+
} else {
71+
result += String(scalar)
72+
previousCharacterIsLowercase = true
73+
}
74+
}
75+
76+
return result
77+
}
78+
79+
80+
public subscript(dynamicMember member: String) -> Config? {
81+
let key = dictionary[member] != nil ? member : uncamelCase(member)
82+
if let value = dictionary[key] as? [String: Any] {
83+
return Config(value)
84+
} else if let value = dictionary[key] {
85+
return Config(["value": value])
86+
}
87+
return nil
88+
}
89+
90+
public var value: Any? {
91+
return dictionary["value"]
92+
}
93+
94+
public var intValue: Int? { value as? Int }
95+
public var boolValue: Bool? { value as? Bool }
96+
public var stringValue: String? { value as? String }
97+
}

Sources/Models/LanguageModel.swift

+112-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import CoreML
99
import Tokenizers
1010
import Generation
11+
import Hub
1112

1213
public class LanguageModel {
1314
public let model: MLModel
@@ -17,10 +18,15 @@ public class LanguageModel {
1718

1819
let input_ids = "input_ids"
1920
let attention_mask = "attention_mask"
20-
21-
lazy public var tokenizer: Tokenizer = {
22-
return architecture.tokenizerClass.init()
23-
}()
21+
22+
struct Configurations {
23+
var modelConfig: Config
24+
var tokenizerConfig: Config?
25+
var tokenizerData: Config
26+
}
27+
28+
private var configPromise: Task<Configurations, Error>? = nil
29+
private var _tokenizer: Tokenizer? = nil
2430

2531
public required init(model: MLModel) {
2632
self.model = model
@@ -49,6 +55,10 @@ public class LanguageModel {
4955
minContextLength = 128
5056
maxContextLength = 128
5157
}
58+
59+
self.configPromise = Task.init {
60+
return try await self.loadConfig()
61+
}
5262
}
5363
}
5464

@@ -71,16 +81,7 @@ public extension LanguageModel {
7181
guard let modelName = model.configuration.modelDisplayName else { fatalError("Models must have a name that identifies them") }
7282
return modelName
7383
}
74-
75-
var architecture: Architecture {
76-
guard let architecture = Architecture.from(modelName: modelName) else { fatalError("Cannot obtain model architecture") }
77-
return architecture
78-
}
7984

80-
var padTokenId: Int? { architecture.padTokenId ?? architecture.eosTokenId }
81-
var bosTokenId: Int? { architecture.bosTokenId }
82-
var eosTokenId: Int? { architecture.eosTokenId }
83-
8485
var inputIdsDescription: MLFeatureDescription {
8586
model.modelDescription.inputDescriptionsByName[input_ids]!
8687
}
@@ -99,13 +100,13 @@ public extension LanguageModel {
99100
}
100101

101102
// MLShapedArrayProtocol is either a MLShapedArray or a MLShapedArraySlice
102-
func predictNextTokenScores(_ tokens: InputTokens) -> any MLShapedArrayProtocol {
103+
func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol {
103104
// TODO: exceptions
104105

105106
// Maybe pad or truncate
106107
let maxTokens = min(tokens.count, maxContextLength)
107108
let padLength = maxTokens >= minContextLength ? 0 : minContextLength-maxTokens
108-
let inputTokens = Array(tokens[0..<maxTokens]) + Array(repeating: padTokenId ?? 0, count: padLength)
109+
let inputTokens = Array(tokens[0..<maxTokens]) + Array(repeating: config.padTokenId ?? 0, count: padLength)
109110

110111
let inputIds = MLMultiArray.from(inputTokens, dims: inputIdsShape.count)
111112
var inputDictionary = [inputIdsName: inputIds]
@@ -126,6 +127,100 @@ public extension LanguageModel {
126127
}
127128
}
128129

130+
extension LanguageModel {
131+
func loadConfig() async throws -> Configurations {
132+
// TODO: caching
133+
async let modelConfig = try Hub.downloadConfig(repoId: modelName, filename: "config.json")
134+
async let tokenizerConfig = try Hub.downloadConfig(repoId: modelName, filename: "tokenizer_config.json")
135+
async let tokenizerVocab = try Hub.downloadConfig(repoId: modelName, filename: "tokenizer.json")
136+
137+
// Note tokenizerConfig may be nil (does not exist in all models)
138+
let configs = await Configurations(modelConfig: try modelConfig, tokenizerConfig: try? tokenizerConfig, tokenizerData: try tokenizerVocab)
139+
return configs
140+
}
141+
}
142+
143+
/// async properties downloaded from the configuration
144+
public extension LanguageModel {
145+
var modelConfig: Config {
146+
get async throws {
147+
try await configPromise!.value.modelConfig
148+
}
149+
}
150+
151+
var tokenizerConfig: Config? {
152+
get async throws {
153+
try await configPromise!.value.tokenizerConfig
154+
}
155+
}
156+
157+
var tokenizerData: Config {
158+
get async throws {
159+
try await configPromise!.value.tokenizerData
160+
}
161+
}
162+
163+
var modelType: String? {
164+
get async throws {
165+
try await modelConfig.modelType?.stringValue
166+
}
167+
}
168+
169+
var textGenerationParameters: Config? {
170+
get async throws {
171+
try await modelConfig.taskSpecificParams?.textGeneration
172+
}
173+
}
174+
175+
var defaultDoSample: Bool {
176+
get async throws {
177+
try await textGenerationParameters?.doSample?.boolValue ?? true
178+
}
179+
}
180+
181+
var architecture: Architecture? {
182+
get async throws {
183+
guard let modelType = try await modelType else { return nil }
184+
return Architecture.from(modelType: modelType)
185+
}
186+
}
187+
188+
var padTokenId: Int? {
189+
get async throws {
190+
guard let architecture = try await architecture else { return nil }
191+
return architecture.padTokenId ?? architecture.eosTokenId
192+
}
193+
}
194+
195+
var bosTokenId: Int? {
196+
get async throws {
197+
let modelConfig = try await modelConfig
198+
if let bosTokenId = modelConfig.bosTokenId?.intValue { return bosTokenId }
199+
return try await architecture?.bosTokenId
200+
}
201+
}
202+
203+
var eosTokenId: Int? {
204+
get async throws {
205+
let modelConfig = try await modelConfig
206+
if let eosTokenId = modelConfig.eosTokenId?.intValue { return eosTokenId }
207+
return try await architecture?.eosTokenId
208+
}
209+
}
210+
211+
var tokenizer: Tokenizer {
212+
get async throws {
213+
guard _tokenizer == nil else { return _tokenizer! }
214+
guard let architecture = try await architecture else { throw "Cannot retrieve Tokenizer" }
215+
let tokenizerData = try await tokenizerData
216+
guard let vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else { throw "Cannot find vocab in tokenizer JSON" }
217+
let merges = tokenizerData.model?.merges?.value as? [String]
218+
_tokenizer = architecture.tokenizerClass.init(vocab: vocab, merges: merges)
219+
return _tokenizer!
220+
}
221+
}
222+
}
223+
129224
extension LanguageModel: TextGenerationModel {
130225
//TODO: retrieve from the json: https://huggingface.co/nlpcloud/instruct-gpt-j-fp16/blob/main/config.json#L26
131226
public var defaultGenerationConfig: GenerationConfig {
@@ -139,3 +234,5 @@ extension LanguageModel: TextGenerationModel {
139234
return config
140235
}
141236
}
237+
238+
extension String: Error {}

Sources/Models/LanguageModelTypes.swift

+8-8
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,29 @@ public protocol LanguageModelProtocol {
1313
/// `name_or_path` in the Python world
1414
var modelName: String { get }
1515

16-
var tokenizer: Tokenizer { get }
16+
var tokenizer: Tokenizer { get async throws }
1717
var model: MLModel { get }
1818

1919
init(model: MLModel)
2020

2121
/// Make prediction callable (this works like __call__ in Python)
22-
func predictNextTokenScores(_ tokens: InputTokens) -> any MLShapedArrayProtocol //MLShapedArray<Float>
23-
func callAsFunction(_ tokens: InputTokens) -> any MLShapedArrayProtocol //MLShapedArray<Float>
22+
func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol //MLShapedArray<Float>
23+
func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol //MLShapedArray<Float>
2424
}
2525

2626
public extension LanguageModelProtocol {
27-
func callAsFunction(_ tokens: InputTokens) -> any MLShapedArrayProtocol {
28-
predictNextTokenScores(tokens)
27+
func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol {
28+
predictNextTokenScores(tokens, config: config)
2929
}
3030
}
3131

3232
public protocol TextGenerationModel: Generation, LanguageModelProtocol {
3333
var defaultGenerationConfig: GenerationConfig { get }
34-
func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback?) async -> String
34+
func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback?) async throws -> String
3535
}
3636

3737
public extension TextGenerationModel {
38-
func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async -> String {
39-
await self.generate(config: config, prompt: prompt, model: self.callAsFunction(_:), tokenizer: self.tokenizer, callback: callback)
38+
func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async throws -> String {
39+
try await self.generate(config: config, prompt: prompt, model: self.callAsFunction, tokenizer: self.tokenizer, callback: callback)
4040
}
4141
}

Sources/Tokenizers/Architecture.swift

+9
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,13 @@ extension Architecture {
4545
}
4646
return nil
4747
}
48+
49+
public static func from(modelType: String) -> Architecture? {
50+
for arch in SupportedArchitecture.allCases {
51+
if modelType.contains(arch.rawValue) {
52+
return arch.architecture
53+
}
54+
}
55+
return nil
56+
}
4857
}

Sources/Tokenizers/BertTokenizer.swift

+2-11
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,9 @@ class BertTokenizer {
2020
private let vocab: [String: Int]
2121
private let ids_to_tokens: [Int: String]
2222

23-
required init() {
24-
let url = Bundle.module.url(forResource: "bert-vocab", withExtension: "txt")!
25-
let vocabTxt = try! String(contentsOf: url)
26-
let tokens = vocabTxt.split(separator: "\n").map { String($0) }
27-
var vocab: [String: Int] = [:]
28-
var ids_to_tokens: [Int: String] = [:]
29-
for (i, token) in tokens.enumerated() {
30-
vocab[token] = i
31-
ids_to_tokens[i] = token
32-
}
23+
required init(vocab: [String: Int], merges: [String]?) {
3324
self.vocab = vocab
34-
self.ids_to_tokens = ids_to_tokens
25+
self.ids_to_tokens = Utils.invert(vocab)
3526
self.wordpieceTokenizer = WordpieceTokenizer(vocab: self.vocab)
3627
}
3728

0 commit comments

Comments
 (0)