@@ -11,6 +11,8 @@ actor LlamaContext {
11
11
private var context : OpaquePointer
12
12
private var batch : llama_batch
13
13
private var tokens_list : [ llama_token ]
14
+ /// This variable is used to store temporarily invalid cchars
15
+ private var temporary_invalid_cchars : [ CChar ]
14
16
15
17
var n_len : Int32 = 512
16
18
var n_cur : Int32 = 0
@@ -21,6 +23,7 @@ actor LlamaContext {
21
23
self . context = context
22
24
self . tokens_list = [ ]
23
25
self . batch = llama_batch_init ( 512 , 0 , 1 )
26
+ self . temporary_invalid_cchars = [ ]
24
27
}
25
28
26
29
deinit {
@@ -61,6 +64,7 @@ actor LlamaContext {
61
64
print ( " attempting to complete \" \( text) \" " )
62
65
63
66
tokens_list = tokenize ( text: text, add_bos: true )
67
+ temporary_invalid_cchars = [ ]
64
68
65
69
let n_ctx = llama_n_ctx ( context)
66
70
let n_kv_req = tokens_list. count + ( Int ( n_len) - tokens_list. count)
@@ -72,7 +76,7 @@ actor LlamaContext {
72
76
}
73
77
74
78
for id in tokens_list {
75
- print ( token_to_piece ( token: id) )
79
+ print ( String ( cString : token_to_piece ( token: id) + [ 0 ] ) )
76
80
}
77
81
78
82
// batch = llama_batch_init(512, 0) // done in init()
@@ -115,10 +119,25 @@ actor LlamaContext {
115
119
116
120
if new_token_id == llama_token_eos ( context) || n_cur == n_len {
117
121
print ( " \n " )
118
- return " "
122
+ let new_token_str = String ( cString: temporary_invalid_cchars + [ 0 ] )
123
+ temporary_invalid_cchars. removeAll ( )
124
+ return new_token_str
119
125
}
120
126
121
- let new_token_str = token_to_piece ( token: new_token_id)
127
+ let new_token_cchars = token_to_piece ( token: new_token_id)
128
+ temporary_invalid_cchars. append ( contentsOf: new_token_cchars)
129
+ let new_token_str : String
130
+ if let string = String ( validatingUTF8: temporary_invalid_cchars + [ 0 ] ) {
131
+ temporary_invalid_cchars. removeAll ( )
132
+ new_token_str = string
133
+ } else if ( 0 ..< temporary_invalid_cchars. count) . contains ( where: { $0 != 0 && String ( validatingUTF8: Array ( temporary_invalid_cchars. suffix ( $0) ) + [ 0 ] ) != nil } ) {
134
+ // in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string
135
+ let string = String ( cString: temporary_invalid_cchars + [ 0 ] )
136
+ temporary_invalid_cchars. removeAll ( )
137
+ new_token_str = string
138
+ } else {
139
+ new_token_str = " "
140
+ }
122
141
print ( new_token_str)
123
142
// tokens_list.append(new_token_id)
124
143
@@ -144,6 +163,7 @@ actor LlamaContext {
144
163
145
164
func clear( ) {
146
165
tokens_list. removeAll ( )
166
+ temporary_invalid_cchars. removeAll ( )
147
167
}
148
168
149
169
private func tokenize( text: String , add_bos: Bool ) -> [ llama_token ] {
@@ -162,7 +182,8 @@ actor LlamaContext {
162
182
return swiftTokens
163
183
}
164
184
165
- private func token_to_piece( token: llama_token ) -> String {
185
+ /// - note: The result does not contain null-terminator
186
+ private func token_to_piece( token: llama_token ) -> [ CChar ] {
166
187
let result = UnsafeMutablePointer< Int8> . allocate( capacity: 8 )
167
188
result. initialize ( repeating: Int8 ( 0 ) , count: 8 )
168
189
defer {
@@ -176,10 +197,12 @@ actor LlamaContext {
176
197
defer {
177
198
newResult. deallocate ( )
178
199
}
179
- _ = llama_token_to_piece ( model, token, newResult, - nTokens)
180
- return String ( cString: newResult)
200
+ let nNewTokens = llama_token_to_piece ( model, token, newResult, - nTokens)
201
+ let bufferPointer = UnsafeBufferPointer ( start: newResult, count: Int ( nNewTokens) )
202
+ return Array ( bufferPointer)
181
203
} else {
182
- return String ( cString: result)
204
+ let bufferPointer = UnsafeBufferPointer ( start: result, count: Int ( nTokens) )
205
+ return Array ( bufferPointer)
183
206
}
184
207
}
185
208
}
0 commit comments