@@ -1600,12 +1600,12 @@ struct llama_mlock {
1600
1600
};
1601
1601
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
1602
1602
1603
- static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
1603
+ static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special ) {
1604
1604
std::vector<char> result(8, 0);
1605
- const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
1605
+ const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special );
1606
1606
if (n_tokens < 0) {
1607
1607
result.resize(-n_tokens);
1608
- int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
1608
+ int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special );
1609
1609
GGML_ASSERT(check == -n_tokens);
1610
1610
}
1611
1611
else {
@@ -13312,7 +13312,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
13312
13312
13313
13313
for (size_t i = 0; i < candidates->size; ++i) {
13314
13314
const llama_token id = candidates->data[i].id;
13315
- const std::string piece = llama_token_to_piece(ctx, id);
13315
+ const std::string piece = llama_token_to_piece(ctx, id, false);
13316
+
13316
13317
if (llama_token_is_eog(&ctx->model, id)) {
13317
13318
if (!allow_eog) {
13318
13319
candidates->data[i].logit = -INFINITY;
@@ -13512,7 +13513,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
13512
13513
GGML_ASSERT(false);
13513
13514
}
13514
13515
13515
- const std::string piece = llama_token_to_piece(ctx, token);
13516
+ const std::string piece = llama_token_to_piece(ctx, token, false );
13516
13517
13517
13518
// Note terminating 0 in decoded string
13518
13519
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@@ -16991,7 +16992,7 @@ static std::string llama_decode_text(const std::string & text) {
16991
16992
}
16992
16993
16993
16994
// does not write null-terminator to buf
16994
- int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) {
16995
+ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special ) {
16995
16996
if (0 <= token && token < llama_n_vocab(model)) {
16996
16997
switch (llama_vocab_get_type(model->vocab)) {
16997
16998
case LLAMA_VOCAB_TYPE_WPM:
@@ -17006,7 +17007,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
17006
17007
}
17007
17008
memcpy(buf, result.c_str(), result.length());
17008
17009
return result.length();
17009
- } else if (llama_is_user_defined_token(model->vocab, token)) {
17010
+ } else if (
17011
+ (llama_is_user_defined_token(model->vocab, token)) ||
17012
+ (llama_is_control_token (model->vocab, token) && special)) {
17010
17013
std::string result = model->vocab.id_to_token[token].text;
17011
17014
if (length < (int) result.length()) {
17012
17015
return -(int) result.length();
@@ -17019,8 +17022,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
17019
17022
}
17020
17023
memcpy(buf, "\xe2\x96\x85", 3);
17021
17024
return 3;
17022
- } else if (llama_is_control_token(model->vocab, token)) {
17023
- ;
17024
17025
} else if (llama_is_byte_token(model->vocab, token)) {
17025
17026
if (length < 1) {
17026
17027
return -1;
@@ -17041,15 +17042,15 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
17041
17042
}
17042
17043
memcpy(buf, result.c_str(), result.length());
17043
17044
return result.length();
17044
- } else if (llama_is_user_defined_token(model->vocab, token)) {
17045
+ } else if (
17046
+ (llama_is_user_defined_token(model->vocab, token)) ||
17047
+ (llama_is_control_token (model->vocab, token) && special)) {
17045
17048
std::string result = model->vocab.id_to_token[token].text;
17046
17049
if (length < (int) result.length()) {
17047
17050
return -(int) result.length();
17048
17051
}
17049
17052
memcpy(buf, result.c_str(), result.length());
17050
17053
return result.length();
17051
- } else if (llama_is_control_token(model->vocab, token)) {
17052
- ;
17053
17054
}
17054
17055
break;
17055
17056
}
0 commit comments