Skip to content

Commit f35df96

Browse files
authored
fix #231 -- use Embedding.asLinear (#232)
1 parent ea86d94 commit f35df96

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

Libraries/MLXLLM/Models/Cohere.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ public class CohereModel: Module, LLMModel, KVCacheDimensionProvider {
163163

164164
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
165165
var out = model(inputs, cache: cache)
166-
out = matmul(out, model.embedTokens.weight.T)
166+
out = model.embedTokens.asLinear(out)
167167
out = out * self.logitScale
168168
return out
169169
}

Libraries/MLXLLM/Models/OpenELM.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public class OpenELMModel: Module, LLMModel, KVCacheDimensionProvider {
199199
if let lmHead {
200200
out = lmHead(out)
201201
} else {
202-
out = matmul(out, transformer.embedTokens.weight.T)
202+
out = transformer.embedTokens.asLinear(out)
203203
}
204204

205205
return out

Libraries/MLXLLM/Models/Starcoder2.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ public class Starcoder2Model: Module, LLMModel, KVCacheDimensionProvider {
173173
if !tieWordEmbeddings {
174174
return lmHead(out)
175175
} else {
176-
out = matmul(out, model.embedTokens.weight.T)
176+
out = model.embedTokens.asLinear(out)
177177
return out
178178
}
179179
}

0 commit comments

Comments
 (0)