Skip to content

Commit 67d6c3d

Browse files
committed
BERT Embedding model
Create layers with weights
1 parent ac4d917 commit 67d6c3d

File tree

2 files changed

+137
-125
lines changed

2 files changed

+137
-125
lines changed

Package.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@ let package = Package(
3333
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]),
3434
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
3535
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]),
36-
.testTarget(name: "EmbeddingTests", dependencies: ["Embedding", "Tokenizers", "Hub"])
36+
.testTarget(name: "EmbeddingTests", dependencies: ["Embedding", "Tokenizers", "Hub"], resources: [.process("Resources"), .process("Vocabs")])
3737
]
3838
)

Sources/Embedding/Embedding.swift

+136-124
Original file line numberDiff line numberDiff line change
@@ -4,139 +4,151 @@ import CoreML
44
import Accelerate
55

66

7-
public protocol Embedding {}
7+
class BERTEmbedding {
88

9-
public struct AutoEmbedding {} // Otherwise AutoModel
10-
11-
extension AutoEmbedding {
12-
public static func from(pretrained model: String, hubApi: HubApi = .shared) async throws -> Embedding {
13-
return try await BGEM3Model(repoName: model, hubApi: hubApi)
14-
}
15-
}
16-
17-
class BERTEmbedding: Embedding { // Otherwise BERTModel
18-
private let wordEmbedding: BNNS.EmbeddingLayer
19-
private let positionEmbedding: BNNS.EmbeddingLayer
20-
private let tokenTypeEmbedding: BNNS.EmbeddingLayer
21-
private let normalization: BNNS.NormalizationLayer
22-
private let dropout: BNNS.DropoutLayer
23-
24-
private let positionEmbeddingType = "absolute"
25-
26-
init(repoName: String) { fatalError() }
27-
28-
public func callAsFunction(inputIds: MLMultiArray? = nil,
29-
tokenTypeIDs: MLMultiArray? = nil,
30-
positionIDs: MLMultiArray? = nil,
31-
inputEmbeds: MLMultiArray? = nil,
32-
pastKeyValuesLength: Int = 0) -> MLMultiArray {
33-
fatalError()
34-
}
35-
}
36-
37-
class BGEM3Model: Embedding {
38-
39-
struct Output {
40-
let lastHidddenState: MLMultiArray // batchSize, sequenceLength, hiddenSize
41-
let hiddenStates: MLMultiArray?
42-
let attentions: MLMultiArray?
43-
44-
let loss: MLMultiArray?
45-
let scores: MLMultiArray?
46-
let pReps: MLMultiArray?
47-
let qReps: MLMultiArray?
48-
}
49-
50-
let withSparse = false
51-
let withDense = true
52-
let withColbert = false
53-
54-
let shouldNormalize = false
55-
// let poolingMethod = "cls"
56-
// let negativesCrossDevice = false
57-
// let temperature = 1.0
58-
// let enableSubBatch = true
59-
// let unifiedFinetuning = true
60-
// let useSelfDistill = false
61-
// let colbertDim: Int? = nil
62-
// let selfDistillStartStep: Int? = nil
63-
64-
private let tokenizer: Tokenizer
65-
private let denseLayer: BNNS.FullyConnectedLayer
66-
private let sparseLayer: BNNS.FullyConnectedLayer
67-
private let colbertLayer: BNNS.FullyConnectedLayer
68-
69-
init(repoName: String, hubApi: HubApi) async throws {
70-
let config = LanguageModelConfigurationFromHub(modelName: repoName)
71-
self.tokenizer = try await AutoTokenizer.from(pretrained: repoName, hubApi: hubApi)
72-
73-
let hiddenSize = try await config.modelConfig.hiddenSize?.intValue ?? 384
74-
let colbertDim: Int? = nil
75-
let denseInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
76-
let denseOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(colbertDim ?? hiddenSize, stride: 2))
77-
let denseWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
78-
self.denseLayer = BNNS.FullyConnectedLayer(input: denseInput, output: denseOutput, weights: denseWeights, bias: nil, activation: .identity)!
9+
typealias Weights = [String: MLMultiArray]
10+
11+
var shape: [NSNumber] {[
12+
NSNumber(value: maxPositionEmbeddings),
13+
NSNumber(value: hiddenSize),
14+
]}
15+
16+
private let weights: Weights
17+
18+
private let positionEmbeddingType: String
19+
private let hiddenSize: Int
20+
private let vocabSize: Int
21+
private let maxPositionEmbeddings: Int
22+
private let typeVocabSize: Int
23+
private let padTokenID: Int
24+
private let normalizationEpsilon: Float
25+
private let dropoutRate: Float = 1e-1
26+
private let hiddenActivation: BNNS.ActivationFunction = .geluApproximation2(alpha: 1e-1, beta: 1e-1)
27+
28+
private var allocations: [BNNSNDArrayDescriptor] = []
29+
30+
private lazy var wordEmbedding: BNNS.EmbeddingLayer = {
31+
let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int64.self, shape: .vector(maxPositionEmbeddings))
32+
allocations.append(input)
33+
let dictData: [Float32] = weights["bert.embeddings.word_embeddings.weight"]!.toArray()
34+
let dict = BNNSNDArrayDescriptor.allocate(initializingFrom: dictData, shape: .matrixColumnMajor(hiddenSize, vocabSize))
35+
allocations.append(dict)
36+
let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
37+
allocations.append(output)
7938

80-
let sparseInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
81-
let sparseOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(1, stride: 2))
82-
let sparseWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
83-
self.sparseLayer = BNNS.FullyConnectedLayer(input: sparseInput, output: sparseOutput, weights: sparseWeights, bias: nil, activation: .identity)!
39+
return BNNS.EmbeddingLayer(input: input, output: output, dictionary: dict, paddingIndex: 0, maximumNorm: 0, normType: .l2, scalesGradientByFrequency: false)!
40+
}()
41+
42+
private lazy var positionEmbedding: BNNS.EmbeddingLayer = {
43+
let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int64.self, shape: .vector(maxPositionEmbeddings))
44+
allocations.append(input)
45+
let dictData: [Float32] = weights["bert.embeddings.position_embeddings.weight"]!.toArray()
46+
let dict = BNNSNDArrayDescriptor.allocate(initializingFrom: dictData, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
47+
allocations.append(dict)
48+
let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
49+
allocations.append(output)
50+
51+
return BNNS.EmbeddingLayer(input: input, output: output, dictionary: dict, paddingIndex: -1, maximumNorm: 0, normType: .l2, scalesGradientByFrequency: true)!
52+
}()
53+
54+
private lazy var tokenTypeEmbedding: BNNS.EmbeddingLayer = {
55+
let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int64.self, shape: .vector(maxPositionEmbeddings))
56+
allocations.append(input)
57+
let dictData: [Float32] = weights["bert.embeddings.token_type_embeddings.weight"]!.toArray()
58+
let dict = BNNSNDArrayDescriptor.allocate(initializingFrom: dictData, shape: .matrixColumnMajor(hiddenSize, typeVocabSize))
59+
allocations.append(dict)
60+
let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
61+
allocations.append(output)
8462

