@@ -98,12 +98,13 @@ open class Module {
98
98
99
99
/// Flag to indicate whether the module is being trained. Manipulated via
100
100
/// ``train(_:)``.
101
+ ///
102
+ /// ### See Also
103
+ /// - ``didSetTrain(_:)``
101
104
public private( set) var training = true
102
105
103
- /// Set of property names that are frozen. Maniupulated via
104
- /// ``freeze(recursive:keys:strict:)`` and
105
- /// ``unfreeze(recursive:keys:strict:)``.
106
- public private( set) var noGrad = Set < String > ( )
106
+ /// See ``noGrad()``
107
+ private var _noGrad = Set < String > ( )
107
108
108
109
private var _items : ModuleItems !
109
110
private var _setters : [ String : TypeErasedSetter ] !
@@ -139,7 +140,7 @@ open class Module {
139
140
/// and ``update(parameters:)`` for example.
140
141
///
141
142
/// Subclasses could potentially override this to provide custom introspection.
142
- public func items( ) -> ModuleItems {
143
+ open func items( ) -> ModuleItems {
143
144
_items
144
145
}
145
146
@@ -222,7 +223,7 @@ open class Module {
222
223
/// - ``mapParameters(map:isLeaf:)``
223
224
/// - ``modules()``
224
225
/// - ``items()``
225
- public func filterMap< Result> (
226
+ open func filterMap< Result> (
226
227
filter: ( Module , String , ModuleItem ) -> Bool ,
227
228
map: ( ModuleItem ) -> Result ? = { $0 } ,
228
229
isLeaf: ( Module , String , ModuleItem ) -> Bool = Module . isLeafDefault
@@ -331,7 +332,7 @@ open class Module {
331
332
/// ### See Also
332
333
/// - <doc:module-filters>
333
334
/// - ``mapParameters(map:)``
334
- public func mapParameters< Result> (
335
+ open func mapParameters< Result> (
335
336
map: @escaping ( MLXArray ) -> Result ? = { $0 } ,
336
337
isLeaf: ( Module , String , ModuleItem ) -> Bool = Module . isLeafDefault
337
338
) -> NestedDictionary < String , Result > {
@@ -343,28 +344,28 @@ open class Module {
343
344
344
345
/// Return a `NestedDictionary<String, MLXArray>` for all parameters in the
345
346
/// model (all layers).
346
- public func parameters( ) -> ModuleParameters {
347
+ open func parameters( ) -> ModuleParameters {
347
348
filterMap ( filter: Self . filterValidParameters, map: Self . mapParameters ( ) )
348
349
}
349
350
350
351
/// Return a `NestedDictionary<String, MLXArray>` for all trainable parameters in the
351
352
/// model (all layers).
352
353
///
353
354
/// This omits ``freeze(recursive:keys:strict:)`` (frozen) parameters.
354
- public func trainableParameters( ) -> ModuleParameters {
355
+ open func trainableParameters( ) -> ModuleParameters {
355
356
filterMap ( filter: Self . filterTrainableParameters, map: Self . mapParameters ( ) )
356
357
}
357
358
358
359
/// Produces a `NestedDictionary<String, Module>` for all direct children of the module.
359
- public func children( ) -> ModuleChildren {
360
+ open func children( ) -> ModuleChildren {
360
361
filterMap ( filter: Self . filterValidChild, map: Self . mapModule ( ) , isLeaf: Self . isLeafModule)
361
362
}
362
363
363
364
/// Produces a `NestedDictionary<String, Module>` for all leaf modules module.
364
365
///
365
366
/// ### See Also
366
367
/// - ``isLeafModuleNoChildren``
367
- public func leafModules( ) -> ModuleChildren {
368
+ open func leafModules( ) -> ModuleChildren {
368
369
filterMap (
369
370
filter: Self . filterValidChild, map: Self . mapModule ( ) ,
370
371
isLeaf: Self . isLeafModuleNoChildren)
@@ -710,7 +711,23 @@ open class Module {
710
711
return self
711
712
}
712
713
713
- private func updateModule( key: String , _ value: Any ) throws {
714
+ /// Set a module to a new value.
715
+ ///
716
+ /// The module property must be wrapped in a ``ModuleInfo``:
717
+ ///
718
+ /// ```swift
719
+ /// @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
720
+ /// ```
721
+ ///
722
+ /// and the value must be a compatible type.
723
+ ///
724
+ /// This method is called via ``update(modules:)`` and is not typically called directly. This
725
+ /// is exposed as an overridable method for subclasses.
726
+ ///
727
+ /// - Parameters:
728
+ /// - key: module key, see ``ModuleInfo``
729
+ /// - value: the replacement module
730
+ open func updateModule( key: String , _ value: Any ) throws {
714
731
if let setter = _setters [ key] {
715
732
do {
716
733
try setter. updateModule ( value)
@@ -727,7 +744,7 @@ open class Module {
727
744
}
728
745
729
746
// `apply_to_modules()`
730
- public func visit( modules visitor: ( String , Module ) throws -> Void ) rethrows {
747
+ open func visit( modules visitor: ( String , Module ) throws -> Void ) rethrows {
731
748
var stack = [ ( String, Module) ] ( )
732
749
stack. append ( ( " " , self ) )
733
750
@@ -746,7 +763,7 @@ open class Module {
746
763
/// - ``namedModules()``
747
764
/// - ``children()``
748
765
/// - ``leafModules()``
749
- public func modules( ) -> [ Module ] {
766
+ open func modules( ) -> [ Module ] {
750
767
var result = [ Module] ( )
751
768
visit {
752
769
result. append ( $1)
@@ -760,7 +777,7 @@ open class Module {
760
777
/// - ``modules()``
761
778
/// - ``children()``
762
779
/// - ``leafModules()``
763
- public func namedModules( ) -> [ ( String , Module ) ] {
780
+ open func namedModules( ) -> [ ( String , Module ) ] {
764
781
var result = [ ( String, Module) ] ( )
765
782
visit {
766
783
result. append ( ( $0, $1) )
@@ -822,7 +839,8 @@ open class Module {
822
839
/// - ``unfreeze(recursive:keys:strict:)``
823
840
open func freeze( recursive: Bool = true , keys: [ String ] ? = nil , strict: Bool = false ) throws {
824
841
let visitor = freezeVisitor ( keys: keys, strict: strict) {
825
- $0. noGrad. formUnion ( $1)
842
+ $0. _noGrad. formUnion ( $1)
843
+ $0. didSetNoGrad ( $0. _noGrad)
826
844
}
827
845
828
846
if recursive {
@@ -859,7 +877,8 @@ open class Module {
859
877
/// - ``Module/unfreeze(recursive:keys:strict:)``
860
878
open func unfreeze( recursive: Bool = true , keys: [ String ] ? = nil , strict: Bool = false ) throws {
861
879
let visitor = freezeVisitor ( keys: keys, strict: strict) {
862
- $0. noGrad. subtract ( $1)
880
+ $0. _noGrad. subtract ( $1)
881
+ $0. didSetNoGrad ( $0. _noGrad)
863
882
}
864
883
865
884
if recursive {
@@ -869,6 +888,24 @@ open class Module {
869
888
}
870
889
}
871
890
891
+ /// Set of property names that are frozen. Maniupulated via
892
+ /// ``freeze(recursive:keys:strict:)`` and
893
+ /// ``unfreeze(recursive:keys:strict:)``.
894
+ open func noGrad( ) -> Set < String > {
895
+ _noGrad
896
+ }
897
+
898
+ /// Called when ``noGrad()`` is updated.
899
+ ///
900
+ /// This is provided for subclasses to override.
901
+ ///
902
+ /// - Parameter noGrad: set of properties that are frozen
903
+ ///
904
+ /// ### See Also
905
+ /// - ``noGrad()``
906
+ open func didSetNoGrad( _ noGrad: Set < String > ) {
907
+ }
908
+
872
909
/// Recursively set the model's training mode.
873
910
///
874
911
/// Training mode only applies to certain layers. For example
@@ -877,11 +914,21 @@ open class Module {
877
914
///
878
915
/// ### See Also
879
916
/// - ``training``
917
+ /// - ``didSetTrain(_:)``
880
918
public func train( _ mode: Bool = true ) {
881
919
visit ( modules: {
882
920
$1. training = mode
921
+ $1. didSetTrain ( mode)
883
922
} )
884
923
}
924
+
925
+ /// Called when ``train(_:)`` is updated.
926
+ ///
927
+ /// This is provided for subclasses to override.
928
+ ///
929
+ /// - Parameter mode: `true` is training
930
+ open func didSetTrain( _ mode: Bool ) {
931
+ }
885
932
}
886
933
887
934
extension Module : IndentedDescription {
@@ -922,7 +969,7 @@ extension Module: Updatable, Evaluatable {
922
969
/// ### See Also
923
970
/// - <doc:layers>
924
971
/// - ``Sequential``
925
- public protocol UnaryLayer {
972
+ public protocol UnaryLayer : Module {
926
973
func callAsFunction( _ x: MLXArray ) -> MLXArray
927
974
}
928
975
@@ -996,7 +1043,7 @@ extension Module {
996
1043
( module: Module , key: String , item: ModuleItem ) in
997
1044
switch item {
998
1045
case . array, . dictionary, . value( . parameters) , . value( . module) :
999
- !key. hasPrefix ( " _ " ) && !module. noGrad. contains ( key)
1046
+ !key. hasPrefix ( " _ " ) && !module. noGrad ( ) . contains ( key)
1000
1047
default : false
1001
1048
}
1002
1049
}
0 commit comments