@@ -90,6 +90,7 @@ private class Attention: Module {
90
90
let scale : Float
91
91
let isSliding : Bool
92
92
let slidingWindow : Int
93
+ let slidingWindowPattern : Int
93
94
94
95
@ModuleInfo ( key: " q_proj " ) var queryProj : Linear
95
96
@ModuleInfo ( key: " k_proj " ) var keyProj : Linear
@@ -109,6 +110,7 @@ private class Attention: Module {
109
110
self . headDim = config. headDim
110
111
self . layerIdx = layerIdx
111
112
self . slidingWindow = config. slidingWindow
113
+ self . slidingWindowPattern = config. slidingWindowPattern
112
114
113
115
self . scale = pow ( config. queryPreAttnScalar, - 0.5 )
114
116
@@ -152,53 +154,30 @@ private class Attention: Module {
152
154
queries = queryNorm ( queries)
153
155
keys = keyNorm ( keys)
154
156
155
- var localMask = mask
157
+ var effectiveMask = mask
156
158
157
159
if let cache {
158
- // Apply RoPE with offset
159
160
queries = rope ( queries, offset: cache. offset)
160
161
keys = rope ( keys, offset: cache. offset)
161
162
( keys, values) = cache. update ( keys: keys, values: values)
162
-
163
- // Handle sliding window for cached generation
164
- if isSliding && cache. offset > slidingWindow && L == 1 {
165
- // Create a sliding window mask for generation
166
- let size = cache. offset + L
167
- let windowStart = max ( 0 , cache. offset - slidingWindow)
168
-
169
- // Create a mask where everything is invalid (large negative value)
170
- var slidingMaskData = Array ( repeating: Float32 ( - 1e9 ) , count: size)
171
-
172
- // Set the sliding window positions to valid (0)
173
- for i in windowStart ..< min ( windowStart + slidingWindow + 1 , size) {
174
- slidingMaskData [ i] = 0
175
- }
176
-
177
- // Create the MLXArray from the data
178
- let slidingMask = MLXArray ( slidingMaskData) . reshaped ( 1 , 1 , 1 , size)
179
- localMask = slidingMask
180
- }
181
163
} else {
182
- // Apply RoPE without offset
183
164
queries = rope ( queries)
184
165
keys = rope ( keys)
185
166
}
186
167
187
- // Scale queries by the pre-attention scalar
188
- queries = queries * MLXArray( scale) . asType ( queries. dtype)
189
-
190
- // Adjust mask for sliding window if needed
191
- if isSliding && localMask != nil && localMask!. dim ( - 1 ) != keys. dim ( - 2 ) {
168
+ if let currentMask = effectiveMask, currentMask. dim ( - 1 ) != keys. dim ( - 2 ) {
192
169
let keyLen = keys. dim ( - 2 )
193
- localMask = localMask![ 0 ... , 0 ... , 0 ... , ( localMask!. dim ( - 1 ) - keyLen) ... ]
170
+ // Ensure slicing doesn't go out of bounds if keyLen is larger than mask dim
171
+ let startIdx = max ( 0 , currentMask. dim ( - 1 ) - keyLen)
172
+ effectiveMask = currentMask [ 0 ... , 0 ... , 0 ... , startIdx... ]
194
173
}
195
174
196
175
let output = MLXFast . scaledDotProductAttention (
197
176
queries: queries,
198
177
keys: keys,
199
178
values: values,
200
- scale: 1.0 , // We already scaled the queries
201
- mask: localMask
179
+ scale: scale ,
180
+ mask: effectiveMask
202
181
)
203
182
. transposed ( 0 , 2 , 1 , 3 )
204
183
. reshaped ( B, L, - 1 )
@@ -227,10 +206,10 @@ private class MLP: Module {
227
206
private class TransformerBlock : Module {
228
207
@ModuleInfo ( key: " self_attn " ) var selfAttention : Attention
229
208
@ModuleInfo var mlp : MLP
230
- @ModuleInfo ( key: " input_layernorm " ) var inputLayerNorm : RMSNorm
231
- @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayerNorm : RMSNorm
232
- @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayerNorm : RMSNorm
233
- @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayerNorm : RMSNorm
209
+ @ModuleInfo ( key: " input_layernorm " ) var inputLayerNorm : Gemma . RMSNorm
210
+ @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayerNorm : Gemma . RMSNorm
211
+ @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayerNorm : Gemma . RMSNorm
212
+ @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayerNorm : Gemma . RMSNorm
234
213
235
214
let numAttentionHeads : Int
236
215
let hiddenSize : Int
@@ -242,13 +221,13 @@ private class TransformerBlock: Module {
242
221
self . _selfAttention. wrappedValue = Attention ( config, layerIdx: layerIdx)
243
222
self . mlp = MLP ( dimensions: config. hiddenSize, hiddenDimensions: config. intermediateSize)
244
223
245
- self . _inputLayerNorm. wrappedValue = RMSNorm (
224
+ self . _inputLayerNorm. wrappedValue = Gemma . RMSNorm (
246
225
dimensions: config. hiddenSize, eps: config. rmsNormEps)
247
- self . _postAttentionLayerNorm. wrappedValue = RMSNorm (
226
+ self . _postAttentionLayerNorm. wrappedValue = Gemma . RMSNorm (
248
227
dimensions: config. hiddenSize, eps: config. rmsNormEps)
249
- self . _preFeedforwardLayerNorm. wrappedValue = RMSNorm (
228
+ self . _preFeedforwardLayerNorm. wrappedValue = Gemma . RMSNorm (
250
229
dimensions: config. hiddenSize, eps: config. rmsNormEps)
251
- self . _postFeedforwardLayerNorm. wrappedValue = RMSNorm (
230
+ self . _postFeedforwardLayerNorm. wrappedValue = Gemma . RMSNorm (
252
231
dimensions: config. hiddenSize, eps: config. rmsNormEps)
253
232
254
233
super. init ( )
@@ -270,7 +249,7 @@ private class TransformerBlock: Module {
270
249
private class Gemma3Model : Module {
271
250
@ModuleInfo ( key: " embed_tokens " ) var embedTokens : Embedding
272
251
@ModuleInfo var layers : [ TransformerBlock ]
273
- @ModuleInfo var norm : RMSNorm
252
+ @ModuleInfo var norm : Gemma . RMSNorm
274
253
275
254
let config : Gemma3TextConfiguration
276
255
@@ -286,134 +265,117 @@ private class Gemma3Model: Module {
286
265
TransformerBlock ( config, layerIdx: layerIdx)
287
266
}
288
267
289
- self . norm = RMSNorm ( dimensions: config. hiddenSize, eps: config. rmsNormEps)
268
+ self . norm = Gemma . RMSNorm ( dimensions: config. hiddenSize, eps: config. rmsNormEps)
290
269
291
270
super. init ( )
292
271
}
293
272
294
- private func createAdditiveCausalMask( n: Int , offset: Int ) -> MLXArray {
295
- let rinds = MLXArray ( Int32 ( 0 ) ..< Int32 ( offset + n) )
296
- let linds = offset != 0 ? MLXArray ( Int32 ( offset) ..< Int32 ( offset + n) ) : rinds
297
- let mask = linds [ 0 ... , . newAxis] .< rinds [ . newAxis]
298
- // Make sure the mask has shape [1, 1, n, offset+n]
299
- return ( mask * Float32( - 1e9 ) ) . reshaped ( 1 , 1 , n, offset + n)
300
- }
301
-
302
- // Create attention mask with sliding window support
303
- private func createAttentionMask( h: MLXArray , cache: [ KVCache ] ? , isSliding: Bool = false )
304
- -> MLXArray ?
273
+ func callAsFunction( _ inputs: MLXArray , mask: MLXArray ? = nil , cache: [ KVCache ? ] ? = nil )
274
+ -> MLXArray
305
275
{
306
- let t = h. dim ( 1 )
307
-
308
- var offset = 0
309
- if let cache = cache, !cache. isEmpty, let firstCache = cache. first ( where: { $0 != nil } ) {
310
- offset = firstCache. offset
276
+ // Apply embedding with scaling
277
+ // TODO: Is type casting necessary here?
278
+ let scale = MLXArray ( sqrtf ( Float ( config. hiddenSize) ) ) . asType ( inputs. dtype)
279
+ var h = embedTokens ( inputs) * scale
280
+
281
+ var layerCache = cache
282
+ if layerCache == nil {
283
+ // During training or first pass without cache, create nil placeholders
284
+ layerCache = Array ( repeating: nil as KVCache ? , count: layers. count)
311
285
}
312
286
313
- // For single token generation with history
314
- if t == 1 && offset > 0 {
315
- return nil // No mask needed for single token generation
316
- } else if t <= 1 && offset == 0 {
317
- return nil
318
- }
287
+ var fullMask : MLXArray ? = nil
288
+ var slidingWindowMask : MLXArray ? = nil
319
289
320
- // Create basic causal mask
321
- var mask = createAdditiveCausalMask ( n: t, offset: offset) . asType ( h. dtype)
322
-
323
- // Apply sliding window constraint if needed
324
- if isSliding && config. slidingWindow > 0 && ( t + offset) > config. slidingWindow {
325
- let windowSize = config. slidingWindow
326
-
327
- // Create a mask that limits attention to the sliding window
328
- for i in 0 ..< t {
329
- let row = i + offset
330
- let minCol = max ( 0 , row - windowSize)
331
-
332
- // Set values outside the window to large negative
333
- if minCol > 0 {
334
- let maskSlice = mask [ 0 , 0 , i, 0 ..< minCol]
335
- let shape = maskSlice. shape
336
- mask [ 0 , 0 , i, 0 ..< minCol] = MLXArray (
337
- Array ( repeating: Float ( - 1e9 ) , count: minCol) )
338
- }
339
- }
290
+ // TODO: Check this part carefully
291
+ if mask == nil {
292
+ // Create a standard causal mask for the input sequence length
293
+ let sequenceLength = inputs. dim ( 1 )
294
+ slidingWindowMask = createAdditiveCausalMask ( n: sequenceLength, offset: 0 )
295
+ fullMask = slidingWindowMask
340
296
}
297
+ for (i, layer) in layers. enumerated ( ) {
298
+ let isGlobal = ( i % config. slidingWindowPattern == config. slidingWindowPattern - 1 )
299
+ var layerMask = mask // Start with the explicitly passed mask
341
300
342
- return mask
343
- }
344
-
345
- func callAsFunction( _ inputs: MLXArray , cache: [ KVCache ] ? = nil ) -> MLXArray {
346
- // Apply embedding with scaling
347
- let scale = sqrt ( Float ( config. hiddenSize) )
348
- var h = embedTokens ( inputs) * MLXArray( scale) . asType ( inputs. dtype)
349
-
350
- // Create masks if needed
351
- var localMasks : [ MLXArray ? ] = Array ( repeating: nil , count: config. hiddenLayers)
352
-
353
- for i in 0 ..< config. hiddenLayers {
354
- let isGlobal = ( i + 1 ) % config. slidingWindowPattern == 0
355
- let isSliding = !isGlobal
356
-
357
- if isSliding && inputs. dim ( 1 ) > 1 {
358
- localMasks [ i] = createAttentionMask ( h: h, cache: cache, isSliding: true )
359
- } else {
360
- localMasks [ i] = createAttentionMask ( h: h, cache: cache, isSliding: false )
301
+ if mask == nil {
302
+ layerMask = slidingWindowMask // Use the generated causal mask
361
303
}
362
- }
363
304
364
- for (i, layer) in layers. enumerated ( ) {
365
- let layerCache = cache ? [ i]
366
- h = layer ( h, mask: localMasks [ i] , cache: layerCache)
305
+ // Apply the layer
306
+ h = layer ( h, mask: layerMask, cache: layerCache ? [ i] )
367
307
}
368
308
369
309
return norm ( h)
370
310
}
371
311
}
372
312
373
- public class Gemma3TextModel : Module , LLMModel , KVCacheDimensionProvider {
313
+ public class Gemma3TextModel : Module , LLMModel {
314
+
374
315
@ModuleInfo private var model : Gemma3Model
375
316
@ModuleInfo ( key: " lm_head " ) var lmHead : Linear
376
317
377
318
public let config : Gemma3TextConfiguration
378
319
public var vocabularySize : Int { config. vocabularySize }
379
- public var kvHeads : [ Int ]
380
320
381
321
public init ( _ config: Gemma3TextConfiguration ) {
382
322
self . config = config
383
323
self . model = Gemma3Model ( config)
384
324
self . _lmHead. wrappedValue = Linear ( config. hiddenSize, config. vocabularySize, bias: false )
385
-
386
- // Set up KV heads array based on sliding window pattern
387
- var heads : [ Int ] = [ ]
388
- for i in 0 ..< config. hiddenLayers {
389
- heads. append ( config. kvHeads)
390
- }
391
- self . kvHeads = heads
392
-
393
325
super. init ( )
394
326
}
395
327
396
328
public func callAsFunction( _ inputs: MLXArray , cache: [ KVCache ] ? = nil ) -> MLXArray {
397
- let out = model ( inputs, cache: cache)
329
+ let optionalCache = cache? . map { $0 as KVCache ? }
330
+ let out = model ( inputs, cache: optionalCache)
398
331
var finalLogits = lmHead ( out)
399
- if let softcap = config. finalLogitSoftcapping {
400
- finalLogits = tanh ( finalLogits / MLXArray( softcap) ) * MLXArray( softcap)
332
+
333
+ // Apply final logit softcapping if configured
334
+ if let softcap = config. finalLogitSoftcapping, softcap > 0 {
335
+ let scale = MLXArray ( softcap)
336
+ finalLogits = tanh ( finalLogits / scale) * scale
401
337
}
402
338
return finalLogits
403
339
}
404
340
341
+ // TODO: Check this
405
342
public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
406
343
var sanitizedWeights = weights
407
- // Copy embedding weights to lm_head if not present
408
344
if sanitizedWeights [ " lm_head.weight " ] == nil {
409
345
if let embedWeight = sanitizedWeights [ " model.embed_tokens.weight " ] {
410
346
sanitizedWeights [ " lm_head.weight " ] = embedWeight
347
+ } else {
348
+ print ( " Warning: Unable to find model.embed_tokens.weight for lm_head weight tying. " )
411
349
}
412
350
}
413
- // Remove RoPE frequency weights as they're computed on the fly
351
+ // Keep filtering RoPE keys if they exist in the checkpoint (though usually not saved)
414
352
return sanitizedWeights. filter { key, _ in
415
- !key. contains ( " self_attn.rotary_emb.inv_freq " )
353
+ !key. contains ( " self_attn.rope.inv_freq " )
354
+ && !key. contains ( " self_attn.rotary_emb.inv_freq " )
355
+ }
356
+ }
357
+
358
+ // public func loraLinearLayers() -> LoRALinearLayers {
359
+ // model.layers.map { ($0.selfAttention, ["q_proj", "v_proj", "k_proj", "o_proj"]) } // Add k/o proj? Check common practice
360
+ // }
361
+
362
+ public func newCache( parameters: GenerateParameters ? = nil ) -> [ KVCache ] {
363
+ var caches = [ KVCache] ( )
364
+ let slidingWindow = config. slidingWindow > 0 ? config. slidingWindow : 4096
365
+ let slidingWindowPattern = config. slidingWindowPattern
366
+
367
+ for i in 0 ..< config. hiddenLayers {
368
+ let isGlobalLayer = ( i % slidingWindowPattern == slidingWindowPattern - 1 )
369
+
370
+ if isGlobalLayer {
371
+ caches. append ( StandardKVCache ( ) )
372
+ } else {
373
+ caches. append (
374
+ RotatingKVCache ( maxSize: slidingWindow, keep: 0 )
375
+ )
376
+ }
416
377
}
378
+ return caches
417
379
}
418
380
}
419
381
0 commit comments