Skip to content

Commit 194bb10

Browse files
committed
Add text-only model
1 parent 283ddea commit 194bb10

File tree

3 files changed

+339
-1
lines changed

3 files changed

+339
-1
lines changed

Applications/LLMEval/ContentView.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class LLMEvaluator {
165165

166166
/// This controls which model loads. `qwen2_5_1_5b` is one of the smaller ones, so this will fit on
167167
/// more devices.
168-
let modelConfiguration = ModelRegistry.qwen2_5_1_5b
168+
let modelConfiguration = LLMRegistry.gemma3_1B_4bit
169169

170170
/// parameters controlling the output
171171
let generateParameters = GenerateParameters(temperature: 0.6)

Libraries/MLXLLM/LLMModelFactory.swift

+7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
4040
"cohere": create(CohereConfiguration.self, CohereModel.init),
4141
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
4242
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
43+
"gemma3_text": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
4344
]
4445
}
4546

@@ -166,6 +167,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
166167
defaultPrompt: "What is the difference between a fruit and a vegetable?"
167168
)
168169

170+
static public let gemma3_1B_4bit = ModelConfiguration(
171+
id: "mlx-community/gemma-3-1b-it-4bit",
172+
defaultPrompt: "What is the difference between a fruit and a vegetable?"
173+
)
174+
169175
private static func all() -> [ModelConfiguration] {
170176
[
171177
codeLlama13b4bit,
@@ -187,6 +193,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
187193
qwen2_5_7b,
188194
qwen2_5_1_5b,
189195
smolLM_135M_4bit,
196+
gemma3_1B_4bit,
190197
]
191198
}
192199

