Skip to content

Commit 1cfae03

Browse files
committed
Gemma 3 1B generates text
1 parent 5204728 commit 1cfae03

File tree

2 files changed

+620
-166
lines changed

2 files changed

+620
-166
lines changed

Libraries/MLXLLM/Models/Gemma3Text.swift

+83-121
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ private class Attention: Module {
9090
let scale: Float
9191
let isSliding: Bool
9292
let slidingWindow: Int
93+
let slidingWindowPattern: Int
9394

9495
@ModuleInfo(key: "q_proj") var queryProj: Linear
9596
@ModuleInfo(key: "k_proj") var keyProj: Linear
@@ -109,6 +110,7 @@ private class Attention: Module {
109110
self.headDim = config.headDim
110111
self.layerIdx = layerIdx
111112
self.slidingWindow = config.slidingWindow
113+
self.slidingWindowPattern = config.slidingWindowPattern
112114

113115
self.scale = pow(config.queryPreAttnScalar, -0.5)
114116

@@ -152,53 +154,30 @@ private class Attention: Module {
152154
queries = queryNorm(queries)
153155
keys = keyNorm(keys)
154156

155-
var localMask = mask
157+
var effectiveMask = mask
156158

157159
if let cache {
158-
// Apply RoPE with offset
159160
queries = rope(queries, offset: cache.offset)
160161
keys = rope(keys, offset: cache.offset)
161162
(keys, values) = cache.update(keys: keys, values: values)
162-
163-
// Handle sliding window for cached generation
164-
if isSliding && cache.offset > slidingWindow && L == 1 {
165-
// Create a sliding window mask for generation
166-
let size = cache.offset + L
167-
let windowStart = max(0, cache.offset - slidingWindow)
168-
169-
// Create a mask where everything is invalid (large negative value)
170-
var slidingMaskData = Array(repeating: Float32(-1e9), count: size)
171-
172-
// Set the sliding window positions to valid (0)
173-
for i in windowStart ..< min(windowStart + slidingWindow + 1, size) {
174-
slidingMaskData[i] = 0
175-
}
176-
177-
// Create the MLXArray from the data
178-
let slidingMask = MLXArray(slidingMaskData).reshaped(1, 1, 1, size)
179-
localMask = slidingMask
180-
}
181163
} else {
182-
// Apply RoPE without offset
183164
queries = rope(queries)
184165
keys = rope(keys)
185166
}
186167

187-
// Scale queries by the pre-attention scalar
188-
queries = queries * MLXArray(scale).asType(queries.dtype)
189-
190-
// Adjust mask for sliding window if needed
191-
if isSliding && localMask != nil && localMask!.dim(-1) != keys.dim(-2) {
168+
if let currentMask = effectiveMask, currentMask.dim(-1) != keys.dim(-2) {
192169
let keyLen = keys.dim(-2)
193-
localMask = localMask![0..., 0..., 0..., (localMask!.dim(-1) - keyLen)...]
170+
// Ensure slicing doesn't go out of bounds if keyLen is larger than mask dim
171+
let startIdx = max(0, currentMask.dim(-1) - keyLen)
172+
effectiveMask = currentMask[0..., 0..., 0..., startIdx...]
194173
}
195174

196175
let output = MLXFast.scaledDotProductAttention(
197176
queries: queries,
198177
keys: keys,
199178
values: values,
200-
scale: 1.0, // We already scaled the queries
201-
mask: localMask
179+
scale: scale,
180+
mask: effectiveMask
202181
)
203182
.transposed(0, 2, 1, 3)
204183
.reshaped(B, L, -1)
@@ -227,10 +206,10 @@ private class MLP: Module {
227206
private class TransformerBlock: Module {
228207
@ModuleInfo(key: "self_attn") var selfAttention: Attention
229208
@ModuleInfo var mlp: MLP
230-
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
231-
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
232-
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm
233-
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
209+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: Gemma.RMSNorm
210+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: Gemma.RMSNorm
211+
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: Gemma.RMSNorm
212+
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: Gemma.RMSNorm
234213

235214
let numAttentionHeads: Int
236215
let hiddenSize: Int
@@ -242,13 +221,13 @@ private class TransformerBlock: Module {
242221
self._selfAttention.wrappedValue = Attention(config, layerIdx: layerIdx)
243222
self.mlp = MLP(dimensions: config.hiddenSize, hiddenDimensions: config.intermediateSize)
244223

245-
self._inputLayerNorm.wrappedValue = RMSNorm(
224+
self._inputLayerNorm.wrappedValue = Gemma.RMSNorm(
246225
dimensions: config.hiddenSize, eps: config.rmsNormEps)
247-
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
226+
self._postAttentionLayerNorm.wrappedValue = Gemma.RMSNorm(
248227
dimensions: config.hiddenSize, eps: config.rmsNormEps)
249-
self._preFeedforwardLayerNorm.wrappedValue = RMSNorm(
228+
self._preFeedforwardLayerNorm.wrappedValue = Gemma.RMSNorm(
250229
dimensions: config.hiddenSize, eps: config.rmsNormEps)
251-
self._postFeedforwardLayerNorm.wrappedValue = RMSNorm(
230+
self._postFeedforwardLayerNorm.wrappedValue = Gemma.RMSNorm(
252231
dimensions: config.hiddenSize, eps: config.rmsNormEps)
253232

254233
super.init()
@@ -270,7 +249,7 @@ private class TransformerBlock: Module {
270249
private class Gemma3Model: Module {
271250
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
272251
@ModuleInfo var layers: [TransformerBlock]
273-
@ModuleInfo var norm: RMSNorm
252+
@ModuleInfo var norm: Gemma.RMSNorm
274253

275254
let config: Gemma3TextConfiguration
276255

@@ -286,134 +265,117 @@ private class Gemma3Model: Module {
286265
TransformerBlock(config, layerIdx: layerIdx)
287266
}
288267

289-
self.norm = RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps)
268+
self.norm = Gemma.RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps)
290269

291270
super.init()
292271
}
293272

294-
private func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray {
295-
let rinds = MLXArray(Int32(0) ..< Int32(offset + n))
296-
let linds = offset != 0 ? MLXArray(Int32(offset) ..< Int32(offset + n)) : rinds
297-
let mask = linds[0..., .newAxis] .< rinds[.newAxis]
298-
// Make sure the mask has shape [1, 1, n, offset+n]
299-
return (mask * Float32(-1e9)).reshaped(1, 1, n, offset + n)
300-
}
301-
302-
// Create attention mask with sliding window support
303-
private func createAttentionMask(h: MLXArray, cache: [KVCache]?, isSliding: Bool = false)
304-
-> MLXArray?
273+
func callAsFunction(_ inputs: MLXArray, mask: MLXArray? = nil, cache: [KVCache?]? = nil)
274+
-> MLXArray
305275
{
306-
let t = h.dim(1)
307-
308-
var offset = 0
309-
if let cache = cache, !cache.isEmpty, let firstCache = cache.first(where: { $0 != nil }) {
310-
offset = firstCache.offset
276+
// Apply embedding with scaling
277+
// TODO: Is type casting necessary here?
278+
let scale = MLXArray(sqrtf(Float(config.hiddenSize))).asType(inputs.dtype)
279+
var h = embedTokens(inputs) * scale
280+
281+
var layerCache = cache
282+
if layerCache == nil {
283+
// During training or first pass without cache, create nil placeholders
284+
layerCache = Array(repeating: nil as KVCache?, count: layers.count)
311285
}
312286

313-
// For single token generation with history
314-
if t == 1 && offset > 0 {
315-
return nil // No mask needed for single token generation
316-
} else if t <= 1 && offset == 0 {
317-
return nil
318-
}
287+
var fullMask: MLXArray? = nil
288+
var slidingWindowMask: MLXArray? = nil
319289

320-
// Create basic causal mask
321-
var mask = createAdditiveCausalMask(n: t, offset: offset).asType(h.dtype)
322-
323-
// Apply sliding window constraint if needed
324-
if isSliding && config.slidingWindow > 0 && (t + offset) > config.slidingWindow {
325-
let windowSize = config.slidingWindow
326-
327-
// Create a mask that limits attention to the sliding window
328-
for i in 0 ..< t {
329-
let row = i + offset
330-
let minCol = max(0, row - windowSize)
331-
332-
// Set values outside the window to large negative
333-
if minCol > 0 {
334-
let maskSlice = mask[0, 0, i, 0 ..< minCol]
335-
let shape = maskSlice.shape
336-
mask[0, 0, i, 0 ..< minCol] = MLXArray(
337-
Array(repeating: Float(-1e9), count: minCol))
338-
}
339-
}
290+
// TODO: Check this part carefully
291+
if mask == nil {
292+
// Create a standard causal mask for the input sequence length
293+
let sequenceLength = inputs.dim(1)
294+
slidingWindowMask = createAdditiveCausalMask(n: sequenceLength, offset: 0)
295+
fullMask = slidingWindowMask
340296
}
297+
for (i, layer) in layers.enumerated() {
298+
let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
299+
var layerMask = mask // Start with the explicitly passed mask
341300

342-
return mask
343-
}
344-
345-
func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
346-
// Apply embedding with scaling
347-
let scale = sqrt(Float(config.hiddenSize))
348-
var h = embedTokens(inputs) * MLXArray(scale).asType(inputs.dtype)
349-
350-
// Create masks if needed
351-
var localMasks: [MLXArray?] = Array(repeating: nil, count: config.hiddenLayers)
352-
353-
for i in 0 ..< config.hiddenLayers {
354-
let isGlobal = (i + 1) % config.slidingWindowPattern == 0
355-
let isSliding = !isGlobal
356-
357-
if isSliding && inputs.dim(1) > 1 {
358-
localMasks[i] = createAttentionMask(h: h, cache: cache, isSliding: true)
359-
} else {
360-
localMasks[i] = createAttentionMask(h: h, cache: cache, isSliding: false)
301+
if mask == nil {
302+
layerMask = slidingWindowMask // Use the generated causal mask
361303
}
362-
}
363304

364-
for (i, layer) in layers.enumerated() {
365-
let layerCache = cache?[i]
366-
h = layer(h, mask: localMasks[i], cache: layerCache)
305+
// Apply the layer
306+
h = layer(h, mask: layerMask, cache: layerCache?[i])
367307
}
368308

369309
return norm(h)
370310
}
371311
}
372312

373-
public class Gemma3TextModel: Module, LLMModel, KVCacheDimensionProvider {
313+
public class Gemma3TextModel: Module, LLMModel {
314+
374315
@ModuleInfo private var model: Gemma3Model
375316
@ModuleInfo(key: "lm_head") var lmHead: Linear
376317

377318
public let config: Gemma3TextConfiguration
378319
public var vocabularySize: Int { config.vocabularySize }
379-
public var kvHeads: [Int]
380320

381321
public init(_ config: Gemma3TextConfiguration) {
382322
self.config = config
383323
self.model = Gemma3Model(config)
384324
self._lmHead.wrappedValue = Linear(config.hiddenSize, config.vocabularySize, bias: false)
385-
386-
// Set up KV heads array based on sliding window pattern
387-
var heads: [Int] = []
388-
for i in 0 ..< config.hiddenLayers {
389-
heads.append(config.kvHeads)
390-
}
391-
self.kvHeads = heads
392-
393325
super.init()
394326
}
395327

396328
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
397-
let out = model(inputs, cache: cache)
329+
let optionalCache = cache?.map { $0 as KVCache? }
330+
let out = model(inputs, cache: optionalCache)
398331
var finalLogits = lmHead(out)
399-
if let softcap = config.finalLogitSoftcapping {
400-
finalLogits = tanh(finalLogits / MLXArray(softcap)) * MLXArray(softcap)
332+
333+
// Apply final logit softcapping if configured
334+
if let softcap = config.finalLogitSoftcapping, softcap > 0 {
335+
let scale = MLXArray(softcap)
336+
finalLogits = tanh(finalLogits / scale) * scale
401337
}
402338
return finalLogits
403339
}
404340

341+
// TODO: Check this
405342
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
406343
var sanitizedWeights = weights
407-
// Copy embedding weights to lm_head if not present
408344
if sanitizedWeights["lm_head.weight"] == nil {
409345
if let embedWeight = sanitizedWeights["model.embed_tokens.weight"] {
410346
sanitizedWeights["lm_head.weight"] = embedWeight
347+
} else {
348+
print("Warning: Unable to find model.embed_tokens.weight for lm_head weight tying.")
411349
}
412350
}
413-
// Remove RoPE frequency weights as they're computed on the fly
351+
// Keep filtering RoPE keys if they exist in the checkpoint (though usually not saved)
414352
return sanitizedWeights.filter { key, _ in
415-
!key.contains("self_attn.rotary_emb.inv_freq")
353+
!key.contains("self_attn.rope.inv_freq")
354+
&& !key.contains("self_attn.rotary_emb.inv_freq")
355+
}
356+
}
357+
358+
// public func loraLinearLayers() -> LoRALinearLayers {
359+
// model.layers.map { ($0.selfAttention, ["q_proj", "v_proj", "k_proj", "o_proj"]) } // Add k/o proj? Check common practice
360+
// }
361+
362+
public func newCache(parameters: GenerateParameters? = nil) -> [KVCache] {
363+
var caches = [KVCache]()
364+
let slidingWindow = config.slidingWindow > 0 ? config.slidingWindow : 4096
365+
let slidingWindowPattern = config.slidingWindowPattern
366+
367+
for i in 0 ..< config.hiddenLayers {
368+
let isGlobalLayer = (i % slidingWindowPattern == slidingWindowPattern - 1)
369+
370+
if isGlobalLayer {
371+
caches.append(StandardKVCache())
372+
} else {
373+
caches.append(
374+
RotatingKVCache(maxSize: slidingWindow, keep: 0)
375+
)
376+
}
416377
}
378+
return caches
417379
}
418380
}
419381

0 commit comments

Comments
 (0)