Skip to content

Separate KV cache from GemmaImpl #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 49 additions & 32 deletions gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,30 @@ struct GemmaInterface {

virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) = 0;
};

template <class Config>
KVCache CreateKVCache() {
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen);
}

KVCache CreateKVCache(Model type) {
switch (type) {
case Model::GEMMA_2B:
return CreateKVCache<ConfigGemma2B>();
case Model::GEMMA_7B:
return CreateKVCache<ConfigGemma7B>();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
}
}

template <class Config>
struct GemmaImpl : public GemmaInterface {
GemmaImpl( // const LoaderArgs& args,
Expand All @@ -255,22 +272,22 @@ struct GemmaImpl : public GemmaInterface {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
}

const sentencepiece::SentencePieceProcessor* Tokenizer() const {
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
return tokenizer.get();
}

void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
const AcceptFunc& accept_token, std::mt19937&,
int verbosity) override;

std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;

hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
KVCache kv_cache;
};

} // namespace gcpp
Expand Down Expand Up @@ -503,7 +520,7 @@ void Transformer(int token, size_t pos,
template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t pos,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
Expand All @@ -517,7 +534,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
const CompressedWeights<TConfig>& c_weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(
gemma.compressed_weights.get());
KVCache& kv_cache = gemma.kv_cache;
int token;

// pos indexes the KV cache. In the first turn of a chat, pos = 0.
Expand Down Expand Up @@ -612,23 +628,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, gen,
verbosity);
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
}

void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, gen,
verbosity);
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
}

// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
Expand Down Expand Up @@ -753,9 +771,6 @@ GemmaImpl<Config>::GemmaImpl(
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
kv_cache(
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen)),
tokenizer(std::move(tokenizer)) {
// PROFILER_ZONE("Startup.tokenizer");
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
Expand All @@ -764,22 +779,24 @@ GemmaImpl<Config>::GemmaImpl(
template <>
void GemmaImpl<ConfigGemma2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
pool, inner_pool, stream_token, accept_token, gen, verbosity);
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
template <>
void GemmaImpl<ConfigGemma7B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
pool, inner_pool, stream_token, accept_token, gen, verbosity);
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}

// TODO: Make Gemma type independent of LoaderArgs, create a factory function
Expand Down Expand Up @@ -814,14 +831,14 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {

void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token,
gen, verbosity);
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}

Expand Down
5 changes: 4 additions & 1 deletion gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ struct Gemma {
gcpp::ModelTraining model_training;
};

KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);

// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
Expand Down Expand Up @@ -211,7 +214,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {

void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity);
Expand Down
13 changes: 7 additions & 6 deletions run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
std::cerr << "\n";
}

void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
int verbosity, const gcpp::AcceptFunc& accept_token,
std::string& eot_line) {
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const InferenceArgs& args, int verbosity,
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
Expand Down Expand Up @@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,

const double time_start = hwy::platform::Now();
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, pool, inner_pool,
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start);
Expand Down Expand Up @@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}

gcpp::Gemma model(loader, pool);
auto kv_cache = CreateKVCache(loader.ModelType());

if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);
Expand Down Expand Up @@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}

ReplGemma(
model, pool, inner_pool, inference, app.verbosity,
model, kv_cache, pool, inner_pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line);
}

Expand Down