-
Notifications
You must be signed in to change notification settings - Fork 205
/
Copy pathQwen2.swift
278 lines (230 loc) · 9.31 KB
/
Qwen2.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
//
// Qwen2.swift
// LLM
//
// Created by John Mai on 2024/3/3.
//
import Foundation
import MLX
import MLXFast
import MLXLMCommon
import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py
private class Attention: Module {
let args: Qwen2Configuration
let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
let rope: RoPE
public init(_ args: Qwen2Configuration) {
self.args = args
let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads
let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5)
_wq.wrappedValue = Linear(dim, heads * headDim, bias: true)
_wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
_wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
_wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
let ropeScale: Float
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
let factor = ropeScaling["factor"]
{
if let v = factor.asFloat() {
ropeScale = 1 / v
} else {
fatalError("ropeScaling.factor must be a float")
}
} else {
ropeScale = 1
}
self.rope = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
scale: ropeScale)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
) -> MLXArray {
let (B, L) = (x.dim(0), x.dim(1))
var queries = wq(x)
var keys = wk(x)
var values = wv(x)
// prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
(keys, values) = cache.update(keys: keys, values: values)
} else {
queries = rope(queries)
keys = rope(keys)
}
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
return wo(output)
}
}
private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_proj") var gate: Linear
@ModuleInfo(key: "down_proj") var down: Linear
@ModuleInfo(key: "up_proj") var up: Linear
public init(dimensions: Int, hiddenDimensions: Int) {
_gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
_down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
_up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
down(silu(gate(x)) * up(x))
}
}
private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
public init(_ args: Qwen2Configuration) {
_attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
_inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
_postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
) -> MLXArray {
var r = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return out
}
}
private class Qwen2ModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
fileprivate let layers: [TransformerBlock]
let norm: RMSNorm
public init(_ args: Qwen2Configuration) {
precondition(args.vocabularySize > 0)
_embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
TransformerBlock(args)
}
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
var h = embedTokens(inputs)
let mask: MLXArray? = createAttentionMask(h: h, cache: cache)
for (i, layer) in layers.enumerated() {
h = layer(h, mask: mask, cache: cache?[i])
}
return norm(h)
}
}
public class Qwen2Model: Module, LLMModel, KVCacheDimensionProvider {
public let vocabularySize: Int
public let kvHeads: [Int]
private let model: Qwen2ModelInner
let configuration: Qwen2Configuration
@ModuleInfo(key: "lm_head") var lmHead: Linear?
public init(_ args: Qwen2Configuration) {
self.configuration = args
self.vocabularySize = args.vocabularySize
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
self.model = Qwen2ModelInner(args)
if !args.tieWordEmbeddings {
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
}
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
var out = model(inputs, cache: cache)
if let lmHead {
out = lmHead(out)
} else {
out = model.embedTokens.asLinear(out)
}
return out
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights
if configuration.tieWordEmbeddings {
weights["lm_head.weight"] = nil
}
// Remove unused precomputed rotary freqs
return weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
}
}
public struct Qwen2Configuration: Codable, Sendable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
var ropeTheta: Float = 1_000_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil
var tieWordEmbeddings = false
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
}
public init(from decoder: Decoder) throws {
// custom implementation to handle optional keys with required values
let container: KeyedDecodingContainer<Qwen2Configuration.CodingKeys> =
try decoder.container(
keyedBy: Qwen2Configuration.CodingKeys.self)
self.hiddenSize = try container.decode(
Int.self, forKey: Qwen2Configuration.CodingKeys.hiddenSize)
self.hiddenLayers = try container.decode(
Int.self, forKey: Qwen2Configuration.CodingKeys.hiddenLayers)
self.intermediateSize = try container.decode(
Int.self, forKey: Qwen2Configuration.CodingKeys.intermediateSize)
self.attentionHeads = try container.decode(
Int.self, forKey: Qwen2Configuration.CodingKeys.attentionHeads)
self.rmsNormEps = try container.decode(
Float.self, forKey: Qwen2Configuration.CodingKeys.rmsNormEps)
self.vocabularySize = try container.decode(
Int.self, forKey: Qwen2Configuration.CodingKeys.vocabularySize)
self.kvHeads = try container.decode(Int.self, forKey: Qwen2Configuration.CodingKeys.kvHeads)
self.ropeTheta =
try container.decodeIfPresent(
Float.self, forKey: Qwen2Configuration.CodingKeys.ropeTheta)
?? 1_000_000
self.ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
self.tieWordEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
}
}
// MARK: - LoRA
extension Qwen2Model: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}