@@ -16,7 +16,6 @@ private class RMSNorm: Module, UnaryLayer {
16
16
public init ( dimensions: Int , eps: Float = 1e-5 ) {
17
17
self . weight = MLXArray . ones ( [ dimensions] )
18
18
self . eps = eps
19
- super. init ( )
20
19
}
21
20
22
21
public func callAsFunction( _ x: MLXArray ) -> MLXArray {
@@ -106,19 +105,13 @@ private class MLP: Module, UnaryLayer {
106
105
}
107
106
108
107
private class TransformerBlock : Module {
109
- let numAttentionHeads : Int
110
- let hiddenSize : Int
111
-
112
108
@ModuleInfo ( key: " self_attn " ) var attention : Attention
113
109
let mlp : MLP
114
110
115
111
@ModuleInfo ( key: " input_layernorm " ) var inputLayerNorm : RMSNorm
116
112
@ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayerNorm : RMSNorm
117
113
118
114
public init ( _ args: GemmaConfiguration ) {
119
- self . numAttentionHeads = args. attentionHeads
120
- self . hiddenSize = args. hiddenSize
121
-
122
115
self . _attention. wrappedValue = Attention ( args)
123
116
self . mlp = MLP ( dimensions: args. hiddenSize, hiddenDimensions: args. intermediateSize)
124
117
self . _inputLayerNorm. wrappedValue = RMSNorm (
@@ -207,8 +200,10 @@ public struct GemmaConfiguration: Codable, Sendable {
207
200
var rmsNormEps : Float
208
201
var vocabularySize : Int
209
202
var kvHeads : Int
210
- var ropeTheta : Float = 10_000
211
- var ropeTraditional : Bool = false
203
+ private let _ropeTheta : Float ?
204
+ public var ropeTheta : Float { _ropeTheta ?? 10_000 }
205
+ private let _ropeTraditional : Bool ?
206
+ public var ropeTraditional : Bool { _ropeTraditional ?? false }
212
207
213
208
enum CodingKeys : String , CodingKey {
214
209
case modelType = " model_type "
@@ -220,38 +215,8 @@ public struct GemmaConfiguration: Codable, Sendable {
220
215
case rmsNormEps = " rms_norm_eps "
221
216
case vocabularySize = " vocab_size "
222
217
case kvHeads = " num_key_value_heads "
223
- case ropeTheta = " rope_theta "
224
- case ropeTraditional = " rope_traditional "
225
- }
226
-
227
- public init ( from decoder: Decoder ) throws {
228
- // Custom implementation to handle optional keys with required values
229
- let container : KeyedDecodingContainer < CodingKeys > = try decoder. container (
230
- keyedBy: CodingKeys . self)
231
-
232
- self . modelType = try container. decode (
233
- String . self, forKey: CodingKeys . modelType)
234
- self . hiddenSize = try container. decode (
235
- Int . self, forKey: CodingKeys . hiddenSize)
236
- self . hiddenLayers = try container. decode (
237
- Int . self, forKey: CodingKeys . hiddenLayers)
238
- self . intermediateSize = try container. decode (
239
- Int . self, forKey: CodingKeys . intermediateSize)
240
- self . attentionHeads = try container. decode (
241
- Int . self, forKey: CodingKeys . attentionHeads)
242
- self . headDimensions = try container. decode (
243
- Int . self, forKey: CodingKeys . headDimensions)
244
- self . rmsNormEps = try container. decode (
245
- Float . self, forKey: CodingKeys . rmsNormEps)
246
- self . vocabularySize = try container. decode (
247
- Int . self, forKey: CodingKeys . vocabularySize)
248
- self . kvHeads = try container. decode ( Int . self, forKey: CodingKeys . kvHeads)
249
- self . ropeTheta =
250
- try container. decodeIfPresent ( Float . self, forKey: CodingKeys . ropeTheta)
251
- ?? 10_000
252
- self . ropeTraditional =
253
- try container. decodeIfPresent (
254
- Bool . self, forKey: CodingKeys . ropeTraditional) ?? false
218
+ case _ropeTheta = " rope_theta "
219
+ case _ropeTraditional = " rope_traditional "
255
220
}
256
221
}
257
222
0 commit comments