Skip to content

Commit 0e89203

Browse files
authored
speculative : add tree-based sampling example (ggml-org#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
1 parent c67fe68 commit 0e89203

File tree

21 files changed

+736
-577
lines changed

21 files changed

+736
-577
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h l
545545
$(CXX) $(CXXFLAGS) -c $< -o $@
546546

547547
COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
548-
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o
548+
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o grammar-parser.o
549549

550550
common.o: common/common.cpp $(COMMON_H_DEPS)
551551
$(CXX) $(CXXFLAGS) -c $< -o $@

common/common.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
820820
return cparams;
821821
}
822822

823+
void llama_batch_clear(struct llama_batch & batch) {
824+
batch.n_tokens = 0;
825+
}
826+
827+
void llama_batch_add(
828+
struct llama_batch & batch,
829+
llama_token id,
830+
llama_pos pos,
831+
const std::vector<llama_seq_id> & seq_ids,
832+
bool logits) {
833+
batch.token [batch.n_tokens] = id;
834+
batch.pos [batch.n_tokens] = pos,
835+
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
836+
for (size_t i = 0; i < seq_ids.size(); ++i) {
837+
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
838+
}
839+
batch.logits [batch.n_tokens] = logits;
840+
841+
batch.n_tokens++;
842+
}
843+
823844
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
824845
auto mparams = llama_model_params_from_gpt_params(params);
825846

common/common.h

+15-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct gpt_params {
7070
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
7171
std::string logdir = ""; // directory in which to save YAML log files
7272

73+
// TODO: avoid tuple, use struct
7374
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
7475
std::string lora_base = ""; // base model path for the lora adapter
7576

@@ -124,10 +125,23 @@ void process_escapes(std::string& input);
124125
// Model utils
125126
//
126127

128+
// TODO: avoid tuplue, use struct
127129
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
128-
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
130+
131+
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
129132
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
130133

134+
// Batch utils
135+
136+
void llama_batch_clear(struct llama_batch & batch);
137+
138+
void llama_batch_add(
139+
struct llama_batch & batch,
140+
llama_token id,
141+
llama_pos pos,
142+
const std::vector<llama_seq_id> & seq_ids,
143+
bool logits);
144+
131145
//
132146
// Vocab utils
133147
//

common/log.h

+69-32
Original file line numberDiff line numberDiff line change
@@ -579,38 +579,75 @@ inline std::string log_var_to_string_impl(const std::vector<int> & var)
579579
return buf.str();
580580
}
581581

582-
#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \
583-
[&tokens, &ctx]() \
584-
{ \
585-
std::stringstream buf; \
586-
buf << "[ "; \
587-
\
588-
bool first = true; \
589-
for (const auto &token : tokens) \
590-
{ \
591-
if (!first) \
592-
buf << ", "; \
593-
else \
594-
first = false; \
595-
\
596-
auto detokenized = llama_token_to_piece(ctx, token); \
597-
\
598-
detokenized.erase( \
599-
std::remove_if( \
600-
detokenized.begin(), \
601-
detokenized.end(), \
602-
[](const unsigned char c) { return !std::isprint(c); }), \
603-
detokenized.end()); \
604-
\
605-
buf \
606-
<< "'" << detokenized << "'" \
607-
<< ":" << std::to_string(token); \
608-
} \
609-
buf << " ]"; \
610-
\
611-
return buf.str(); \
612-
}() \
613-
.c_str()
582+
template <typename C, typename T>
583+
inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
584+
{
585+
std::stringstream buf;
586+
buf << "[ ";
587+
588+
bool first = true;
589+
for (const auto &token : tokens)
590+
{
591+
if (!first) {
592+
buf << ", ";
593+
} else {
594+
first = false;
595+
}
596+
597+
auto detokenized = llama_token_to_piece(ctx, token);
598+
599+
detokenized.erase(
600+
std::remove_if(
601+
detokenized.begin(),
602+
detokenized.end(),
603+
[](const unsigned char c) { return !std::isprint(c); }),
604+
detokenized.end());
605+
606+
buf
607+
<< "'" << detokenized << "'"
608+
<< ":" << std::to_string(token);
609+
}
610+
buf << " ]";
611+
612+
return buf.str();
613+
}
614+
615+
template <typename C, typename B>
616+
inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
617+
{
618+
std::stringstream buf;
619+
buf << "[ ";
620+
621+
bool first = true;
622+
for (int i = 0; i < batch.n_tokens; ++i)
623+
{
624+
if (!first) {
625+
buf << ", ";
626+
} else {
627+
first = false;
628+
}
629+
630+
auto detokenized = llama_token_to_piece(ctx, batch.token[i]);
631+
632+
detokenized.erase(
633+
std::remove_if(
634+
detokenized.begin(),
635+
detokenized.end(),
636+
[](const unsigned char c) { return !std::isprint(c); }),
637+
detokenized.end());
638+
639+
buf
640+
<< "\n" << std::to_string(i)
641+
<< ":token '" << detokenized << "'"
642+
<< ":pos " << std::to_string(batch.pos[i])
643+
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
644+
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
645+
<< ":logits " << std::to_string(batch.logits[i]);
646+
}
647+
buf << " ]";
648+
649+
return buf.str();
650+
}
614651

615652
#ifdef LOG_DISABLE_LOGS
616653

0 commit comments

Comments
 (0)