Skip to content

Commit 1b21b4e

Browse files
committed
Extracted model loading logic into a separate class to help with readability and memory managment
1 parent f4ce44b commit 1b21b4e

File tree

3 files changed

+188
-183
lines changed

3 files changed

+188
-183
lines changed

Sources/llama-cpp-swift/LLama.swift

Lines changed: 138 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -3,225 +3,181 @@ import Logging
33
@preconcurrency import llama
44

55
public actor LLama {
6-
private let logger = Logger.llama
7-
private let model: OpaquePointer
8-
private let context: OpaquePointer
9-
private let sampling: UnsafeMutablePointer<llama_sampler>
10-
private var batch: llama_batch
11-
private var tokensList: [llama_token]
12-
private var temporaryInvalidCChars: [CChar]
13-
private var isDone = false
14-
15-
private var nLen: Int32 = 1024
16-
private var nCur: Int32 = 0
17-
private var nDecode: Int32 = 0
18-
19-
// MARK: - Init & teardown
20-
21-
public init(modelPath: String, contextSize: UInt32 = 2048) throws {
22-
llama_backend_init()
23-
let modelParams = llama_model_default_params()
24-
25-
#if targetEnvironment(simulator)
26-
modelParams.n_gpu_layers = 0
27-
logger.debug("Running on simulator, force use n_gpu_layers = 0")
28-
#endif
29-
30-
guard let model = llama_load_model_from_file(modelPath, modelParams) else {
31-
llama_backend_free()
32-
throw InitializationError(message: "Failed to load model", code: .failedToLoadModel)
6+
private let logger = Logger.llama
7+
private let modelLoader: Model
8+
private let sampling: UnsafeMutablePointer<llama_sampler>
9+
private var batch: llama_batch
10+
private var tokensList: [llama_token]
11+
private var temporaryInvalidCChars: [CChar]
12+
private var isDone = false
13+
14+
private var nLen: Int32 = 1024
15+
private var nCur: Int32 = 0
16+
private var nDecode: Int32 = 0
17+
18+
// MARK: - Init & teardown
19+
20+
public init(modelLoader: Model) {
21+
self.modelLoader = modelLoader
22+
23+
// Initialize sampling
24+
let sparams = llama_sampler_chain_default_params()
25+
self.sampling = llama_sampler_chain_init(sparams)
26+
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.8))
27+
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
28+
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
29+
30+
// Initialize batch and token list
31+
self.batch = llama_batch_init(512, 0, 1)
32+
self.tokensList = []
33+
self.temporaryInvalidCChars = []
3334
}
34-
self.model = model
35-
36-
// Initialize context parameters
37-
let nThreads = max(1, min(8, ProcessInfo.processInfo.processorCount - 2))
38-
logger.debug("Using \(nThreads) threads")
39-
40-
var ctxParams = llama_context_default_params()
41-
ctxParams.n_ctx = contextSize
42-
ctxParams.n_threads = Int32(nThreads)
43-
ctxParams.n_threads_batch = Int32(nThreads)
44-
45-
guard let context = llama_new_context_with_model(model, ctxParams) else {
46-
llama_free_model(model)
47-
llama_backend_free()
48-
throw InitializationError(
49-
message: "Failed to initialize context", code: .failedToInitializeContext)
50-
}
51-
self.context = context
52-
53-
// Initialize sampling
54-
let sparams = llama_sampler_chain_default_params()
55-
self.sampling = llama_sampler_chain_init(sparams)
56-
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.8))
57-
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
58-
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
59-
60-
// Initialize batch and token list
61-
self.batch = llama_batch_init(512, 0, 1)
62-
self.tokensList = []
63-
self.temporaryInvalidCChars = []
64-
}
6535

66-
deinit {
67-
llama_batch_free(batch)
68-
llama_backend_free()
69-
}
36+
deinit {
37+
llama_batch_free(batch)
38+
// llama_sampler_free(sampling)
39+
}
7040

