@@ -14,6 +14,7 @@ private class Attention: Module {
14
14
let heads : Int
15
15
let kvHeads : Int
16
16
let headDim : Int
17
+ let ropeDim : Int
17
18
18
19
@ModuleInfo ( key: " qkv_proj " ) var wqkv : Linear
19
20
@ModuleInfo ( key: " o_proj " ) var wo : Linear
@@ -42,6 +43,7 @@ private class Attention: Module {
42
43
self . kvHeads = args. kvHeads
43
44
44
45
self . headDim = args. hiddenSize / heads
46
+ self . ropeDim = Int ( Float ( headDim) * args. partialRotaryFactor)
45
47
self . scale = pow ( Float ( headDim) , - 0.5 )
46
48
47
49
self . _wqkv. wrappedValue = Linear ( dim, ( heads + 2 * kvHeads) * headDim, bias: false )
@@ -63,15 +65,15 @@ private class Attention: Module {
63
65
{
64
66
self . rope = . suScaledRotaryEmbedding(
65
67
SuScaledRotaryEmbedding (
66
- dimensions: headDim , base: args. ropeTheta,
68
+ dimensions: ropeDim , base: args. ropeTheta,
67
69
maxPositionEmbeddings: args. maxPositionEmbeddings,
68
70
originalMaxPositionEmbeddings: args. originalMaxPositionEmbeddings,
69
71
longFactor: longFactor) )
70
72
71
73
} else {
72
74
self . rope = . rope(
73
75
RoPE (
74
- dimensions: headDim , traditional: args. ropeTraditional, base: args. ropeTheta,
76
+ dimensions: ropeDim , traditional: args. ropeTraditional, base: args. ropeTheta,
75
77
scale: ropeScale) )
76
78
}
77
79
}
@@ -157,9 +159,11 @@ private class Phi3ModelInner: Module {
157
159
158
160
fileprivate let layers : [ TransformerBlock ]
159
161
let norm : RMSNorm
162
+ let args : Phi3Configuration
160
163
161
164
public init ( _ args: Phi3Configuration ) {
162
165
precondition ( args. vocabularySize > 0 )
166
+ self . args = args
163
167
164
168
self . _embedTokens. wrappedValue = Embedding (
165
169
embeddingCount: args. vocabularySize, dimensions: args. hiddenSize)
@@ -190,19 +194,31 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider {
190
194
public let kvHeads : [ Int ]
191
195
192
196
private let model : Phi3ModelInner
197
+ private let args : Phi3Configuration
193
198
194
- @ModuleInfo ( key: " lm_head " ) var lmHead : Linear
199
+ @ModuleInfo ( key: " lm_head " ) var lmHead : Linear ?
195
200
196
201
public init ( _ args: Phi3Configuration ) {
197
202
self . vocabularySize = args. vocabularySize
198
203
self . kvHeads = ( 0 ..< args. hiddenLayers) . map { _ in args. kvHeads }
199
204
self . model = Phi3ModelInner ( args)
200
- self . _lmHead. wrappedValue = Linear ( args. hiddenSize, args. vocabularySize, bias: false )
205
+ self . args = args
206
+
207
+ if !args. tieWordEmbeddings {
208
+ self . _lmHead. wrappedValue = Linear ( args. hiddenSize, args. vocabularySize, bias: false )
209
+ }
201
210
}
202
211
203
212
public func callAsFunction( _ inputs: MLXArray , cache: [ KVCache ] ? ) -> MLXArray {
204
213
let out = model ( inputs, cache: cache)
205
- return lmHead ( out)
214
+ if args. tieWordEmbeddings {
215
+ return model. embedTokens. asLinear ( out)
216
+ } else if let lmHead {
217
+ return lmHead ( out)
218
+ } else {
219
+ fatalError (
220
+ " Model configuration error: Neither tied embeddings nor lm_head is available " )
221
+ }
206
222
}
207
223
}
208
224
@@ -235,8 +251,10 @@ public struct Phi3Configuration: Codable, Sendable {
235
251
var ropeTheta : Float = 10_000
236
252
var ropeTraditional : Bool = false
237
253
var ropeScaling : RopeScalingWithFactorArrays ?
254
+ var partialRotaryFactor : Float = 1.0
238
255
var maxPositionEmbeddings : Int
239
256
var originalMaxPositionEmbeddings : Int
257
+ var tieWordEmbeddings : Bool = false
240
258
241
259
enum CodingKeys : String , CodingKey {
242
260
case hiddenSize = " hidden_size "
@@ -249,8 +267,10 @@ public struct Phi3Configuration: Codable, Sendable {
249
267
case ropeTheta = " rope_theta "
250
268
case ropeTraditional = " rope_traditional "
251
269
case ropeScaling = " rope_scaling "
270
+ case partialRotaryFactor = " partial_rotary_factor "
252
271
case maxPositionEmbeddings = " max_position_embeddings "
253
272
case originalMaxPositionEmbeddings = " original_max_position_embeddings "
273
+ case tieWordEmbeddings = " tie_word_embeddings "
254
274
}
255
275
256
276
public init ( from decoder: Decoder ) throws {
@@ -278,10 +298,15 @@ public struct Phi3Configuration: Codable, Sendable {
278
298
Bool . self, forKey: Phi3Configuration . CodingKeys. ropeTraditional) ?? false
279
299
ropeScaling = try container. decodeIfPresent (
280
300
RopeScalingWithFactorArrays . self, forKey: . ropeScaling)
301
+ partialRotaryFactor =
302
+ try container. decodeIfPresent (
303
+ Float . self, forKey: . partialRotaryFactor) ?? 1.0
281
304
maxPositionEmbeddings = try container. decode ( Int . self, forKey: . maxPositionEmbeddings)
282
305
originalMaxPositionEmbeddings = try container. decode (
283
306
Int . self, forKey: . originalMaxPositionEmbeddings)
284
-
307
+ tieWordEmbeddings =
308
+ try container. decodeIfPresent (
309
+ Bool . self, forKey: . tieWordEmbeddings) ?? false
285
310
}
286
311
}
287
312
0 commit comments