Skip to content

Commit d208995

Browse files
authored
swift : fix concatenation method to avoid invalid UTF8 stringfication (#4325)
1 parent 5c9f90c commit d208995

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

Diff for: examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

+30-7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ actor LlamaContext {
1111
private var context: OpaquePointer
1212
private var batch: llama_batch
1313
private var tokens_list: [llama_token]
14+
/// This variable is used to store temporarily invalid cchars
15+
private var temporary_invalid_cchars: [CChar]
1416

1517
var n_len: Int32 = 512
1618
var n_cur: Int32 = 0
@@ -21,6 +23,7 @@ actor LlamaContext {
2123
self.context = context
2224
self.tokens_list = []
2325
self.batch = llama_batch_init(512, 0, 1)
26+
self.temporary_invalid_cchars = []
2427
}
2528

2629
deinit {
@@ -61,6 +64,7 @@ actor LlamaContext {
6164
print("attempting to complete \"\(text)\"")
6265

6366
tokens_list = tokenize(text: text, add_bos: true)
67+
temporary_invalid_cchars = []
6468

6569
let n_ctx = llama_n_ctx(context)
6670
let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
@@ -72,7 +76,7 @@ actor LlamaContext {
7276
}
7377

7478
for id in tokens_list {
75-
print(token_to_piece(token: id))
79+
print(String(cString: token_to_piece(token: id) + [0]))
7680
}
7781

7882
// batch = llama_batch_init(512, 0) // done in init()
@@ -115,10 +119,25 @@ actor LlamaContext {
115119

116120
if new_token_id == llama_token_eos(context) || n_cur == n_len {
117121
print("\n")
118-
return ""
122+
let new_token_str = String(cString: temporary_invalid_cchars + [0])
123+
temporary_invalid_cchars.removeAll()
124+
return new_token_str
119125
}
120126

121-
let new_token_str = token_to_piece(token: new_token_id)
127+
let new_token_cchars = token_to_piece(token: new_token_id)
128+
temporary_invalid_cchars.append(contentsOf: new_token_cchars)
129+
let new_token_str: String
130+
if let string = String(validatingUTF8: temporary_invalid_cchars + [0]) {
131+
temporary_invalid_cchars.removeAll()
132+
new_token_str = string
133+
} else if (0 ..< temporary_invalid_cchars.count).contains(where: {$0 != 0 && String(validatingUTF8: Array(temporary_invalid_cchars.suffix($0)) + [0]) != nil}) {
134+
// in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string
135+
let string = String(cString: temporary_invalid_cchars + [0])
136+
temporary_invalid_cchars.removeAll()
137+
new_token_str = string
138+
} else {
139+
new_token_str = ""
140+
}
122141
print(new_token_str)
123142
// tokens_list.append(new_token_id)
124143

@@ -144,6 +163,7 @@ actor LlamaContext {
144163

145164
func clear() {
146165
tokens_list.removeAll()
166+
temporary_invalid_cchars.removeAll()
147167
}
148168

149169
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
@@ -162,7 +182,8 @@ actor LlamaContext {
162182
return swiftTokens
163183
}
164184

165-
private func token_to_piece(token: llama_token) -> String {
185+
/// - note: The result does not contain null-terminator
186+
private func token_to_piece(token: llama_token) -> [CChar] {
166187
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
167188
result.initialize(repeating: Int8(0), count: 8)
168189
defer {
@@ -176,10 +197,12 @@ actor LlamaContext {
176197
defer {
177198
newResult.deallocate()
178199
}
179-
_ = llama_token_to_piece(model, token, newResult, -nTokens)
180-
return String(cString: newResult)
200+
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens)
201+
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
202+
return Array(bufferPointer)
181203
} else {
182-
return String(cString: result)
204+
let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens))
205+
return Array(bufferPointer)
183206
}
184207
}
185208
}

0 commit comments

Comments
 (0)