Skip to content

Commit fc0be87

Browse files
smdesaiSachin Desai
and
Sachin Desai
authored
adding support for Granite (#284)
Co-authored-by: Sachin Desai <[email protected]>
1 parent 4c681c5 commit fc0be87

File tree

3 files changed

+293
-3
lines changed

3 files changed

+293
-3
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,12 @@ class LLMEvaluator {
247247
private func generate(prompt: String) async {
248248

249249
self.output = ""
250+
let chat: [Chat.Message] = [
251+
.system("You are a helpful assistant"),
252+
.user(prompt),
253+
]
250254
let userInput = UserInput(
251-
prompt: prompt,
252-
additionalContext: ["enable_thinking": enableThinking]
253-
)
255+
chat: chat, additionalContext: ["enable_thinking": enableThinking])
254256

255257
do {
256258
let modelContainer = try await load()

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
4242
"cohere": create(CohereConfiguration.self, CohereModel.init),
4343
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
4444
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
45+
"granite": create(GraniteConfiguration.self, GraniteModel.init),
4546
]
4647
}
4748

@@ -193,13 +194,19 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
193194
defaultPrompt: "What is the difference between a fruit and a vegetable?"
194195
)
195196

197+
static public let granite3_3_2b_4bit = ModelConfiguration(
198+
id: "mlx-community/granite-3.3-2b-instruct-4bit",
199+
defaultPrompt: ""
200+
)
201+
196202
private static func all() -> [ModelConfiguration] {
197203
[
198204
codeLlama13b4bit,
199205
deepSeekR1_7B_4bit,
200206
gemma2bQuantized,
201207
gemma_2_2b_it_4bit,
202208
gemma_2_9b_it_4bit,
209+
granite3_3_2b_4bit,
203210
llama3_1_8B_4bit,
204211
llama3_2_1B_4bit,
205212
llama3_2_3B_4bit,

Libraries/MLXLLM/Models/Granite.swift

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
//
2+
// Granite.swift
3+
// mlx-swift-examples
4+
//
5+
// Created by Sachin Desai on 4/25/25.
6+
//
7+
8+
// Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/granite.py
9+
10+
import Foundation
11+
import MLX
12+
import MLXFast
13+
import MLXLMCommon
14+
import MLXNN
15+
16+
private class Attention: Module {
17+
let args: GraniteConfiguration
18+
let scale: Float
19+
20+
@ModuleInfo(key: "q_proj") var wq: Linear
21+
@ModuleInfo(key: "k_proj") var wk: Linear
22+
@ModuleInfo(key: "v_proj") var wv: Linear
23+
@ModuleInfo(key: "o_proj") var wo: Linear
24+
25+
let rope: RoPE
26+
27+
public init(_ args: GraniteConfiguration) {
28+
self.args = args
29+
30+
let dim = args.hiddenSize
31+
let nHeads = args.attentionHeads
32+
let nKvHeads = args.kvHeads
33+
let headDim = dim / nHeads
34+
35+
self.scale = args.attentionMultiplier
36+
let attentionBias = args.attentionBias
37+
38+
self._wq.wrappedValue = Linear(dim, nHeads * headDim, bias: attentionBias)
39+
self._wk.wrappedValue = Linear(dim, nKvHeads * headDim, bias: attentionBias)
40+
self._wv.wrappedValue = Linear(dim, nKvHeads * headDim, bias: attentionBias)
41+
self._wo.wrappedValue = Linear(nHeads * headDim, dim, bias: attentionBias)
42+
43+
let ropeScale: Float
44+
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
45+
let factor = ropeScaling["factor"]
46+
{
47+
if let v = factor.asFloat() {
48+
ropeScale = 1 / v
49+
} else {
50+
fatalError("ropeScaling.factor must be a float")
51+
}
52+
} else {
53+
ropeScale = 1
54+
}
55+
rope = RoPE(dimensions: headDim, traditional: false, base: args.ropeTheta, scale: ropeScale)
56+
}
57+
58+
public func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray {
59+
let (B, L) = (x.dim(0), x.dim(1))
60+
61+
var queries = wq(x)
62+
var keys = wk(x)
63+
var values = wv(x)
64+
65+
// prepare the queries, keys and values for the attention computation
66+
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
67+
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
68+
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
69+
70+
if let cache {
71+
queries = rope(queries, offset: cache.offset)
72+
keys = rope(keys, offset: cache.offset)
73+
(keys, values) = cache.update(keys: keys, values: values)
74+
} else {
75+
queries = rope(queries)
76+
keys = rope(keys)
77+
}
78+
79+
let output = MLXFast.scaledDotProductAttention(
80+
queries: queries, keys: keys, values: values, scale: self.scale, mask: mask
81+
)
82+
.transposed(0, 2, 1, 3)
83+
.reshaped(B, L, -1)
84+
85+
return wo(output)
86+
}
87+
}
88+
89+
private class MLP: Module, UnaryLayer {
90+
@ModuleInfo(key: "gate_proj") var gate: Linear
91+
@ModuleInfo(key: "down_proj") var down: Linear
92+
@ModuleInfo(key: "up_proj") var up: Linear
93+
94+
public init(_ args: GraniteConfiguration) {
95+
let dim = args.hiddenSize
96+
let hiddenDim = args.intermediateSize
97+
let mlpBias = args.mlpBias
98+
99+
self._gate.wrappedValue = Linear(dim, hiddenDim, bias: mlpBias)
100+
self._down.wrappedValue = Linear(hiddenDim, dim, bias: mlpBias)
101+
self._up.wrappedValue = Linear(dim, hiddenDim, bias: mlpBias)
102+
}
103+
104+
public func callAsFunction(_ x: MLXArray) -> MLXArray {
105+
down(silu(gate(x)) * up(x))
106+
}
107+
}
108+
109+
private class TransformerBlock: Module {
110+
@ModuleInfo(key: "self_attn") var attention: Attention
111+
let mlp: MLP
112+
113+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
114+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
115+
116+
let residualMultiplier: Float
117+
118+
public init(_ args: GraniteConfiguration) {
119+
let attentionHeads = args.attentionHeads
120+
let hiddenSize = args.hiddenSize
121+
122+
self._attention.wrappedValue = Attention(args)
123+
self.mlp = MLP(args)
124+
125+
self._inputLayerNorm.wrappedValue = RMSNorm(
126+
dimensions: hiddenSize, eps: args.rmsNormEps)
127+
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
128+
dimensions: hiddenSize, eps: args.rmsNormEps)
129+
130+
self.residualMultiplier = args.residualMultiplier
131+
}
132+
133+
public func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray {
134+
var r = attention(inputLayerNorm(x), mask: mask, cache: cache)
135+
let h = x + r * residualMultiplier
136+
r = mlp(postAttentionLayerNorm(h))
137+
let out = h + r * residualMultiplier
138+
return out
139+
}
140+
}
141+
142+
private class GraniteModelInner: Module {
143+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
144+
fileprivate let layers: [TransformerBlock]
145+
let norm: RMSNorm
146+
let embeddingMultiplier: Float
147+
148+
public init(_ args: GraniteConfiguration) {
149+
precondition(args.vocabularySize > 0)
150+
151+
self._embedTokens.wrappedValue = Embedding(
152+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
153+
self.layers = (0 ..< args.hiddenLayers)
154+
.map { _ in
155+
TransformerBlock(args)
156+
}
157+
158+
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
159+
self.embeddingMultiplier = args.embeddingMultiplier
160+
}
161+
162+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
163+
var h = embedTokens(inputs) * embeddingMultiplier
164+
165+
let mask: MLXArray? = createAttentionMask(h: h, cache: cache)
166+
167+
for (i, layer) in layers.enumerated() {
168+
h = layer(h, mask: mask, cache: cache?[i])
169+
}
170+
171+
return norm(h)
172+
}
173+
}
174+
175+
public class GraniteModel: Module, LLMModel, KVCacheDimensionProvider {
176+
public let vocabularySize: Int
177+
public let kvHeads: [Int]
178+
let logitsScaling: Float
179+
180+
private let model: GraniteModelInner
181+
let configuration: GraniteConfiguration
182+
183+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
184+
185+
public init(_ args: GraniteConfiguration) {
186+
self.configuration = args
187+
self.vocabularySize = args.vocabularySize
188+
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
189+
190+
self.model = GraniteModelInner(args)
191+
192+
if !args.tieWordEmbeddings {
193+
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
194+
}
195+
self.logitsScaling = args.logitsScaling
196+
}
197+
198+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
199+
var out = model(inputs, cache: cache)
200+
if let lmHead {
201+
out = lmHead(out)
202+
} else {
203+
out = model.embedTokens.asLinear(out)
204+
}
205+
206+
return out / logitsScaling
207+
}
208+
}
209+
210+
public struct GraniteConfiguration: Codable, Sendable {
211+
var hiddenSize: Int
212+
var hiddenLayers: Int
213+
var intermediateSize: Int
214+
var attentionHeads: Int
215+
var rmsNormEps: Float
216+
var vocabularySize: Int
217+
var logitsScaling: Float
218+
var attentionMultiplier: Float
219+
var embeddingMultiplier: Float
220+
var residualMultiplier: Float
221+
var maxPositionEmbeddings: Int
222+
var kvHeads: Int
223+
var attentionBias: Bool
224+
var mlpBias: Bool
225+
var ropeTheta: Float
226+
var ropeTraditional: Bool = false
227+
var ropeScaling: [String: StringOrNumber]? = nil
228+
var tieWordEmbeddings: Bool = true
229+
230+
enum CodingKeys: String, CodingKey {
231+
case hiddenSize = "hidden_size"
232+
case hiddenLayers = "num_hidden_layers"
233+
case intermediateSize = "intermediate_size"
234+
case attentionHeads = "num_attention_heads"
235+
case rmsNormEps = "rms_norm_eps"
236+
case vocabularySize = "vocab_size"
237+
case logitsScaling = "logits_scaling"
238+
case attentionMultiplier = "attention_multiplier"
239+
case embeddingMultiplier = "embedding_multiplier"
240+
case residualMultiplier = "residual_multiplier"
241+
case maxPositionEmbeddings = "max_position_embeddings"
242+
case kvHeads = "num_key_value_heads"
243+
case attentionBias = "attention_bias"
244+
case mlpBias = "mlp_bias"
245+
case ropeTheta = "rope_theta"
246+
case ropeScaling = "rope_scaling"
247+
case tieWordEmbeddings = "tie_word_embeddings"
248+
}
249+
250+
public init(from decoder: Decoder) throws {
251+
let container: KeyedDecodingContainer<GraniteConfiguration.CodingKeys> =
252+
try decoder.container(keyedBy: GraniteConfiguration.CodingKeys.self)
253+
254+
self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
255+
self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
256+
self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
257+
self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
258+
self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
259+
self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
260+
self.logitsScaling = try container.decode(Float.self, forKey: .logitsScaling)
261+
self.attentionMultiplier = try container.decode(Float.self, forKey: .attentionMultiplier)
262+
self.embeddingMultiplier = try container.decode(Float.self, forKey: .embeddingMultiplier)
263+
self.residualMultiplier = try container.decode(Float.self, forKey: .residualMultiplier)
264+
self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
265+
self.kvHeads = try container.decode(Int.self, forKey: .kvHeads)
266+
self.attentionBias = try container.decode(Bool.self, forKey: .attentionBias)
267+
self.mlpBias = try container.decode(Bool.self, forKey: .mlpBias) ?? false
268+
self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10000000.0
269+
self.ropeScaling = try container.decodeIfPresent(
270+
[String: StringOrNumber].self, forKey: .ropeScaling)
271+
self.tieWordEmbeddings = try container.decode(Bool.self, forKey: .tieWordEmbeddings)
272+
}
273+
}
274+
275+
// MARK: - LoRA
276+
277+
extension GraniteModel: LoRAModel {
278+
public func loraLinearLayers() -> LoRALinearLayers {
279+
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
280+
}
281+
}

0 commit comments

Comments
 (0)