@@ -3,225 +3,181 @@ import Logging
3
3
@preconcurrency import llama
4
4
5
5
public actor LLama {
6
- private let logger = Logger . llama
7
- private let model : OpaquePointer
8
- private let context : OpaquePointer
9
- private let sampling : UnsafeMutablePointer < llama_sampler >
10
- private var batch : llama_batch
11
- private var tokensList : [ llama_token ]
12
- private var temporaryInvalidCChars : [ CChar ]
13
- private var isDone = false
14
-
15
- private var nLen : Int32 = 1024
16
- private var nCur : Int32 = 0
17
- private var nDecode : Int32 = 0
18
-
19
- // MARK: - Init & teardown
20
-
21
- public init ( modelPath: String , contextSize: UInt32 = 2048 ) throws {
22
- llama_backend_init ( )
23
- let modelParams = llama_model_default_params ( )
24
-
25
- #if targetEnvironment(simulator)
26
- modelParams. n_gpu_layers = 0
27
- logger. debug ( " Running on simulator, force use n_gpu_layers = 0 " )
28
- #endif
29
-
30
- guard let model = llama_load_model_from_file ( modelPath, modelParams) else {
31
- llama_backend_free ( )
32
- throw InitializationError ( message: " Failed to load model " , code: . failedToLoadModel)
6
+ private let logger = Logger . llama
7
+ private let modelLoader : Model
8
+ private let sampling : UnsafeMutablePointer < llama_sampler >
9
+ private var batch : llama_batch
10
+ private var tokensList : [ llama_token ]
11
+ private var temporaryInvalidCChars : [ CChar ]
12
+ private var isDone = false
13
+
14
+ private var nLen : Int32 = 1024
15
+ private var nCur : Int32 = 0
16
+ private var nDecode : Int32 = 0
17
+
18
+ // MARK: - Init & teardown
19
+
20
+ public init ( modelLoader: Model ) {
21
+ self . modelLoader = modelLoader
22
+
23
+ // Initialize sampling
24
+ let sparams = llama_sampler_chain_default_params ( )
25
+ self . sampling = llama_sampler_chain_init ( sparams)
26
+ llama_sampler_chain_add ( self . sampling, llama_sampler_init_temp ( 0.8 ) )
27
+ llama_sampler_chain_add ( self . sampling, llama_sampler_init_softmax ( ) )
28
+ llama_sampler_chain_add ( self . sampling, llama_sampler_init_dist ( 1234 ) )
29
+
30
+ // Initialize batch and token list
31
+ self . batch = llama_batch_init ( 512 , 0 , 1 )
32
+ self . tokensList = [ ]
33
+ self . temporaryInvalidCChars = [ ]
33
34
}
34
- self . model = model
35
-
36
- // Initialize context parameters
37
- let nThreads = max ( 1 , min ( 8 , ProcessInfo . processInfo. processorCount - 2 ) )
38
- logger. debug ( " Using \( nThreads) threads " )
39
-
40
- var ctxParams = llama_context_default_params ( )
41
- ctxParams. n_ctx = contextSize
42
- ctxParams. n_threads = Int32 ( nThreads)
43
- ctxParams. n_threads_batch = Int32 ( nThreads)
44
-
45
- guard let context = llama_new_context_with_model ( model, ctxParams) else {
46
- llama_free_model ( model)
47
- llama_backend_free ( )
48
- throw InitializationError (
49
- message: " Failed to initialize context " , code: . failedToInitializeContext)
50
- }
51
- self . context = context
52
-
53
- // Initialize sampling
54
- let sparams = llama_sampler_chain_default_params ( )
55
- self . sampling = llama_sampler_chain_init ( sparams)
56
- llama_sampler_chain_add ( self . sampling, llama_sampler_init_temp ( 0.8 ) )
57
- llama_sampler_chain_add ( self . sampling, llama_sampler_init_softmax ( ) )
58
- llama_sampler_chain_add ( self . sampling, llama_sampler_init_dist ( 1234 ) )
59
-
60
- // Initialize batch and token list
61
- self . batch = llama_batch_init ( 512 , 0 , 1 )
62
- self . tokensList = [ ]
63
- self . temporaryInvalidCChars = [ ]
64
- }
65
35
66
- deinit {
67
- llama_batch_free ( batch)
68
- llama_backend_free ( )
69
- }
36
+ deinit {
37
+ llama_batch_free ( batch)
38
+ // llama_sampler_free(sampling )
39
+ }
70
40
71
- // MARK: - Inference
72
- public func infer( prompt: String , maxTokens: Int32 = 128 ) -> AsyncThrowingStream < String , Error > {
73
- return AsyncThrowingStream ( String . self, bufferingPolicy: . unbounded) { continuation in
74
- Task {
75
- do {
76
- try completionInit ( text: prompt)
77
- } catch {
78
- continuation. finish ( throwing: error)
79
- return
41
+ // MARK: - Inference
42
+
43
+ public func infer( prompt: String , maxTokens: Int32 = 128 ) -> AsyncThrowingStream < String , Error > {
44
+ return AsyncThrowingStream { continuation in
45
+ Task {
46
+ do {
47
+ try self . completionInit ( text: prompt)
48
+ } catch {
49
+ continuation. finish ( throwing: error)
50
+ return
51
+ }
52
+ while !self . isDone && self . nCur < self . nLen && self . nCur - self . batch. n_tokens < maxTokens {
53
+ guard !Task. isCancelled else {
54
+ continuation. finish ( )
55
+ return
56
+ }
57
+ let newTokenStr = self . completionLoop ( )
58
+ continuation. yield ( newTokenStr)
59
+ }
60
+ continuation. finish ( )
61
+ }
80
62
}
81
- while !isDone && nCur < nLen && nCur - batch. n_tokens < maxTokens {
82
- guard !Task. isCancelled else {
83
- continuation. finish ( )
84
- return
85
- }
86
- let newTokenStr = completionLoop ( )
87
- continuation. yield ( newTokenStr)
88
- }
89
- continuation. finish ( )
90
- }
91
63
}
92
- }
93
64
94
- // MARK: - Private helpers
65
+ // MARK: - Private helpers
95
66
96
- private func llamaBatchClear( _ batch: inout llama_batch ) {
97
- batch. n_tokens = 0
98
- }
67
+ private func completionInit( text: String ) throws {
68
+ logger. debug ( " Attempting to complete \" \( text) \" " )
99
69
100
- private func llamaBatchAdd(
101
- _ batch: inout llama_batch , _ id: llama_token , _ pos: llama_pos , _ seq_ids: [ llama_seq_id ] ,
102
- _ logits: Bool
103
- ) {
104
- batch. token [ Int ( batch. n_tokens) ] = id
105
- batch. pos [ Int ( batch. n_tokens) ] = pos
106
- batch. n_seq_id [ Int ( batch. n_tokens) ] = Int32 ( seq_ids. count)
107
- for i in 0 ..< seq_ids. count {
108
- batch. seq_id [ Int ( batch. n_tokens) ] ![ Int ( i) ] = seq_ids [ i]
109
- }
110
- batch. logits [ Int ( batch. n_tokens) ] = logits ? 1 : 0
70
+ tokensList = tokenize ( text: text, add_bos: true )
71
+ temporaryInvalidCChars = [ ]
111
72
112
- batch . n_tokens += 1
113
- }
73
+ let nCtx = llama_n_ctx ( modelLoader . context )
74
+ let nKvReq = tokensList . count + Int ( nLen ) - tokensList . count
114
75
115
- private func completionInit( text: String ) throws {
116
- logger. debug ( " Attempting to complete \" \( text) \" " )
76
+ logger. debug ( " \n n_len = \( self . nLen) , n_ctx = \( nCtx) , n_kv_req = \( nKvReq) " )
117
77
118
- tokensList = tokenize ( text: text, add_bos: true )
119
- temporaryInvalidCChars = [ ]
120
-
121
- let nCtx = llama_n_ctx ( context)
122
- let nKvReq = tokensList. count + Int( nLen) - tokensList. count
78
+ if nKvReq > nCtx {
79
+ logger. error ( " Error: n_kv_req > n_ctx, the required KV cache size is not big enough " )
80
+ throw InferError ( message: " KV cache too small " , code: . kvCacheFailure)
81
+ }
123
82
124
- logger . debug ( " \n n_len = \( self . nLen ) , n_ctx = \( nCtx ) , n_kv_req = \( nKvReq ) " )
83
+ batch . clear ( )
125
84
126
- if nKvReq > nCtx {
127
- logger. error ( " Error: n_kv_req > n_ctx, the required KV cache size is not big enough " )
128
- throw InferError ( message: " KV cache too small " , code: . kvCacheFailure)
129
- }
85
+ for (i, token) in tokensList. enumerated ( ) {
86
+ llamaBatchAdd ( & batch, token, Int32 ( i) , [ 0 ] , false )
87
+ }
88
+ if batch. n_tokens > 0 {
89
+ batch. logits [ Int ( batch. n_tokens) - 1 ] = 1 // true
90
+ }
130
91
131
- batch. clear ( )
92
+ if llama_decode ( modelLoader. context, batch) != 0 {
93
+ throw InferError ( message: " llama_decode failed " , code: . decodingFailure)
94
+ }
132
95
133
- for (i, token) in tokensList. enumerated ( ) {
134
- llamaBatchAdd ( & batch, token, Int32 ( i) , [ 0 ] , false )
135
- }
136
- if batch. n_tokens > 0 {
137
- batch. logits [ Int ( batch. n_tokens) - 1 ] = 1 // true
96
+ nCur = batch. n_tokens
138
97
}
139
98
140
- if llama_decode ( context , batch ) != 0 {
141
- throw InferError ( message : " llama_decode failed " , code : . decodingFailure )
142
- }
99
+ private func completionLoop ( ) -> String {
100
+ var newTokenID : llama_token = 0
101
+ newTokenID = llama_sampler_sample ( sampling , modelLoader . context , batch . n_tokens - 1 )
143
102
144
- nCur = batch. n_tokens
145
- }
103
+ if llama_token_is_eog ( modelLoader. model, newTokenID) || nCur == nLen {
104
+ isDone = true
105
+ let newTokenStr = String ( cString: temporaryInvalidCChars + [ 0 ] )
106
+ temporaryInvalidCChars. removeAll ( )
107
+ return newTokenStr
108
+ }
146
109
147
- private func completionLoop( ) -> String {
148
- var newTokenID : llama_token = 0
149
- newTokenID = llama_sampler_sample ( sampling, context, batch. n_tokens - 1 )
110
+ let newTokenCChars = tokenToPieceArray ( token: newTokenID)
111
+ temporaryInvalidCChars. append ( contentsOf: newTokenCChars + [ 0 ] )
112
+ let newTokenStr : String
113
+
114
+ if let string = String ( validatingUTF8: temporaryInvalidCChars) {
115
+ temporaryInvalidCChars. removeAll ( )
116
+ newTokenStr = string
117
+ } else if let partialStr = attemptPartialString ( from: temporaryInvalidCChars) {
118
+ temporaryInvalidCChars. removeAll ( )
119
+ newTokenStr = partialStr
120
+ } else {
121
+ newTokenStr = " "
122
+ }
150
123
151
- if llama_token_is_eog ( model, newTokenID) || nCur == nLen {
152
- isDone = true
153
- let newTokenStr = String ( cString: temporaryInvalidCChars + [ 0 ] )
154
- temporaryInvalidCChars. removeAll ( )
155
- return newTokenStr
156
- }
124
+ batch. clear ( )
125
+ llamaBatchAdd ( & batch, newTokenID, nCur, [ 0 ] , true )
157
126
158
- let newTokenCChars = tokenToPieceArray ( token: newTokenID)
159
- temporaryInvalidCChars. append ( contentsOf: newTokenCChars + [ 0 ] )
160
- let newTokenStr : String
161
-
162
- if let string = String ( validatingUTF8: temporaryInvalidCChars) {
163
- temporaryInvalidCChars. removeAll ( )
164
- newTokenStr = string
165
- } else if let partialStr = attemptPartialString ( from: temporaryInvalidCChars) {
166
- temporaryInvalidCChars. removeAll ( )
167
- newTokenStr = partialStr
168
- } else {
169
- newTokenStr = " "
170
- }
127
+ nDecode += 1
128
+ nCur += 1
171
129
172
- batch. clear ( )
173
- llamaBatchAdd ( & batch, newTokenID, nCur, [ 0 ] , true )
130
+ if llama_decode ( modelLoader. context, batch) != 0 {
131
+ logger. error ( " Failed to evaluate llama! " )
132
+ }
174
133
175
- nDecode += 1
176
- nCur += 1
134
+ return newTokenStr
135
+ }
177
136
178
- if llama_decode ( context, batch) != 0 {
179
- print ( " Failed to evaluate llama! " )
137
+ private func llamaBatchAdd(
138
+ _ batch: inout llama_batch , _ id: llama_token , _ pos: llama_pos , _ seq_ids: [ llama_seq_id ] ,
139
+ _ logits: Bool
140
+ ) {
141
+ batch. token [ Int ( batch. n_tokens) ] = id
142
+ batch. pos [ Int ( batch. n_tokens) ] = pos
143
+ batch. n_seq_id [ Int ( batch. n_tokens) ] = Int32 ( seq_ids. count)
144
+ for i in 0 ..< seq_ids. count {
145
+ batch. seq_id [ Int ( batch. n_tokens) ] ![ Int ( i) ] = seq_ids [ i]
180
146
}
147
+ batch. logits [ Int ( batch. n_tokens) ] = logits ? 1 : 0
181
148
182
- return newTokenStr
149
+ batch . n_tokens += 1
183
150
}
184
151
185
- private func tokenize( text: String , add_bos: Bool ) -> [ llama_token ] {
186
- let utf8Data = text. utf8CString
187
- let nTokens = Int32 ( utf8Data. count) + ( add_bos ? 1 : 0 )
188
- let tokens = UnsafeMutablePointer< llama_token> . allocate( capacity: Int ( nTokens) )
189
- defer { tokens. deallocate ( ) }
190
152
191
- let tokenCount = llama_tokenize (
192
- model , text , Int32 ( utf8Data. count ) , tokens , Int32 ( nTokens ) , add_bos , false )
193
- guard tokenCount > 0 else {
194
- return [ ]
195
- }
153
+ private func tokenize ( text : String , add_bos : Bool ) -> [ llama_token ] {
154
+ let utf8Data = text . utf8CString
155
+ let nTokens = Int32 ( utf8Data . count ) + ( add_bos ? 1 : 0 )
156
+ let tokens = UnsafeMutablePointer < llama_token > . allocate ( capacity : Int ( nTokens ) )
157
+ defer { tokens . deallocate ( ) }
196
158
197
- return Array ( UnsafeBufferPointer ( start: tokens, count: Int ( tokenCount) ) )
198
- }
199
-
200
- private func tokenToPiece( token: llama_token ) -> String ? {
201
- var result = [ CChar] ( repeating: 0 , count: 8 )
202
- var nTokens = llama_token_to_piece ( model, token, & result, 8 , 0 , false )
159
+ let tokenCount = llama_tokenize (
160
+ modelLoader. model, text, Int32 ( utf8Data. count) , tokens, Int32 ( nTokens) , add_bos, false )
161
+ guard tokenCount > 0 else {
162
+ return [ ]
163
+ }
203
164
204
- if nTokens < 0 {
205
- let requiredSize = - nTokens
206
- result = [ CChar] ( repeating: 0 , count: Int ( requiredSize) )
207
- nTokens = llama_token_to_piece ( model, token, & result, requiredSize, 0 , false )
165
+ return Array ( UnsafeBufferPointer ( start: tokens, count: Int ( tokenCount) ) )
208
166
}
209
167
210
- return String ( cString: result)
211
- }
168
+ private func tokenToPieceArray( token: llama_token ) -> [ CChar ] {
169
+ var buffer = [ CChar] ( repeating: 0 , count: 8 )
170
+ var nTokens = llama_token_to_piece ( modelLoader. model, token, & buffer, 8 , 0 , false )
212
171
213
- private func tokenToPieceArray( token: llama_token ) -> [ CChar ] {
214
- var buffer = [ CChar] ( repeating: 0 , count: 8 )
215
- var nTokens = llama_token_to_piece ( model, token, & buffer, 8 , 0 , false )
172
+ if nTokens < 0 {
173
+ let requiredSize = - nTokens
174
+ buffer = [ CChar] ( repeating: 0 , count: Int ( requiredSize) )
175
+ nTokens = llama_token_to_piece ( modelLoader. model, token, & buffer, requiredSize, 0 , false )
176
+ }
216
177
217
- if nTokens < 0 {
218
- let requiredSize = - nTokens
219
- buffer = [ CChar] ( repeating: 0 , count: Int ( requiredSize) )
220
- nTokens = llama_token_to_piece ( model, token, & buffer, requiredSize, 0 , false )
178
+ return Array ( buffer. prefix ( Int ( nTokens) ) )
221
179
}
222
180
223
- return Array ( buffer. prefix ( Int ( nTokens) ) )
224
- }
225
181
226
182
private func attemptPartialString( from cchars: [ CChar ] ) -> String ? {
227
183
for i in ( 1 ..< cchars. count) . reversed ( ) {
0 commit comments