Skip to content

Commit d656ad4

Browse files
authored
Read model weights (#91)
* Read weights from safetensors * Check file extension * Deintegrate Safetensor * Separate Safetensor from the weights * Rename test tensors to include type * Rename ModelWeights to Weights * Throw error for unsupported data types * Remove model weights from LanguageModel.Configurations * Move Weights to TensorUtils * Specify filenames to download in tests. * Make the weights optional and public Enable safe access to keys.
1 parent 37e234e commit d656ad4

8 files changed

+195
-3
lines changed

Package.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ let package = Package(
3030
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]),
3131
.testTarget(name: "HubTests", dependencies: ["Hub"]),
3232
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
33-
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]),
33+
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]),
3434
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
35-
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"])
35+
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]),
3636
]
3737
)

Sources/Models/LanguageModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public class LanguageModel {
2424
var tokenizerConfig: Config?
2525
var tokenizerData: Config
2626
}
27-
27+
2828
private var configuration: LanguageModelConfigurationFromHub? = nil
2929
private var _tokenizer: Tokenizer? = nil
3030

Sources/TensorUtils/Weights.swift

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import CoreML
2+
3+
4+
public struct Weights {
5+
6+
enum WeightsError: Error {
7+
case notSupported(message: String)
8+
case invalidFile
9+
}
10+
11+
private let dictionary: [String: MLMultiArray]
12+
13+
init(_ dictionary: [String: MLMultiArray]) {
14+
self.dictionary = dictionary
15+
}
16+
17+
subscript(key: String) -> MLMultiArray? { dictionary[key] }
18+
19+
public static func from(fileURL: URL) throws -> Weights {
20+
guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension)
21+
else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") }
22+
23+
let data = try Data(contentsOf: fileURL, options: .mappedIfSafe)
24+
switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) {
25+
case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf"))
26+
case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx")
27+
default: return try Safetensor.from(data: data)
28+
}
29+
}
30+
}
31+
32+
struct Safetensor {
33+
34+
typealias Error = Weights.WeightsError
35+
36+
struct Header {
37+
38+
struct Offset: Decodable {
39+
let dataOffsets: [Int]?
40+
let dtype: String?
41+
let shape: [Int]?
42+
43+
/// Unsupported: "I8", "U8", "I16", "U16", "BF16"
44+
var dataType: MLMultiArrayDataType? {
45+
get throws {
46+
switch dtype {
47+
case "I32", "U32": .int32
48+
case "F16": .float16
49+
case "F32": .float32
50+
case "F64", "U64": .float64
51+
default: throw Error.notSupported(message: "\(dtype ?? "empty")")
52+
}
53+
}
54+
}
55+
}
56+
57+
static func from(data: Data) throws -> [String: Offset?] {
58+
let decoder = JSONDecoder()
59+
decoder.keyDecodingStrategy = .convertFromSnakeCase
60+
return try decoder.decode([String: Offset?].self, from: data)
61+
}
62+
}
63+
64+
static func from(data: Data) throws -> Weights {
65+
let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) })
66+
guard headerSize < data.count else { throw Error.invalidFile }
67+
let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8)))
68+
69+
var dict = [String: MLMultiArray]()
70+
for (key, point) in header {
71+
guard let offsets = point?.dataOffsets, offsets.count == 2,
72+
let shape = point?.shape as? [NSNumber],
73+
let dType = try point?.dataType
74+
else { continue }
75+
76+
let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in
77+
acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0)
78+
}
79+
let start = 8 + offsets[0] + headerSize
80+
let end = 8 + offsets[1] + headerSize
81+
let tensorData = data.subdata(in: start..<end) as NSData
82+
let ptr = UnsafeMutableRawPointer(mutating: tensorData.bytes)
83+
dict[key] = try MLMultiArray(dataPointer: ptr, shape: shape, dataType: dType, strides: strides)
84+
}
85+
86+
return Weights(dict)
87+
}
88+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
@testable import TensorUtils
2+
@testable import Hub
3+
import XCTest
4+
5+
class WeightsTests: XCTestCase {
6+
7+
let downloadDestination: URL = {
8+
FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests")
9+
}()
10+
11+
var hubApi: HubApi { HubApi(downloadBase: downloadDestination) }
12+
13+
func testLoadWeightsFromFileURL() async throws {
14+
let repo = "google/bert_uncased_L-2_H-128_A-2"
15+
let modelDir = try await hubApi.snapshot(from: repo, matching: ["config.json", "model.safetensors"])
16+
17+
let files = try FileManager.default.contentsOfDirectory(at: modelDir, includingPropertiesForKeys: [.isReadableKey])
18+
XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "config.json" }))
19+
XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "model.safetensors" }))
20+
21+
let modelFile = modelDir.appending(path: "/model.safetensors")
22+
let weights = try Weights.from(fileURL: modelFile)
23+
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.dataType, .float32)
24+
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.count, 128)
25+
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.shape.count, 1)
26+
27+
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.dataType, .float32)
28+
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.count, 3906816)
29+
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.shape.count, 2)
30+
31+
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[0, 0]].floatValue, -0.0041, accuracy: 1e-3)
32+
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[3, 4]].floatValue, 0.0037, accuracy: 1e-3)
33+
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[5, 3]].floatValue, -0.5371, accuracy: 1e-3)
34+
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[7, 8]].floatValue, 0.0460, accuracy: 1e-3)
35+
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[11, 7]].floatValue, -0.0058, accuracy: 1e-3)
36+
}
37+
38+
func testSafetensorReadTensor1D() throws {
39+
let modelFile = Bundle.module.url(forResource: "tensor-1d-int32", withExtension: "safetensors")!
40+
let weights: Weights = try Weights.from(fileURL: modelFile)
41+
let tensor = weights["embedding"]!
42+
XCTAssertEqual(tensor.dataType, .int32)
43+
XCTAssertEqual(tensor[[0]], 1)
44+
XCTAssertEqual(tensor[[1]], 2)
45+
XCTAssertEqual(tensor[[2]], 3)
46+
}
47+
48+
func testSafetensorReadTensor2D() throws {
49+
let modelFile = Bundle.module.url(forResource: "tensor-2d-float64", withExtension: "safetensors")!
50+
let weights: Weights = try Weights.from(fileURL: modelFile)
51+
let tensor = weights["embedding"]!
52+
XCTAssertEqual(tensor.dataType, .float64)
53+
XCTAssertEqual(tensor[[0, 0]], 1)
54+
XCTAssertEqual(tensor[[0, 1]], 2)
55+
XCTAssertEqual(tensor[[0, 2]], 3)
56+
XCTAssertEqual(tensor[[1, 0]], 24)
57+
XCTAssertEqual(tensor[[1, 1]], 25)
58+
XCTAssertEqual(tensor[[1, 2]], 26)
59+
}
60+
61+
func testSafetensorReadTensor3D() throws {
62+
let modelFile = Bundle.module.url(forResource: "tensor-3d-float32", withExtension: "safetensors")!
63+
let weights: Weights = try Weights.from(fileURL: modelFile)
64+
let tensor = weights["embedding"]!
65+
XCTAssertEqual(tensor.dataType, .float32)
66+
XCTAssertEqual(tensor[[0, 0, 0]], 22)
67+
XCTAssertEqual(tensor[[0, 0, 1]], 23)
68+
XCTAssertEqual(tensor[[0, 0, 2]], 24)
69+
XCTAssertEqual(tensor[[0, 1, 0]], 11)
70+
XCTAssertEqual(tensor[[0, 1, 1]], 12)
71+
XCTAssertEqual(tensor[[0, 1, 2]], 13)
72+
XCTAssertEqual(tensor[[1, 0, 0]], 2)
73+
XCTAssertEqual(tensor[[1, 0, 1]], 3)
74+
XCTAssertEqual(tensor[[1, 0, 2]], 4)
75+
XCTAssertEqual(tensor[[1, 1, 0]], 1)
76+
XCTAssertEqual(tensor[[1, 1, 1]], 2)
77+
XCTAssertEqual(tensor[[1, 1, 2]], 3)
78+
}
79+
80+
func testSafetensorReadTensor4D() throws {
81+
let modelFile = Bundle.module.url(forResource: "tensor-4d-float32", withExtension: "safetensors")!
82+
let weights: Weights = try Weights.from(fileURL: modelFile)
83+
let tensor = weights["embedding"]!
84+
XCTAssertEqual(tensor.dataType, .float32)
85+
XCTAssertEqual(tensor[[0, 0, 0, 0]], 11)
86+
XCTAssertEqual(tensor[[0, 0, 0, 1]], 12)
87+
XCTAssertEqual(tensor[[0, 0, 0, 2]], 13)
88+
XCTAssertEqual(tensor[[0, 0, 1, 0]], 1)
89+
XCTAssertEqual(tensor[[0, 0, 1, 1]], 2)
90+
XCTAssertEqual(tensor[[0, 0, 1, 2]], 3)
91+
XCTAssertEqual(tensor[[0, 0, 2, 0]], 4)
92+
XCTAssertEqual(tensor[[0, 0, 2, 1]], 5)
93+
XCTAssertEqual(tensor[[0, 0, 2, 2]], 6)
94+
XCTAssertEqual(tensor[[1, 0, 0, 0]], 22)
95+
XCTAssertEqual(tensor[[1, 0, 0, 1]], 23)
96+
XCTAssertEqual(tensor[[1, 0, 0, 2]], 24)
97+
XCTAssertEqual(tensor[[1, 0, 1, 0]], 15)
98+
XCTAssertEqual(tensor[[1, 0, 1, 1]], 16)
99+
XCTAssertEqual(tensor[[1, 0, 1, 2]], 17)
100+
XCTAssertEqual(tensor[[1, 0, 2, 0]], 17)
101+
XCTAssertEqual(tensor[[1, 0, 2, 1]], 18)
102+
XCTAssertEqual(tensor[[1, 0, 2, 2]], 19)
103+
}
104+
}

0 commit comments

Comments
 (0)