71-
// MARK: - Inference
72-
public func infer(prompt: String, maxTokens: Int32 = 128) -> AsyncThrowingStream<String, Error> {
73-
return AsyncThrowingStream(String.self, bufferingPolicy: .unbounded) { continuation in
74-
Task {
75-
do {
76-
try completionInit(text: prompt)
77-
} catch {
78-
continuation.finish(throwing: error)
79-
return
41+
// MARK: - Inference
42+
43+
public func infer(prompt: String, maxTokens: Int32 = 128) -> AsyncThrowingStream<String, Error> {
44+
return AsyncThrowingStream { continuation in
45+
Task {
46+
do {
47+
try self.completionInit(text: prompt)
48+
} catch {
49+
continuation.finish(throwing: error)
50+
return
51+
}
52+
while !self.isDone && self.nCur < self.nLen && self.nCur - self.batch.n_tokens < maxTokens {
53+
guard !Task.isCancelled else {
54+
continuation.finish()
55+
return
56+
}
57+
let newTokenStr = self.completionLoop()
58+
continuation.yield(newTokenStr)
59+
}
60+
continuation.finish()
61+
}
8062
}
81-
while !isDone && nCur < nLen && nCur - batch.n_tokens < maxTokens {
82-
guard !Task.isCancelled else {
83-
continuation.finish()
84-
return
85-
}
86-
let newTokenStr = completionLoop()
87-
continuation.yield(newTokenStr)
88-
}
89-
continuation.finish()
90-
}
9163
}
92-
}
9364

94-
// MARK: - Private helpers
65+
// MARK: - Private helpers
9566

96-
private func llamaBatchClear(_ batch: inout llama_batch) {
97-
batch.n_tokens = 0
98-
}
67+
private func completionInit(text: String) throws {
68+
logger.debug("Attempting to complete \"\(text)\"")
9969

100-
private func llamaBatchAdd(
101-
_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id],
102-
_ logits: Bool
103-
) {
104-
batch.token[Int(batch.n_tokens)] = id
105-
batch.pos[Int(batch.n_tokens)] = pos
106-
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
107-
for i in 0..<seq_ids.count {
108-
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
109-
}
110-
batch.logits[Int(batch.n_tokens)] = logits ? 1 : 0
70+
tokensList = tokenize(text: text, add_bos: true)
71+
temporaryInvalidCChars = []
11172

112-
batch.n_tokens += 1
113-
}
73+
let nCtx = llama_n_ctx(modelLoader.context)
74+
let nKvReq = tokensList.count + Int(nLen) - tokensList.count
11475

115-
private func completionInit(text: String) throws {
116-
logger.debug("Attempting to complete \"\(text)\"")
76+
logger.debug("\nn_len = \(self.nLen), n_ctx = \(nCtx), n_kv_req = \(nKvReq)")
11777

118-
tokensList = tokenize(text: text, add_bos: true)
119-
temporaryInvalidCChars = []
120-
121-
let nCtx = llama_n_ctx(context)
122-
let nKvReq = tokensList.count + Int(nLen) - tokensList.count
78+
if nKvReq > nCtx {
79+
logger.error("Error: n_kv_req > n_ctx, the required KV cache size is not big enough")
80+
throw InferError(message: "KV cache too small", code: .kvCacheFailure)
81+
}
12382

124-
logger.debug("\nn_len = \(self.nLen), n_ctx = \(nCtx), n_kv_req = \(nKvReq)")
83+
batch.clear()
12584

126-
if nKvReq > nCtx {
127-
logger.error("Error: n_kv_req > n_ctx, the required KV cache size is not big enough")
128-
throw InferError(message: "KV cache too small", code: .kvCacheFailure)
129-
}
85+
for (i, token) in tokensList.enumerated() {
86+
llamaBatchAdd(&batch, token, Int32(i), [0], false)
87+
}
88+
if batch.n_tokens > 0 {
89+
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
90+
}
13091

131-
batch.clear()
92+
if llama_decode(modelLoader.context, batch) != 0 {
93+
throw InferError(message: "llama_decode failed", code: .decodingFailure)
94+
}
13295

133-
for (i, token) in tokensList.enumerated() {
134-
llamaBatchAdd(&batch, token, Int32(i), [0], false)
135-
}
136-
if batch.n_tokens > 0 {
137-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
96+
nCur = batch.n_tokens
13897
}
13998

140-
if llama_decode(context, batch) != 0 {
141-
throw InferError(message: "llama_decode failed", code: .decodingFailure)
142-
}
99+
private func completionLoop() -> String {
100+
var newTokenID: llama_token = 0
101+
newTokenID = llama_sampler_sample(sampling, modelLoader.context, batch.n_tokens - 1)
143102

144-
nCur = batch.n_tokens
145-
}
103+
if llama_token_is_eog(modelLoader.model, newTokenID) || nCur == nLen {
104+
isDone = true
105+
let newTokenStr = String(cString: temporaryInvalidCChars + [0])
106+
temporaryInvalidCChars.removeAll()
107+
return newTokenStr
108+
}
146109

147-
private func completionLoop() -> String {
148-
var newTokenID: llama_token = 0
149-
newTokenID = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
110+
let newTokenCChars = tokenToPieceArray(token: newTokenID)
111+
temporaryInvalidCChars.append(contentsOf: newTokenCChars + [0])
112+
let newTokenStr: String
113+
114+
if let string = String(validatingUTF8: temporaryInvalidCChars) {
115+
temporaryInvalidCChars.removeAll()
116+
newTokenStr = string
117+
} else if let partialStr = attemptPartialString(from: temporaryInvalidCChars) {
118+
temporaryInvalidCChars.removeAll()
119+
newTokenStr = partialStr
120+
} else {
121+
newTokenStr = ""
122+
}
150123

151-
if llama_token_is_eog(model, newTokenID) || nCur == nLen {
152-
isDone = true
153-
let newTokenStr = String(cString: temporaryInvalidCChars + [0])
154-
temporaryInvalidCChars.removeAll()
155-
return newTokenStr
156-
}
124+
batch.clear()
125+
llamaBatchAdd(&batch, newTokenID, nCur, [0], true)
157126

158-
let newTokenCChars = tokenToPieceArray(token: newTokenID)
159-
temporaryInvalidCChars.append(contentsOf: newTokenCChars + [0])
160-
let newTokenStr: String
161-
162-
if let string = String(validatingUTF8: temporaryInvalidCChars) {
163-
temporaryInvalidCChars.removeAll()
164-
newTokenStr = string
165-
} else if let partialStr = attemptPartialString(from: temporaryInvalidCChars) {
166-
temporaryInvalidCChars.removeAll()
167-
newTokenStr = partialStr
168-
} else {
169-
newTokenStr = ""
170-
}
127+
nDecode += 1
128+
nCur += 1
171129

172-
batch.clear()
173-
llamaBatchAdd(&batch, newTokenID, nCur, [0], true)
130+
if llama_decode(modelLoader.context, batch) != 0 {
131+
logger.error("Failed to evaluate llama!")
132+
}
174133

175-
nDecode += 1
176-
nCur += 1
134+
return newTokenStr
135+
}
177136

178-
if llama_decode(context, batch) != 0 {
179-
print("Failed to evaluate llama!")
137+
private func llamaBatchAdd(
138+
_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id],
139+
_ logits: Bool
140+
) {
141+
batch.token[Int(batch.n_tokens)] = id
142+
batch.pos[Int(batch.n_tokens)] = pos
143+
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
144+
for i in 0..<seq_ids.count {
145+
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
180146
}
147+
batch.logits[Int(batch.n_tokens)] = logits ? 1 : 0
181148

182-
return newTokenStr
149+
batch.n_tokens += 1
183150
}
184151

185-
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
186-
let utf8Data = text.utf8CString
187-
let nTokens = Int32(utf8Data.count) + (add_bos ? 1 : 0)
188-
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: Int(nTokens))
189-
defer { tokens.deallocate() }
190152

191-
let tokenCount = llama_tokenize(
192-
model, text, Int32(utf8Data.count), tokens, Int32(nTokens), add_bos, false)
193-
guard tokenCount > 0 else {
194-
return []
195-
}
153+
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
154+
let utf8Data = text.utf8CString
155+
let nTokens = Int32(utf8Data.count) + (add_bos ? 1 : 0)
156+
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: Int(nTokens))
157+
defer { tokens.deallocate() }
196158

197-
return Array(UnsafeBufferPointer(start: tokens, count: Int(tokenCount)))
198-
}
199-
200-
private func tokenToPiece(token: llama_token) -> String? {
201-
var result = [CChar](repeating: 0, count: 8)
202-
var nTokens = llama_token_to_piece(model, token, &result, 8, 0, false)
159+
let tokenCount = llama_tokenize(
160+
modelLoader.model, text, Int32(utf8Data.count), tokens, Int32(nTokens), add_bos, false)
161+
guard tokenCount > 0 else {
162+
return []
163+
}
203164

204-
if nTokens < 0 {
205-
let requiredSize = -nTokens
206-
result = [CChar](repeating: 0, count: Int(requiredSize))
207-
nTokens = llama_token_to_piece(model, token, &result, requiredSize, 0, false)
165+
return Array(UnsafeBufferPointer(start: tokens, count: Int(tokenCount)))
208166
}
209167

210-
return String(cString: result)
211-
}
168+
private func tokenToPieceArray(token: llama_token) -> [CChar] {
169+
var buffer = [CChar](repeating: 0, count: 8)
170+
var nTokens = llama_token_to_piece(modelLoader.model, token, &buffer, 8, 0, false)
212171

213-
private func tokenToPieceArray(token: llama_token) -> [CChar] {
214-
var buffer = [CChar](repeating: 0, count: 8)
215-
var nTokens = llama_token_to_piece(model, token, &buffer, 8, 0, false)
172+
if nTokens < 0 {
173+
let requiredSize = -nTokens
174+
buffer = [CChar](repeating: 0, count: Int(requiredSize))
175+
nTokens = llama_token_to_piece(modelLoader.model, token, &buffer, requiredSize, 0, false)
176+
}
216177

217-
if nTokens < 0 {
218-
let requiredSize = -nTokens
219-
buffer = [CChar](repeating: 0, count: Int(requiredSize))
220-
nTokens = llama_token_to_piece(model, token, &buffer, requiredSize, 0, false)
178+
return Array(buffer.prefix(Int(nTokens)))
221179
}
222180

223-
return Array(buffer.prefix(Int(nTokens)))
224-
}
225181

226182
private func attemptPartialString(from cchars: [CChar]) -> String? {
227183
for i in (1..<cchars.count).reversed() {

0 commit comments

Comments
 (0)