Skip to content

Commit 581c6cb

Browse files
davidkoskiawni
andauthored
add some documentation on porting models (#264)
* add some documentation on porting models Co-authored-by: Awni Hannun <[email protected]>
1 parent 1029eef commit 581c6cb

File tree

6 files changed

+731
-42
lines changed

6 files changed

+731
-42
lines changed

Libraries/MLXLLM/Models/Gemma.swift

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ private class RMSNorm: Module, UnaryLayer {
1616
public init(dimensions: Int, eps: Float = 1e-5) {
1717
self.weight = MLXArray.ones([dimensions])
1818
self.eps = eps
19-
super.init()
2019
}
2120

2221
public func callAsFunction(_ x: MLXArray) -> MLXArray {
@@ -106,19 +105,13 @@ private class MLP: Module, UnaryLayer {
106105
}
107106

108107
private class TransformerBlock: Module {
109-
let numAttentionHeads: Int
110-
let hiddenSize: Int
111-
112108
@ModuleInfo(key: "self_attn") var attention: Attention
113109
let mlp: MLP
114110

115111
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
116112
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
117113

118114
public init(_ args: GemmaConfiguration) {
119-
self.numAttentionHeads = args.attentionHeads
120-
self.hiddenSize = args.hiddenSize
121-
122115
self._attention.wrappedValue = Attention(args)
123116
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
124117
self._inputLayerNorm.wrappedValue = RMSNorm(
@@ -207,8 +200,10 @@ public struct GemmaConfiguration: Codable, Sendable {
207200
var rmsNormEps: Float
208201
var vocabularySize: Int
209202
var kvHeads: Int
210-
var ropeTheta: Float = 10_000
211-
var ropeTraditional: Bool = false
203+
private let _ropeTheta: Float?
204+
public var ropeTheta: Float { _ropeTheta ?? 10_000 }
205+
private let _ropeTraditional: Bool?
206+
public var ropeTraditional: Bool { _ropeTraditional ?? false }
212207

213208
enum CodingKeys: String, CodingKey {
214209
case modelType = "model_type"
@@ -220,38 +215,8 @@ public struct GemmaConfiguration: Codable, Sendable {
220215
case rmsNormEps = "rms_norm_eps"
221216
case vocabularySize = "vocab_size"
222217
case kvHeads = "num_key_value_heads"
223-
case ropeTheta = "rope_theta"
224-
case ropeTraditional = "rope_traditional"
225-
}
226-
227-
public init(from decoder: Decoder) throws {
228-
// Custom implementation to handle optional keys with required values
229-
let container: KeyedDecodingContainer<CodingKeys> = try decoder.container(
230-
keyedBy: CodingKeys.self)
231-
232-
self.modelType = try container.decode(
233-
String.self, forKey: CodingKeys.modelType)
234-
self.hiddenSize = try container.decode(
235-
Int.self, forKey: CodingKeys.hiddenSize)
236-
self.hiddenLayers = try container.decode(
237-
Int.self, forKey: CodingKeys.hiddenLayers)
238-
self.intermediateSize = try container.decode(
239-
Int.self, forKey: CodingKeys.intermediateSize)
240-
self.attentionHeads = try container.decode(
241-
Int.self, forKey: CodingKeys.attentionHeads)
242-
self.headDimensions = try container.decode(
243-
Int.self, forKey: CodingKeys.headDimensions)
244-
self.rmsNormEps = try container.decode(
245-
Float.self, forKey: CodingKeys.rmsNormEps)
246-
self.vocabularySize = try container.decode(
247-
Int.self, forKey: CodingKeys.vocabularySize)
248-
self.kvHeads = try container.decode(Int.self, forKey: CodingKeys.kvHeads)
249-
self.ropeTheta =
250-
try container.decodeIfPresent(Float.self, forKey: CodingKeys.ropeTheta)
251-
?? 10_000
252-
self.ropeTraditional =
253-
try container.decodeIfPresent(
254-
Bool.self, forKey: CodingKeys.ropeTraditional) ?? false
218+
case _ropeTheta = "rope_theta"
219+
case _ropeTraditional = "rope_traditional"
255220
}
256221
}
257222

Libraries/MLXLLM/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# MLXLLM
22

3+
# Documentation
4+
5+
- [Porting and implementing models](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxlmcommon/porting)
6+
- [MLXLLMCommon](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxlmcommon) -- common API for LLM and VLM
7+
- [MLXLLM](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxllm) -- large language model example implementations
8+
- [MLXVLM](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxvlm) -- vision language model example implementations
9+
10+
# Contents
11+
312
This is a port of several models from:
413

514
- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/

0 commit comments

Comments
 (0)