diff --git a/Package.swift b/Package.swift index de98aa6..8d90a52 100644 --- a/Package.swift +++ b/Package.swift @@ -29,8 +29,8 @@ let package = Package( .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]), .testTarget(name: "HubTests", dependencies: ["Hub"]), .testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]), - .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]), + .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]), - .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]) + .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]), ] ) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index e46a23c..457755a 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -24,7 +24,7 @@ public class LanguageModel { var tokenizerConfig: Config? var tokenizerData: Config } - + private var configuration: LanguageModelConfigurationFromHub? = nil private var _tokenizer: Tokenizer? = nil diff --git a/Sources/TensorUtils/Weights.swift b/Sources/TensorUtils/Weights.swift new file mode 100644 index 0000000..2050e01 --- /dev/null +++ b/Sources/TensorUtils/Weights.swift @@ -0,0 +1,88 @@ +import CoreML + + +public struct Weights { + + enum WeightsError: Error { + case notSupported(message: String) + case invalidFile + } + + private let dictionary: [String: MLMultiArray] + + init(_ dictionary: [String: MLMultiArray]) { + self.dictionary = dictionary + } + + subscript(key: String) -> MLMultiArray? { dictionary[key] } + + public static func from(fileURL: URL) throws -> Weights { + guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension) + else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") } + + let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) + switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) { + case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf")) + case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") + default: return try Safetensor.from(data: data) + } + } +} + +struct Safetensor { + + typealias Error = Weights.WeightsError + + struct Header { + + struct Offset: Decodable { + let dataOffsets: [Int]? + let dtype: String? + let shape: [Int]? + + /// Unsupported: "I8", "U8", "I16", "U16", "BF16" + var dataType: MLMultiArrayDataType? { + get throws { + switch dtype { + case "I32", "U32": .int32 + case "F16": .float16 + case "F32": .float32 + case "F64", "U64": .float64 + default: throw Error.notSupported(message: "\(dtype ?? "empty")") + } + } + } + } + + static func from(data: Data) throws -> [String: Offset?] { + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + return try decoder.decode([String: Offset?].self, from: data) + } + } + + static func from(data: Data) throws -> Weights { + let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) }) + guard headerSize < data.count else { throw Error.invalidFile } + let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8))) + + var dict = [String: MLMultiArray]() + for (key, point) in header { + guard let offsets = point?.dataOffsets, offsets.count == 2, + let shape = point?.shape as? [NSNumber], + let dType = try point?.dataType + else { continue } + + let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in + acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) + } + let start = 8 + offsets[0] + headerSize + let end = 8 + offsets[1] + headerSize + let tensorData = data.subdata(in: start..