@@ -4,139 +4,151 @@ import CoreML
4
4
import Accelerate
5
5
6
6
7
- public protocol Embedding { }
7
+ class BERTEmbedding {
8
8
9
- public struct AutoEmbedding { } // Otherwise AutoModel
10
-
11
- extension AutoEmbedding {
12
- public static func from( pretrained model: String , hubApi: HubApi = . shared) async throws -> Embedding {
13
- return try await BGEM3Model ( repoName: model, hubApi: hubApi)
14
- }
15
- }
16
-
17
- class BERTEmbedding : Embedding { // Otherwise BERTModel
18
- private let wordEmbedding : BNNS . EmbeddingLayer
19
- private let positionEmbedding : BNNS . EmbeddingLayer
20
- private let tokenTypeEmbedding : BNNS . EmbeddingLayer
21
- private let normalization : BNNS . NormalizationLayer
22
- private let dropout : BNNS . DropoutLayer
23
-
24
- private let positionEmbeddingType = " absolute "
25
-
26
- init ( repoName: String ) { fatalError ( ) }
27
-
28
- public func callAsFunction( inputIds: MLMultiArray ? = nil ,
29
- tokenTypeIDs: MLMultiArray ? = nil ,
30
- positionIDs: MLMultiArray ? = nil ,
31
- inputEmbeds: MLMultiArray ? = nil ,
32
- pastKeyValuesLength: Int = 0 ) -> MLMultiArray {
33
- fatalError ( )
34
- }
35
- }
36
-
37
- class BGEM3Model : Embedding {
38
-
39
- struct Output {
40
- let lastHidddenState : MLMultiArray // batchSize, sequenceLength, hiddenSize
41
- let hiddenStates : MLMultiArray ?
42
- let attentions : MLMultiArray ?
43
-
44
- let loss : MLMultiArray ?
45
- let scores : MLMultiArray ?
46
- let pReps : MLMultiArray ?
47
- let qReps : MLMultiArray ?
48
- }
49
-
50
- let withSparse = false
51
- let withDense = true
52
- let withColbert = false
53
-
54
- let shouldNormalize = false
55
- // let poolingMethod = "cls"
56
- // let negativesCrossDevice = false
57
- // let temperature = 1.0
58
- // let enableSubBatch = true
59
- // let unifiedFinetuning = true
60
- // let useSelfDistill = false
61
- // let colbertDim: Int? = nil
62
- // let selfDistillStartStep: Int? = nil
63
-
64
- private let tokenizer : Tokenizer
65
- private let denseLayer : BNNS . FullyConnectedLayer
66
- private let sparseLayer : BNNS . FullyConnectedLayer
67
- private let colbertLayer : BNNS . FullyConnectedLayer
68
-
69
- init ( repoName: String , hubApi: HubApi ) async throws {
70
- let config = LanguageModelConfigurationFromHub ( modelName: repoName)
71
- self . tokenizer = try await AutoTokenizer . from ( pretrained: repoName, hubApi: hubApi)
72
-
73
- let hiddenSize = try await config. modelConfig. hiddenSize? . intValue ?? 384
74
- let colbertDim : Int ? = nil
75
- let denseInput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
76
- let denseOutput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( colbertDim ?? hiddenSize, stride: 2 ) )
77
- let denseWeights = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
78
- self . denseLayer = BNNS . FullyConnectedLayer ( input: denseInput, output: denseOutput, weights: denseWeights, bias: nil , activation: . identity) !
9
+ typealias Weights = [ String : MLMultiArray ]
10
+
11
+ var shape : [ NSNumber ] { [
12
+ NSNumber ( value: maxPositionEmbeddings) ,
13
+ NSNumber ( value: hiddenSize) ,
14
+ ] }
15
+
16
+ private let weights : Weights
17
+
18
+ private let positionEmbeddingType : String
19
+ private let hiddenSize : Int
20
+ private let vocabSize : Int
21
+ private let maxPositionEmbeddings : Int
22
+ private let typeVocabSize : Int
23
+ private let padTokenID : Int
24
+ private let normalizationEpsilon : Float
25
+ private let dropoutRate : Float = 1e-1
26
+ private let hiddenActivation : BNNS . ActivationFunction = . geluApproximation2( alpha: 1e-1 , beta: 1e-1 )
27
+
28
+ private var allocations : [ BNNSNDArrayDescriptor ] = [ ]
29
+
30
+ private lazy var wordEmbedding : BNNS . EmbeddingLayer = {
31
+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Int64 . self, shape: . vector( maxPositionEmbeddings) )
32
+ allocations. append ( input)
33
+ let dictData : [ Float32 ] = weights [ " bert.embeddings.word_embeddings.weight " ] !. toArray ( )
34
+ let dict = BNNSNDArrayDescriptor . allocate ( initializingFrom: dictData, shape: . matrixColumnMajor( hiddenSize, vocabSize) )
35
+ allocations. append ( dict)
36
+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
37
+ allocations. append ( output)
79
38
80
- let sparseInput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
81
- let sparseOutput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( 1 , stride: 2 ) )
82
- let sparseWeights = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
83
- self . sparseLayer = BNNS . FullyConnectedLayer ( input: sparseInput, output: sparseOutput, weights: sparseWeights, bias: nil , activation: . identity) !
39
+ return BNNS . EmbeddingLayer ( input: input, output: output, dictionary: dict, paddingIndex: 0 , maximumNorm: 0 , normType: . l2, scalesGradientByFrequency: false ) !
40
+ } ( )
41
+
42
+ private lazy var positionEmbedding : BNNS . EmbeddingLayer = {
43
+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Int64 . self, shape: . vector( maxPositionEmbeddings) )
44
+ allocations. append ( input)
45
+ let dictData : [ Float32 ] = weights [ " bert.embeddings.position_embeddings.weight " ] !. toArray ( )
46
+ let dict = BNNSNDArrayDescriptor . allocate ( initializingFrom: dictData, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
47
+ allocations. append ( dict)
48
+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
49
+ allocations. append ( output)
50
+
51
+ return BNNS . EmbeddingLayer ( input: input, output: output, dictionary: dict, paddingIndex: - 1 , maximumNorm: 0 , normType: . l2, scalesGradientByFrequency: true ) !
52
+ } ( )
53
+
54
+ private lazy var tokenTypeEmbedding : BNNS . EmbeddingLayer = {
55
+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Int64 . self, shape: . vector( maxPositionEmbeddings) )
56
+ allocations. append ( input)
57
+ let dictData : [ Float32 ] = weights [ " bert.embeddings.token_type_embeddings.weight " ] !. toArray ( )
58
+ let dict = BNNSNDArrayDescriptor . allocate ( initializingFrom: dictData, shape: . matrixColumnMajor( hiddenSize, typeVocabSize) )
59
+ allocations. append ( dict)
60
+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
61
+ allocations. append ( output)
84
62
85
- let colbertInput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
86
- let colbertOutput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( 1 , stride: 2 ) )
87
- let colbertWeights = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
88
- self . colbertLayer = BNNS . FullyConnectedLayer ( input: colbertInput, output: colbertOutput, weights: colbertWeights, bias: nil , activation: . identity) !
89
- }
90
-
91
- public func callAsFunction( _ textInput: ( indices: MLMultiArray , attentionMask: MLMultiArray ) ) -> Output {
92
- fatalError ( )
63
+ return BNNS . EmbeddingLayer ( input: input, output: output, dictionary: dict, paddingIndex: - 1 , maximumNorm: 0 , normType: . l2, scalesGradientByFrequency: true ) !
64
+ } ( )
65
+
66
+ private lazy var normalization : BNNS . NormalizationLayer = {
67
+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixRowMajor( maxPositionEmbeddings, hiddenSize) )
68
+ allocations. append ( input)
69
+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixRowMajor( maxPositionEmbeddings, hiddenSize) )
70
+ allocations. append ( output)
71
+
72
+ let betaWA : MLMultiArray ! = weights [ " bert.embeddings.LayerNorm.beta " ] ?? weights [ " bert.embeddings.LayerNorm.bias " ]
73
+ let beta = BNNSNDArrayDescriptor . allocate ( initializingFrom: betaWA. toArray ( ) as [ Float32 ] , shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
74
+ allocations. append ( beta)
75
+
76
+ let gammaWA : MLMultiArray ! = weights [ " bert.embeddings.LayerNorm.gamma " ] ?? weights [ " bert.embeddings.LayerNorm.weight " ]
77
+ let gamma = BNNSNDArrayDescriptor . allocate ( initializingFrom: gammaWA. toArray ( ) as [ Float32 ] , shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
78
+ allocations. append ( gamma)
79
+
80
+ return BNNS . NormalizationLayer ( type: . batch( movingMean: nil , movingVariance: nil ) , input: input, output: output, beta: beta, gamma: gamma, epsilon: normalizationEpsilon, activation: hiddenActivation) !
81
+ } ( )
82
+
83
+ private lazy var dropout : BNNS . DropoutLayer = {
84
+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
85
+ allocations. append ( input)
86
+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
87
+ allocations. append ( output)
88
+
89
+ return BNNS . DropoutLayer ( input: input, output: output, rate: dropoutRate, seed: 0 , control: 0 ) !
90
+ } ( )
91
+
92
+ deinit {
93
+ allocations. forEach ( { $0. deallocate ( ) } )
93
94
}
94
95
95
- private func forward( textInput: ( indices: MLMultiArray , attentionMask: MLMultiArray ) ) -> [ String : MLMultiArray ] {
96
- let lastHiddenState = self ( textInput) . lastHidddenState
97
-
98
- var output = [ String: MLMultiArray] ( )
99
- if withDense {
100
- output [ " dense " ] = self . dense ( hiddenState: lastHiddenState, mask: textInput. attentionMask)
101
- }
102
- if withSparse {
103
- output [ " sparse " ] = self . sparse ( hiddenState: lastHiddenState, mask: textInput. attentionMask)
96
+ init ( config: Config , weights: Weights = [ : ] ) {
97
+ assert ( config. model_type!. stringValue == " bert " )
98
+ for key in [
99
+ " bert.embeddings.word_embeddings.weight " ,
100
+ " bert.embeddings.position_embeddings.weight " ,
101
+ " bert.embeddings.token_type_embeddings.weight " ,
102
+ ] { assert ( weights. keys. contains ( where: { $0 == key } ) ) }
103
+ assert ( weights. keys. contains ( where: { $0 == " bert.embeddings.LayerNorm.beta " || $0 == " bert.embeddings.LayerNorm.bias " } ) )
104
+ assert ( weights. keys. contains ( where: { $0 == " bert.embeddings.LayerNorm.gamma " || $0 == " bert.embeddings.LayerNorm.weight " } ) )
105
+ assert ( config. hidden_act!. stringValue == " gelu " )
106
+ assert ( " absolute " == config. position_embedding_type!. stringValue!)
107
+ self . positionEmbeddingType = config. position_embedding_type!. stringValue!
108
+ self . hiddenSize = config. hidden_size!. intValue!
109
+ self . vocabSize = config. vocab_size!. intValue!
110
+ self . maxPositionEmbeddings = config. max_position_embeddings!. intValue!
111
+ self . typeVocabSize = config. type_vocab_size!. intValue!
112
+ self . padTokenID = config. pad_token_id!. intValue!
113
+ self . normalizationEpsilon = Float ( config. layer_norm_eps!. doubleValue!)
114
+ self . weights = weights
115
+ }
116
+
117
+ public func callAsFunction( inputIDs: [ Int64 ] ,
118
+ tokenTypeIDs: [ Int64 ] ? = nil ,
119
+ positionIDs: [ Int64 ] ? = nil ) -> MLMultiArray {
120
+ let inputLength = inputIDs. count
121
+ let inputIDs : [ Int64 ] = inputIDs. padded ( length: maxPositionEmbeddings)
122
+ let wordInput = BNNSNDArrayDescriptor . allocate ( initializingFrom: inputIDs, shape: . vector( inputIDs. count) )
123
+ let wordOutput = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, inputIDs. count) )
124
+ defer {
125
+ wordInput. deallocate ( )
126
+ wordOutput. deallocate ( )
104
127
}
105
- if withColbert {
106
- output [ " colbert " ] = self . colbert ( hiddenState: lastHiddenState, mask: textInput. attentionMask)
128
+ try ! wordEmbedding. apply ( batchSize: 1 , input: wordInput, output: wordOutput)
129
+
130
+ let positionIDs = positionIDs ?? Array < Int64 > ( stride ( from: 0 , through: Int64 ( inputLength - 1 ) , by: 1 ) )
131
+ let positionInput = BNNSNDArrayDescriptor . allocate ( initializingFrom: positionIDs. padded ( length: maxPositionEmbeddings) , shape: . vector( maxPositionEmbeddings) )
132
+ let positionOutput = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
133
+ defer {
134
+ positionInput. deallocate ( )
135
+ positionOutput. deallocate ( )
107
136
}
108
-
109
- if shouldNormalize {
110
- if withDense {
111
- // TODO: Normalize output["dense"] =
112
- fatalError ( )
113
- }
114
- if withColbert {
115
- // TODO: Normalize output["colbert"] =
116
- fatalError ( )
117
- }
137
+ try ! self . positionEmbedding. apply ( batchSize: 1 , input: positionInput, output: positionOutput)
138
+
139
+ let tokenTypeIDs : [ Int64 ] = tokenTypeIDs ?? Array ( repeating: 0 , count: maxPositionEmbeddings)
140
+ let typeInput = BNNSNDArrayDescriptor . allocate ( initializingFrom: tokenTypeIDs, shape: . vector( maxPositionEmbeddings) )
141
+ let typeOutput = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
142
+ defer {
143
+ typeInput. deallocate ( )
144
+ typeOutput. deallocate ( )
118
145
}
146
+ try ! self . tokenTypeEmbedding. apply ( batchSize: 1 , input: typeInput, output: typeOutput)
119
147
120
- return output
121
- }
122
-
123
- private func dense( hiddenState: MLMultiArray , mask: MLMultiArray ) -> MLMultiArray {
124
- assert ( hiddenState. shape. count == 2 )
125
- var data = [ Float] ( )
126
- data. reserveCapacity ( hiddenState. count)
127
-
128
- for index in 0 ..< hiddenState. count {
129
- data. append ( hiddenState [ index] . floatValue)
130
- }
131
-
132
- return try ! MLMultiArray ( data)
133
- }
134
-
135
- private func sparse( hiddenState: MLMultiArray , mask: MLMultiArray ) -> MLMultiArray {
136
- fatalError ( )
137
- }
148
+ let multiWord = try ! wordOutput. makeMultiArray ( of: Float32 . self, shape: shape)
149
+ let multiPosition = try ! positionOutput. makeMultiArray ( of: Float32 . self, shape: shape)
150
+ let multiType = try ! typeOutput. makeMultiArray ( of: Float32 . self, shape: shape)
138
151
139
- private func colbert( hiddenState: MLMultiArray , mask: MLMultiArray ) -> MLMultiArray {
140
- fatalError ( )
152
+ return multiWord + multiPosition + multiType
141
153
}
142
154
}
0 commit comments