From 04c93469fff1001c9e2709ee5ad3dc83d724f82b Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 16 Dec 2024 13:32:37 -0800 Subject: [PATCH 1/2] exploration of LoRA using composition --- Libraries/MLXLLM/Models/Gemma2.swift | 126 ++++++++++++++++++++++++++- 1 file changed, 122 insertions(+), 4 deletions(-) diff --git a/Libraries/MLXLLM/Models/Gemma2.swift b/Libraries/MLXLLM/Models/Gemma2.swift index 48e04820..0c222d35 100644 --- a/Libraries/MLXLLM/Models/Gemma2.swift +++ b/Libraries/MLXLLM/Models/Gemma2.swift @@ -5,6 +5,7 @@ import MLX import MLXFast import MLXLMCommon import MLXNN +import MLXRandom // Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py @@ -33,10 +34,10 @@ private class Attention: Module { let nKVHeads: Int let repeats: Int - @ModuleInfo(key: "q_proj") var wq: Linear - @ModuleInfo(key: "k_proj") var wk: Linear - @ModuleInfo(key: "v_proj") var wv: Linear - @ModuleInfo(key: "o_proj") var wo: Linear + @ModuleInfo(key: "q_proj") var wq: UnaryLayer + @ModuleInfo(key: "k_proj") var wk: UnaryLayer + @ModuleInfo(key: "v_proj") var wv: UnaryLayer + @ModuleInfo(key: "o_proj") var wo: UnaryLayer let rope: RoPE @@ -288,3 +289,120 @@ extension Gemma2Model: LoRAModel { model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } } } + +// TODO - notes +// +// - make UnaryLayer extend Module +// - make a Quantized protocol that provides the groupSize and bits +// - make the QuantizedLinear shape produce the expanded shape +// - make `items()` open +// - make `updateModule(key:_:)` open +// +// - evaluation and training should work as expected +// - this flattens the weights and modules into one layer +// - to match the normal lora implementation +// - see items() and updateModule() + +// TODO: make UnaryLayer extend Module +public protocol UnaryLayer2: Module { + func callAsFunction(_ x: MLXArray) -> MLXArray +} + +/// LoRA layer that can wrap any UnaryLayer +class LoRA: Module, UnaryLayer2 { + + let adapts: UnaryLayer2 + let scale: Float + + @ParameterInfo(key: "lora_a") var loraA: MLXArray + @ParameterInfo(key: "lora_b") var loraB: MLXArray + + public init( + adapts: UnaryLayer2, inputDimensions: Int, outputDimensions: Int, rank: Int = 8, + scale: Float = 20.0 + ) { + self.adapts = adapts + + self.scale = scale + + let loraScale = 1 / sqrt(Float(inputDimensions)) + self._loraA.wrappedValue = MLXRandom.uniform( + low: -loraScale, high: loraScale, [inputDimensions, rank]) + self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) + + freeze() + } + + // TODO: in LoRALinear this is + // public static func from(linear: Linear, rank: Int = 8) -> LoRA + public convenience init(linear: Linear, rank: Int = 8, scale: Float = 20.0) { + var (outputDimensions, inputDimensions) = linear.shape + + if let linear = linear as? QuantizedLinear { + // TODO Linear should probably have a property to return these directly + // rather than shape which represents the physical shape of the layers + inputDimensions = inputDimensions * 32 / linear.bits + } + + self.init( + adapts: linear, + inputDimensions: inputDimensions, outputDimensions: outputDimensions, + rank: rank, scale: scale) + } + + // produce a merged view of properties (flatten LoRA into adapts) + override func items() -> ModuleItems { + var result = adapts.items() + for (key, value) in super.items() { + if key == "adapts" { continue } + result[key] = value + } + return result + } + + // forward module updates -> adapt + func updateModule(key: String, _ value: Any) throws { + try adapts.updateModule(key: key, value) + } + + override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws + { + try adapts.freeze(recursive: recursive, keys: keys, strict: strict) + } + + // TODO: this requires knowledge of the innards of the adapted layer so it + // is specific to Linear (and QuantizedLinear). + public func toLinear(deQuantize: Bool = false) -> Linear { + // TODO throws? failable? + guard let linear = adapts as? Linear else { fatalError("Not a Linear") } + + var weight: MLXArray + if let quantized = linear as? QuantizedLinear { + weight = dequantized( + quantized.weight, scales: quantized.scales, biases: quantized.biases, + groupSize: quantized.groupSize, bits: quantized.bits) + } else { + weight = linear.weight + } + + let loraB = (scale * loraB.T).asType(.float16) + let loraA = loraA.T.asType(.float16) + let mergedWeight = weight + matmul(loraB, loraA) + + // TODO maybe add a protocol for Quanitzed + if let quantized = linear as? QuantizedLinear { + return QuantizedLinear( + weight: mergedWeight, bias: quantized.bias, + groupSize: quantized.groupSize, bits: quantized.bits) + } else { + return Linear(weight: mergedWeight, bias: linear.bias) + } + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + // TODO let y = super.callAsFunction(x.asType(scales.dtype)) -- ignoring the asType here + let y = adapts(x) + let z = matmul(matmul(x, self.loraA), self.loraB) + return y + scale * z + } +} From 42d0049271f772f1ec4844dced03fe45ce86f474 Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 16 Dec 2024 13:40:13 -0800 Subject: [PATCH 2/2] add note on noGrad --- Libraries/MLXLLM/Models/Gemma2.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/Libraries/MLXLLM/Models/Gemma2.swift b/Libraries/MLXLLM/Models/Gemma2.swift index 0c222d35..213ecbf4 100644 --- a/Libraries/MLXLLM/Models/Gemma2.swift +++ b/Libraries/MLXLLM/Models/Gemma2.swift @@ -297,6 +297,7 @@ extension Gemma2Model: LoRAModel { // - make the QuantizedLinear shape produce the expanded shape // - make `items()` open // - make `updateModule(key:_:)` open +// - make `noGrad` overridable (turn into function?) // // - evaluation and training should work as expected // - this flattens the weights and modules into one layer