Skip to content

Commit 0c50f71

Browse files
Support Phi-4-mini (#216)
* Support Phi-4-mini
1 parent 11b7b43 commit 0c50f71

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

Libraries/MLXLLM/Models/Phi3.swift

+31-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ private class Attention: Module {
1414
let heads: Int
1515
let kvHeads: Int
1616
let headDim: Int
17+
let ropeDim: Int
1718

1819
@ModuleInfo(key: "qkv_proj") var wqkv: Linear
1920
@ModuleInfo(key: "o_proj") var wo: Linear
@@ -42,6 +43,7 @@ private class Attention: Module {
4243
self.kvHeads = args.kvHeads
4344

4445
self.headDim = args.hiddenSize / heads
46+
self.ropeDim = Int(Float(headDim) * args.partialRotaryFactor)
4547
self.scale = pow(Float(headDim), -0.5)
4648

4749
self._wqkv.wrappedValue = Linear(dim, (heads + 2 * kvHeads) * headDim, bias: false)
@@ -63,15 +65,15 @@ private class Attention: Module {
6365
{
6466
self.rope = .suScaledRotaryEmbedding(
6567
SuScaledRotaryEmbedding(
66-
dimensions: headDim, base: args.ropeTheta,
68+
dimensions: ropeDim, base: args.ropeTheta,
6769
maxPositionEmbeddings: args.maxPositionEmbeddings,
6870
originalMaxPositionEmbeddings: args.originalMaxPositionEmbeddings,
6971
longFactor: longFactor))
7072

7173
} else {
7274
self.rope = .rope(
7375
RoPE(
74-
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
76+
dimensions: ropeDim, traditional: args.ropeTraditional, base: args.ropeTheta,
7577
scale: ropeScale))
7678
}
7779
}
@@ -157,9 +159,11 @@ private class Phi3ModelInner: Module {
157159

158160
fileprivate let layers: [TransformerBlock]
159161
let norm: RMSNorm
162+
let args: Phi3Configuration
160163

161164
public init(_ args: Phi3Configuration) {
162165
precondition(args.vocabularySize > 0)
166+
self.args = args
163167

164168
self._embedTokens.wrappedValue = Embedding(
165169
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
@@ -190,19 +194,31 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider {
190194
public let kvHeads: [Int]
191195

192196
private let model: Phi3ModelInner
197+
private let args: Phi3Configuration
193198

194-
@ModuleInfo(key: "lm_head") var lmHead: Linear
199+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
195200

196201
public init(_ args: Phi3Configuration) {
197202
self.vocabularySize = args.vocabularySize
198203
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
199204
self.model = Phi3ModelInner(args)
200-
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
205+
self.args = args
206+
207+
if !args.tieWordEmbeddings {
208+
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
209+
}
201210
}
202211

203212
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
204213
let out = model(inputs, cache: cache)
205-
return lmHead(out)
214+
if args.tieWordEmbeddings {
215+
return model.embedTokens.asLinear(out)
216+
} else if let lmHead {
217+
return lmHead(out)
218+
} else {
219+
fatalError(
220+
"Model configuration error: Neither tied embeddings nor lm_head is available")
221+
}
206222
}
207223
}
208224

@@ -235,8 +251,10 @@ public struct Phi3Configuration: Codable, Sendable {
235251
var ropeTheta: Float = 10_000
236252
var ropeTraditional: Bool = false
237253
var ropeScaling: RopeScalingWithFactorArrays?
254+
var partialRotaryFactor: Float = 1.0
238255
var maxPositionEmbeddings: Int
239256
var originalMaxPositionEmbeddings: Int
257+
var tieWordEmbeddings: Bool = false
240258

241259
enum CodingKeys: String, CodingKey {
242260
case hiddenSize = "hidden_size"
@@ -249,8 +267,10 @@ public struct Phi3Configuration: Codable, Sendable {
249267
case ropeTheta = "rope_theta"
250268
case ropeTraditional = "rope_traditional"
251269
case ropeScaling = "rope_scaling"
270+
case partialRotaryFactor = "partial_rotary_factor"
252271
case maxPositionEmbeddings = "max_position_embeddings"
253272
case originalMaxPositionEmbeddings = "original_max_position_embeddings"
273+
case tieWordEmbeddings = "tie_word_embeddings"
254274
}
255275

256276
public init(from decoder: Decoder) throws {
@@ -278,10 +298,15 @@ public struct Phi3Configuration: Codable, Sendable {
278298
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
279299
ropeScaling = try container.decodeIfPresent(
280300
RopeScalingWithFactorArrays.self, forKey: .ropeScaling)
301+
partialRotaryFactor =
302+
try container.decodeIfPresent(
303+
Float.self, forKey: .partialRotaryFactor) ?? 1.0
281304
maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
282305
originalMaxPositionEmbeddings = try container.decode(
283306
Int.self, forKey: .originalMaxPositionEmbeddings)
284-
307+
tieWordEmbeddings =
308+
try container.decodeIfPresent(
309+
Bool.self, forKey: .tieWordEmbeddings) ?? false
285310
}
286311
}
287312

Libraries/MLXLLM/SuScaledRotaryEmbedding.swift

+7-2
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,14 @@ public class SuScaledRotaryEmbedding: Module {
3939
}
4040

4141
public func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
42+
// Apply scaling only to the dimensions that will be rotated
43+
var scaledX = x
44+
let sliceToScale = scaledX[.ellipsis, 0 ..< dimensions]
45+
scaledX[.ellipsis, 0 ..< dimensions] = scale * sliceToScale
46+
4247
return MLXFast.RoPE(
43-
self.scale * x,
44-
dimensions: x.shape.last!,
48+
scaledX,
49+
dimensions: dimensions,
4550
traditional: false,
4651
base: nil,
4752
scale: 1.0,

0 commit comments

Comments
 (0)