Skip to content

Commit 739f84b

Browse files
committed
address issues that prevent using composition for layers like LoRA
- see ml-explore/mlx-swift-examples#167 - also fixes issue where quantize() could quantize a quantized layer!
1 parent 15f12e4 commit 739f84b

File tree

5 files changed

+204
-29
lines changed

5 files changed

+204
-29
lines changed

Source/MLXNN/Linear.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ open class Linear: Module, UnaryLayer, Quantizable {
7373
public let weight: MLXArray
7474
public let bias: MLXArray?
7575

76-
public var shape: (Int, Int) {
76+
open var shape: (Int, Int) {
7777
weight.shape2
7878
}
7979

Source/MLXNN/Module.swift

+66-19
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ open class Module {
9898

9999
/// Flag to indicate whether the module is being trained. Manipulated via
100100
/// ``train(_:)``.
101+
///
102+
/// ### See Also
103+
/// - ``didSetTrain(_:)``
101104
public private(set) var training = true
102105

103-
/// Set of property names that are frozen. Maniupulated via
104-
/// ``freeze(recursive:keys:strict:)`` and
105-
/// ``unfreeze(recursive:keys:strict:)``.
106-
public private(set) var noGrad = Set<String>()
106+
/// See ``noGrad()``
107+
private var _noGrad = Set<String>()
107108

108109
private var _items: ModuleItems!
109110
private var _setters: [String: TypeErasedSetter]!
@@ -139,7 +140,7 @@ open class Module {
139140
/// and ``update(parameters:)`` for example.
140141
///
141142
/// Subclasses could potentially override this to provide custom introspection.
142-
public func items() -> ModuleItems {
143+
open func items() -> ModuleItems {
143144
_items
144145
}
145146

@@ -222,7 +223,7 @@ open class Module {
222223
/// - ``mapParameters(map:isLeaf:)``
223224
/// - ``modules()``
224225
/// - ``items()``
225-
public func filterMap<Result>(
226+
open func filterMap<Result>(
226227
filter: (Module, String, ModuleItem) -> Bool,
227228
map: (ModuleItem) -> Result? = { $0 },
228229
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault
@@ -331,7 +332,7 @@ open class Module {
331332
/// ### See Also
332333
/// - <doc:module-filters>
333334
/// - ``mapParameters(map:)``
334-
public func mapParameters<Result>(
335+
open func mapParameters<Result>(
335336
map: @escaping (MLXArray) -> Result? = { $0 },
336337
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault
337338
) -> NestedDictionary<String, Result> {
@@ -343,28 +344,28 @@ open class Module {
343344

344345
/// Return a `NestedDictionary<String, MLXArray>` for all parameters in the
345346
/// model (all layers).
346-
public func parameters() -> ModuleParameters {
347+
open func parameters() -> ModuleParameters {
347348
filterMap(filter: Self.filterValidParameters, map: Self.mapParameters())
348349
}
349350

350351
/// Return a `NestedDictionary<String, MLXArray>` for all trainable parameters in the
351352
/// model (all layers).
352353
///
353354
/// This omits ``freeze(recursive:keys:strict:)`` (frozen) parameters.
354-
public func trainableParameters() -> ModuleParameters {
355+
open func trainableParameters() -> ModuleParameters {
355356
filterMap(filter: Self.filterTrainableParameters, map: Self.mapParameters())
356357
}
357358

358359
/// Produces a `NestedDictionary<String, Module>` for all direct children of the module.
359-
public func children() -> ModuleChildren {
360+
open func children() -> ModuleChildren {
360361
filterMap(filter: Self.filterValidChild, map: Self.mapModule(), isLeaf: Self.isLeafModule)
361362
}
362363

363364
/// Produces a `NestedDictionary<String, Module>` for all leaf modules module.
364365
///
365366
/// ### See Also
366367
/// - ``isLeafModuleNoChildren``
367-
public func leafModules() -> ModuleChildren {
368+
open func leafModules() -> ModuleChildren {
368369
filterMap(
369370
filter: Self.filterValidChild, map: Self.mapModule(),
370371
isLeaf: Self.isLeafModuleNoChildren)
@@ -710,7 +711,23 @@ open class Module {
710711
return self
711712
}
712713

713-
private func updateModule(key: String, _ value: Any) throws {
714+
/// Set a module to a new value.
715+
///
716+
/// The module property must be wrapped in a ``ModuleInfo``:
717+
///
718+
/// ```swift
719+
/// @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
720+
/// ```
721+
///
722+
/// and the value must be a compatible type.
723+
///
724+
/// This method is called via ``update(modules:)`` and is not typically called directly. This
725+
/// is exposed as an overridable method for subclasses.
726+
///
727+
/// - Parameters:
728+
/// - key: module key, see ``ModuleInfo``
729+
/// - value: the replacement module
730+
open func updateModule(key: String, _ value: Any) throws {
714731
if let setter = _setters[key] {
715732
do {
716733
try setter.updateModule(value)
@@ -727,7 +744,7 @@ open class Module {
727744
}
728745

729746
// `apply_to_modules()`
730-
public func visit(modules visitor: (String, Module) throws -> Void) rethrows {
747+
open func visit(modules visitor: (String, Module) throws -> Void) rethrows {
731748
var stack = [(String, Module)]()
732749
stack.append(("", self))
733750

@@ -746,7 +763,7 @@ open class Module {
746763
/// - ``namedModules()``
747764
/// - ``children()``
748765
/// - ``leafModules()``
749-
public func modules() -> [Module] {
766+
open func modules() -> [Module] {
750767
var result = [Module]()
751768
visit {
752769
result.append($1)
@@ -760,7 +777,7 @@ open class Module {
760777
/// - ``modules()``
761778
/// - ``children()``
762779
/// - ``leafModules()``
763-
public func namedModules() -> [(String, Module)] {
780+
open func namedModules() -> [(String, Module)] {
764781
var result = [(String, Module)]()
765782
visit {
766783
result.append(($0, $1))
@@ -822,7 +839,8 @@ open class Module {
822839
/// - ``unfreeze(recursive:keys:strict:)``
823840
open func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws {
824841
let visitor = freezeVisitor(keys: keys, strict: strict) {
825-
$0.noGrad.formUnion($1)
842+
$0._noGrad.formUnion($1)
843+
$0.didSetNoGrad($0._noGrad)
826844
}
827845

828846
if recursive {
@@ -859,7 +877,8 @@ open class Module {
859877
/// - ``Module/unfreeze(recursive:keys:strict:)``
860878
open func unfreeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws {
861879
let visitor = freezeVisitor(keys: keys, strict: strict) {
862-
$0.noGrad.subtract($1)
880+
$0._noGrad.subtract($1)
881+
$0.didSetNoGrad($0._noGrad)
863882
}
864883

865884
if recursive {
@@ -869,6 +888,24 @@ open class Module {
869888
}
870889
}
871890

891+
/// Set of property names that are frozen. Maniupulated via
892+
/// ``freeze(recursive:keys:strict:)`` and
893+
/// ``unfreeze(recursive:keys:strict:)``.
894+
open func noGrad() -> Set<String> {
895+
_noGrad
896+
}
897+
898+
/// Called when ``noGrad()`` is updated.
899+
///
900+
/// This is provided for subclasses to override.
901+
///
902+
/// - Parameter noGrad: set of properties that are frozen
903+
///
904+
/// ### See Also
905+
/// - ``noGrad()``
906+
open func didSetNoGrad(_ noGrad: Set<String>) {
907+
}
908+
872909
/// Recursively set the model's training mode.
873910
///
874911
/// Training mode only applies to certain layers. For example
@@ -877,11 +914,21 @@ open class Module {
877914
///
878915
/// ### See Also
879916
/// - ``training``
917+
/// - ``didSetTrain(_:)``
880918
public func train(_ mode: Bool = true) {
881919
visit(modules: {
882920
$1.training = mode
921+
$1.didSetTrain(mode)
883922
})
884923
}
924+
925+
/// Called when ``train(_:)`` is updated.
926+
///
927+
/// This is provided for subclasses to override.
928+
///
929+
/// - Parameter mode: `true` is training
930+
open func didSetTrain(_ mode: Bool) {
931+
}
885932
}
886933

887934
extension Module: IndentedDescription {
@@ -922,7 +969,7 @@ extension Module: Updatable, Evaluatable {
922969
/// ### See Also
923970
/// - <doc:layers>
924971
/// - ``Sequential``
925-
public protocol UnaryLayer {
972+
public protocol UnaryLayer: Module {
926973
func callAsFunction(_ x: MLXArray) -> MLXArray
927974
}
928975

@@ -996,7 +1043,7 @@ extension Module {
9961043
(module: Module, key: String, item: ModuleItem) in
9971044
switch item {
9981045
case .array, .dictionary, .value(.parameters), .value(.module):
999-
!key.hasPrefix("_") && !module.noGrad.contains(key)
1046+
!key.hasPrefix("_") && !module.noGrad().contains(key)
10001047
default: false
10011048
}
10021049
}

Source/MLXNN/Quantized.swift

+19-5
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,18 @@ public protocol Quantizable {
1111
func toQuantized(groupSize: Int, bits: Int) -> Module
1212
}
1313

14-
public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Module? {
15-
if let quantizable = layer as? Quantizable {
16-
quantizable.toQuantized(groupSize: groupSize, bits: bits)
14+
/// Protocol for layers that are quantized.
15+
public protocol Quantized: Module {
16+
var groupSize: Int { get }
17+
var bits: Int { get }
18+
}
19+
20+
public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Quantized? {
21+
if layer is Quantized {
22+
// already quantized
23+
nil
24+
} else if let quantizable = layer as? Quantizable {
25+
quantizable.toQuantized(groupSize: groupSize, bits: bits) as? Quantized
1726
} else {
1827
nil
1928
}
@@ -52,7 +61,7 @@ public func quantize(
5261
}
5362

5463
/// The same as ``Embedding`` but with a quantized weight matrix.
55-
open class QuantizedEmbedding: Embedding {
64+
open class QuantizedEmbedding: Embedding, Quantized {
5665

5766
public let groupSize: Int
5867
public let bits: Int
@@ -121,14 +130,19 @@ open class QuantizedEmbedding: Embedding {
121130
///
122131
/// ### See Also
123132
/// - ``init(weight:bias:groupSize:bits:)``
124-
open class QuantizedLinear: Linear {
133+
open class QuantizedLinear: Linear, Quantized {
125134

126135
public let groupSize: Int
127136
public let bits: Int
128137

129138
public let scales: MLXArray
130139
public let biases: MLXArray
131140

141+
open override var shape: (Int, Int) {
142+
let shape = weight.shape2
143+
return (shape.0, shape.1 * 32 / bits)
144+
}
145+
132146
/// Applies an affine transformation to the input using a quantized weight matrix.
133147
///
134148
/// This is the quantized version of ``Linear``. Typically this is used via ``quantize(model:groupSize:bits:predicate:)``.

Source/MLXNN/Transformer.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ open class MultiHeadAttention: Module {
1212

1313
public let numHeads: Int
1414

15-
@ModuleInfo(key: "query_proj") public var queryProjection: Linear
16-
@ModuleInfo(key: "key_proj") public var keyProjection: Linear
17-
@ModuleInfo(key: "value_proj") public var valueProjection: Linear
18-
@ModuleInfo(key: "out_proj") public var outProjection: Linear
15+
@ModuleInfo(key: "query_proj") public var queryProjection: UnaryLayer
16+
@ModuleInfo(key: "key_proj") public var keyProjection: UnaryLayer
17+
@ModuleInfo(key: "value_proj") public var valueProjection: UnaryLayer
18+
@ModuleInfo(key: "out_proj") public var outProjection: UnaryLayer
1919

2020
/// Implements the scaled dot product attention with multiple heads.
2121
///

0 commit comments

Comments
 (0)