85-
let colbertInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
86-
let colbertOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(1, stride: 2))
87-
let colbertWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
88-
self.colbertLayer = BNNS.FullyConnectedLayer(input: colbertInput, output: colbertOutput, weights: colbertWeights, bias: nil, activation: .identity)!
89-
}
90-
91-
public func callAsFunction(_ textInput: (indices: MLMultiArray, attentionMask: MLMultiArray)) -> Output {
92-
fatalError()
63+
return BNNS.EmbeddingLayer(input: input, output: output, dictionary: dict, paddingIndex: -1, maximumNorm: 0, normType: .l2, scalesGradientByFrequency: true)!
64+
}()
65+
66+
private lazy var normalization: BNNS.NormalizationLayer = {
67+
let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixRowMajor(maxPositionEmbeddings, hiddenSize))
68+
allocations.append(input)
69+
let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixRowMajor(maxPositionEmbeddings, hiddenSize))
70+
allocations.append(output)
71+
72+
let betaWA: MLMultiArray! = weights["bert.embeddings.LayerNorm.beta"] ?? weights["bert.embeddings.LayerNorm.bias"]
73+
let beta = BNNSNDArrayDescriptor.allocate(initializingFrom: betaWA.toArray() as [Float32], shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
74+
allocations.append(beta)
75+
76+
let gammaWA: MLMultiArray! = weights["bert.embeddings.LayerNorm.gamma"] ?? weights["bert.embeddings.LayerNorm.weight"]
77+
let gamma = BNNSNDArrayDescriptor.allocate(initializingFrom: gammaWA.toArray() as [Float32], shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
78+
allocations.append(gamma)
79+
80+
return BNNS.NormalizationLayer(type: .batch(movingMean: nil, movingVariance: nil), input: input, output: output, beta: beta, gamma: gamma, epsilon: normalizationEpsilon, activation: hiddenActivation)!
81+
}()
82+
83+
private lazy var dropout: BNNS.DropoutLayer = {
84+
let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
85+
allocations.append(input)
86+
let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
87+
allocations.append(output)
88+
89+
return BNNS.DropoutLayer(input: input, output: output, rate: dropoutRate, seed: 0, control: 0)!
90+
}()
91+
92+
deinit {
93+
allocations.forEach({ $0.deallocate() })
9394
}
9495

95-
private func forward(textInput: (indices: MLMultiArray, attentionMask: MLMultiArray)) -> [String: MLMultiArray] {
96-
let lastHiddenState = self(textInput).lastHidddenState
97-
98-
var output = [String: MLMultiArray]()
99-
if withDense {
100-
output["dense"] = self.dense(hiddenState: lastHiddenState, mask: textInput.attentionMask)
101-
}
102-
if withSparse {
103-
output["sparse"] = self.sparse(hiddenState: lastHiddenState, mask: textInput.attentionMask)
96+
init(config: Config, weights: Weights = [:]) {
97+
assert(config.model_type!.stringValue == "bert")
98+
for key in [
99+
"bert.embeddings.word_embeddings.weight",
100+
"bert.embeddings.position_embeddings.weight",
101+
"bert.embeddings.token_type_embeddings.weight",
102+
] { assert(weights.keys.contains(where: { $0 == key })) }
103+
assert(weights.keys.contains(where: { $0 == "bert.embeddings.LayerNorm.beta" || $0 == "bert.embeddings.LayerNorm.bias" }))
104+
assert(weights.keys.contains(where: { $0 == "bert.embeddings.LayerNorm.gamma" || $0 == "bert.embeddings.LayerNorm.weight" }))
105+
assert(config.hidden_act!.stringValue == "gelu")
106+
assert("absolute" == config.position_embedding_type!.stringValue!)
107+
self.positionEmbeddingType = config.position_embedding_type!.stringValue!
108+
self.hiddenSize = config.hidden_size!.intValue!
109+
self.vocabSize = config.vocab_size!.intValue!
110+
self.maxPositionEmbeddings = config.max_position_embeddings!.intValue!
111+
self.typeVocabSize = config.type_vocab_size!.intValue!
112+
self.padTokenID = config.pad_token_id!.intValue!
113+
self.normalizationEpsilon = Float(config.layer_norm_eps!.doubleValue!)
114+
self.weights = weights
115+
}
116+
117+
public func callAsFunction(inputIDs: [Int64],
118+
tokenTypeIDs: [Int64]? = nil,
119+
positionIDs: [Int64]? = nil) -> MLMultiArray {
120+
let inputLength = inputIDs.count
121+
let inputIDs: [Int64] = inputIDs.padded(length: maxPositionEmbeddings)
122+
let wordInput = BNNSNDArrayDescriptor.allocate(initializingFrom: inputIDs, shape: .vector(inputIDs.count))
123+
let wordOutput = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, inputIDs.count))
124+
defer {
125+
wordInput.deallocate()
126+
wordOutput.deallocate()
104127
}
105-
if withColbert {
106-
output["colbert"] = self.colbert(hiddenState: lastHiddenState, mask: textInput.attentionMask)
128+
try! wordEmbedding.apply(batchSize: 1, input: wordInput, output: wordOutput)
129+
130+
let positionIDs = positionIDs ?? Array<Int64>(stride(from: 0, through: Int64(inputLength - 1), by: 1))
131+
let positionInput = BNNSNDArrayDescriptor.allocate(initializingFrom: positionIDs.padded(length: maxPositionEmbeddings), shape: .vector(maxPositionEmbeddings))
132+
let positionOutput = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
133+
defer {
134+
positionInput.deallocate()
135+
positionOutput.deallocate()
107136
}
108-
109-
if shouldNormalize {
110-
if withDense {
111-
// TODO: Normalize output["dense"] =
112-
fatalError()
113-
}
114-
if withColbert {
115-
// TODO: Normalize output["colbert"] =
116-
fatalError()
117-
}
137+
try! self.positionEmbedding.apply(batchSize: 1, input: positionInput, output: positionOutput)
138+
139+
let tokenTypeIDs: [Int64] = tokenTypeIDs ?? Array(repeating: 0, count: maxPositionEmbeddings)
140+
let typeInput = BNNSNDArrayDescriptor.allocate(initializingFrom: tokenTypeIDs, shape: .vector(maxPositionEmbeddings))
141+
let typeOutput = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings))
142+
defer {
143+
typeInput.deallocate()
144+
typeOutput.deallocate()
118145
}
146+
try! self.tokenTypeEmbedding.apply(batchSize: 1, input: typeInput, output: typeOutput)
119147

120-
return output
121-
}
122-
123-
private func dense(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray {
124-
assert(hiddenState.shape.count == 2)
125-
var data = [Float]()
126-
data.reserveCapacity(hiddenState.count)
127-
128-
for index in 0..<hiddenState.count {
129-
data.append(hiddenState[index].floatValue)
130-
}
131-
132-
return try! MLMultiArray(data)
133-
}
134-
135-
private func sparse(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray {
136-
fatalError()
137-
}
148+
let multiWord = try! wordOutput.makeMultiArray(of: Float32.self, shape: shape)
149+
let multiPosition = try! positionOutput.makeMultiArray(of: Float32.self, shape: shape)
150+
let multiType = try! typeOutput.makeMultiArray(of: Float32.self, shape: shape)
138151

139-
private func colbert(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray {
140-
fatalError()
152+
return multiWord + multiPosition + multiType
141153
}
142154
}

0 commit comments

Comments
 (0)