+331
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
//
2+
// Gemma3Text.swift
3+
// mlx-swift-examples
4+
//
5+
// Created by Anthony DePasquale on 14.03.2025.
6+
//
7+
8+
import Foundation
9+
import MLX
10+
import MLXFast
11+
import MLXLMCommon
12+
import MLXNN
13+
14+
public struct Gemma3TextConfiguration: Codable {
15+
let modelType: String
16+
let hiddenSize: Int
17+
let hiddenLayers: Int
18+
let intermediateSize: Int
19+
let attentionHeads: Int
20+
let headDim: Int
21+
let rmsNormEps: Float
22+
let vocabularySize: Int
23+
let kvHeads: Int
24+
let ropeGlobalBaseFreq: Float
25+
let ropeLocalBaseFreq: Float
26+
let ropeTraditional: Bool
27+
let queryPreAttnScalar: Float
28+
let slidingWindow: Int
29+
let slidingWindowPattern: Int
30+
31+
enum CodingKeys: String, CodingKey {
32+
case modelType = "model_type"
33+
case hiddenSize = "hidden_size"
34+
case hiddenLayers = "num_hidden_layers"
35+
case intermediateSize = "intermediate_size"
36+
case attentionHeads = "num_attention_heads"
37+
case headDim = "head_dim"
38+
case rmsNormEps = "rms_norm_eps"
39+
case vocabularySize = "vocab_size"
40+
case kvHeads = "num_key_value_heads"
41+
case ropeGlobalBaseFreq = "rope_global_base_freq"
42+
case ropeLocalBaseFreq = "rope_local_base_freq"
43+
case ropeTraditional = "rope_traditional"
44+
case queryPreAttnScalar = "query_pre_attn_scalar"
45+
case slidingWindow = "sliding_window"
46+
case slidingWindowPattern = "sliding_window_pattern"
47+
}
48+
49+
public init(from decoder: Decoder) throws {
50+
let container = try decoder.container(keyedBy: CodingKeys.self)
51+
52+
modelType = try container.decode(String.self, forKey: .modelType)
53+
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
54+
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
55+
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
56+
57+
// Default values with optional decoding
58+
attentionHeads = try container.decodeIfPresent(Int.self, forKey: .attentionHeads) ?? 4
59+
headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? 256
60+
rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1.0e-6
61+
vocabularySize = try container.decodeIfPresent(Int.self, forKey: .vocabularySize) ?? 262144
62+
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? 1
63+
ropeGlobalBaseFreq =
64+
try container.decodeIfPresent(Float.self, forKey: .ropeGlobalBaseFreq) ?? 1_000_000.0
65+
ropeLocalBaseFreq =
66+
try container.decodeIfPresent(Float.self, forKey: .ropeLocalBaseFreq) ?? 10_000.0
67+
ropeTraditional =
68+
try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false
69+
queryPreAttnScalar =
70+
try container.decodeIfPresent(Float.self, forKey: .queryPreAttnScalar) ?? 256
71+
slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512
72+
slidingWindowPattern =
73+
try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6
74+
}
75+
}
76+
77+
private class Attention: Module {
78+
let nHeads: Int
79+
let nKVHeads: Int
80+
let repeats: Int
81+
let headDim: Int
82+
let layerIdx: Int
83+
let scale: Float
84+
let isSliding: Bool
85+
86+
@ModuleInfo(key: "q_proj") var queryProj: Linear
87+
@ModuleInfo(key: "k_proj") var keyProj: Linear
88+
@ModuleInfo(key: "v_proj") var valueProj: Linear
89+
@ModuleInfo(key: "o_proj") var outputProj: Linear
90+
91+
@ModuleInfo(key: "q_norm") var queryNorm: GemmaUtils.RMSNorm
92+
@ModuleInfo(key: "k_norm") var keyNorm: GemmaUtils.RMSNorm
93+
94+
@ModuleInfo var rope: RoPE
95+
96+
init(_ config: Gemma3TextConfiguration, layerIdx: Int) {
97+
let dim = config.hiddenSize
98+
self.nHeads = config.attentionHeads
99+
self.nKVHeads = config.kvHeads
100+
self.repeats = nHeads / nKVHeads
101+
self.headDim = config.headDim
102+
self.layerIdx = layerIdx
103+
104+
self.scale = pow(config.queryPreAttnScalar, -0.5)
105+
106+
self._queryProj.wrappedValue = Linear(dim, nHeads * headDim, bias: false)
107+
self._keyProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: false)
108+
self._valueProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: false)
109+
self._outputProj.wrappedValue = Linear(nHeads * headDim, dim, bias: false)
110+
111+
self._queryNorm.wrappedValue = GemmaUtils.RMSNorm(
112+
dimensions: headDim, eps: config.rmsNormEps)
113+
self._keyNorm.wrappedValue = GemmaUtils.RMSNorm(dimensions: headDim, eps: config.rmsNormEps)
114+
115+
self.isSliding = (layerIdx + 1) % config.slidingWindowPattern != 0
116+
117+
let baseFreq = isSliding ? config.ropeLocalBaseFreq : config.ropeGlobalBaseFreq
118+
self._rope.wrappedValue = RoPE(
119+
dimensions: headDim,
120+
traditional: config.ropeTraditional,
121+
base: baseFreq
122+
)
123+
124+
super.init()
125+
}
126+
127+
func callAsFunction(
128+
_ x: MLXArray,
129+
mask: MLXArray? = nil,
130+
cache: KVCache? = nil
131+
) -> MLXArray {
132+
let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2))
133+
134+
var queries = queryProj(x)
135+
var keys = keyProj(x)
136+
var values = valueProj(x)
137+
138+
queries = queries.reshaped(B, L, nHeads, -1).transposed(0, 2, 1, 3)
139+
keys = keys.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
140+
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
141+
142+
queries = queryNorm(queries)
143+
keys = keyNorm(keys)
144+
145+
var localMask = mask
146+
147+
if let cache = cache {
148+
queries = rope(queries, offset: cache.offset)
149+
keys = rope(keys, offset: cache.offset)
150+
(keys, values) = cache.update(keys: keys, values: values)
151+
} else {
152+
queries = rope(queries)
153+
keys = rope(keys)
154+
}
155+
156+
// Sliding window mask adjustment
157+
if localMask != nil && localMask!.dim(-1) != keys.dim(-2) {
158+
let keyLen = keys.dim(-2)
159+
localMask = localMask![0..., 0..., 0..., (localMask!.dim(-1) - keyLen)...]
160+
}
161+
162+
let output = MLXFast.scaledDotProductAttention(
163+
queries: queries,
164+
keys: keys,
165+
values: values,
166+
scale: scale,
167+
mask: localMask
168+
)
169+
.transposed(0, 2, 1, 3)
170+
.reshaped(B, L, -1)
171+
172+
return outputProj(output)
173+
}
174+
}
175+
176+
private class MLP: Module {
177+
@ModuleInfo(key: "gate_proj") var gateProj: Linear
178+
@ModuleInfo(key: "down_proj") var downProj: Linear
179+
@ModuleInfo(key: "up_proj") var upProj: Linear
180+
181+
init(dimensions: Int, hiddenDimensions: Int) {
182+
self._gateProj.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
183+
self._downProj.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
184+
self._upProj.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
185+
super.init()
186+
}
187+
188+
func callAsFunction(_ x: MLXArray) -> MLXArray {
189+
return downProj(geluApproximate(gateProj(x)) * upProj(x))
190+
}
191+
}
192+
193+
private class TransformerBlock: Module {
194+
@ModuleInfo(key: "self_attn") var selfAttention: Attention
195+
@ModuleInfo var mlp: MLP
196+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
197+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
198+
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm
199+
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
200+
201+
let numAttentionHeads: Int
202+
let hiddenSize: Int
203+
204+
init(_ config: Gemma3TextConfiguration, layerIdx: Int) {
205+
self.numAttentionHeads = config.attentionHeads
206+
self.hiddenSize = config.hiddenSize
207+
208+
self._selfAttention.wrappedValue = Attention(config, layerIdx: layerIdx)
209+
self.mlp = MLP(dimensions: config.hiddenSize, hiddenDimensions: config.intermediateSize)
210+
211+
self._inputLayerNorm.wrappedValue = RMSNorm(
212+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
213+
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
214+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
215+
self._preFeedforwardLayerNorm.wrappedValue = RMSNorm(
216+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
217+
self._postFeedforwardLayerNorm.wrappedValue = RMSNorm(
218+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
219+
220+
super.init()
221+
}
222+
223+
func callAsFunction(
224+
_ x: MLXArray,
225+
mask: MLXArray? = nil
226+
) -> MLXArray {
227+
let r = selfAttention(inputLayerNorm(x), mask: mask, cache: nil)
228+
let h = x + postAttentionLayerNorm(r)
229+
let r2 = mlp(preFeedforwardLayerNorm(h))
230+
let out = h + postFeedforwardLayerNorm(r2)
231+
return out
232+
}
233+
}
234+
235+
private class Gemma3Model: Module {
236+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
237+
@ModuleInfo var layers: [TransformerBlock]
238+
@ModuleInfo var norm: RMSNorm
239+
240+
let config: Gemma3TextConfiguration
241+
242+
init(_ config: Gemma3TextConfiguration) {
243+
self.config = config
244+
245+
self._embedTokens.wrappedValue = Embedding(
246+
embeddingCount: config.vocabularySize,
247+
dimensions: config.hiddenSize
248+
)
249+
250+
self._layers.wrappedValue = (0 ..< config.hiddenLayers).map { layerIdx in
251+
TransformerBlock(config, layerIdx: layerIdx)
252+
}
253+
254+
self.norm = RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps)
255+
256+
super.init()
257+
}
258+
259+
func callAsFunction(_ inputs: MLXArray, mask: MLXArray? = nil) -> MLXArray {
260+
var h = embedTokens(inputs)
261+
h = h * sqrt(Float(config.hiddenSize))
262+
263+
var fullMask: MLXArray? = nil
264+
var slidingWindowMask: MLXArray? = nil
265+
266+
if mask == nil {
267+
let j = config.slidingWindowPattern
268+
slidingWindowMask = createAttentionMask(h: h, cache: nil)
269+
}
270+
271+
for (i, layer) in layers.enumerated() {
272+
let isSliding = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
273+
274+
var layerMask = mask
275+
if mask == nil {
276+
layerMask = isSliding ? slidingWindowMask : fullMask
277+
}
278+
279+
h = layer(h, mask: layerMask)
280+
}
281+
282+
return norm(h)
283+
}
284+
}
285+
286+
public class Gemma3TextModel: Module, LLMModel, KVCacheDimensionProvider {
287+
@ModuleInfo private var model: Gemma3Model
288+
@ModuleInfo(key: "lm_head") var lmHead: Linear
289+
290+
public let config: Gemma3TextConfiguration
291+
public var vocabularySize: Int { config.vocabularySize }
292+
public var kvHeads: [Int]
293+
294+
public init(_ config: Gemma3TextConfiguration) {
295+
self.config = config
296+
self.model = Gemma3Model(config)
297+
self._lmHead.wrappedValue = Linear(config.hiddenSize, config.vocabularySize, bias: false)
298+
299+
// Set up KV heads array based on sliding window pattern
300+
var heads: [Int] = []
301+
for i in 0 ..< config.hiddenLayers {
302+
heads.append(config.kvHeads)
303+
}
304+
self.kvHeads = heads
305+
306+
super.init()
307+
}
308+
309+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
310+
let out = model(inputs, mask: nil)
311+
return lmHead(out)
312+
}
313+
314+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
315+
var sanitizedWeights = weights
316+
317+
if sanitizedWeights["lm_head.weight"] == nil {
318+
sanitizedWeights["lm_head.weight"] = sanitizedWeights["model.embed_tokens.weight"]
319+
}
320+
321+
return sanitizedWeights.filter { key, _ in
322+
!key.contains("self_attn.rotary_emb.inv_freq")
323+
}
324+
}
325+
}
326+
327+
extension Gemma3TextModel: LoRAModel {
328+
public func loraLinearLayers() -> LoRALinearLayers {
329+
model.layers.map { ($0.selfAttention, ["q_proj", "v_proj"]) }
330+
}
331+
}

0 commit comments

Comments
 (0)