Skip to content

Commit 3df06f6

Browse files
authored
Merge pull request #81 from ufownl/feature/separated_kvcache
Separate KV cache from GemmaImpl
2 parents 6c0388e + 170a9b4 commit 3df06f6

File tree

3 files changed

+60
-39
lines changed

3 files changed

+60
-39
lines changed

gemma.cc

+49-32
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,30 @@ struct GemmaInterface {
235235

236236
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
237237
float temperature, const std::vector<int>& prompt,
238-
size_t start_pos, hwy::ThreadPool& pool,
239-
hwy::ThreadPool& inner_pool,
238+
size_t start_pos, KVCache& kv_cache,
239+
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
240240
const StreamFunc& stream_token,
241241
const AcceptFunc& accept_token, std::mt19937& gen,
242242
int verbosity) = 0;
243243
};
244244

245+
template <class Config>
246+
KVCache CreateKVCache() {
247+
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
248+
Config::kSeqLen);
249+
}
250+
251+
KVCache CreateKVCache(Model type) {
252+
switch (type) {
253+
case Model::GEMMA_2B:
254+
return CreateKVCache<ConfigGemma2B>();
255+
case Model::GEMMA_7B:
256+
return CreateKVCache<ConfigGemma7B>();
257+
default:
258+
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
259+
}
260+
}
261+
245262
template <class Config>
246263
struct GemmaImpl : public GemmaInterface {
247264
GemmaImpl( // const LoaderArgs& args,
@@ -255,22 +272,22 @@ struct GemmaImpl : public GemmaInterface {
255272
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
256273
}
257274

258-
const sentencepiece::SentencePieceProcessor* Tokenizer() const {
275+
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
259276
return tokenizer.get();
260277
}
261278

262279
void Generate(size_t max_tokens, size_t max_generated_tokens,
263280
float temperature, const std::vector<int>& prompt,
264-
size_t start_pos, hwy::ThreadPool& pool,
281+
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
265282
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
266-
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
283+
const AcceptFunc& accept_token, std::mt19937&,
284+
int verbosity) override;
267285

268286
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
269287

270288
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
271289
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
272290
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
273-
KVCache kv_cache;
274291
};
275292

276293
} // namespace gcpp
@@ -503,7 +520,7 @@ void Transformer(int token, size_t pos,
503520
template <class TConfig>
504521
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
505522
size_t max_generated_tokens, float temperature,
506-
const std::vector<int>& prompt, size_t pos,
523+
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
507524
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
508525
const StreamFunc& stream_token,
509526
const AcceptFunc& accept_token, std::mt19937& gen,
@@ -517,7 +534,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
517534
const CompressedWeights<TConfig>& c_weights =
518535
*reinterpret_cast<CompressedWeights<TConfig>*>(
519536
gemma.compressed_weights.get());
520-
KVCache& kv_cache = gemma.kv_cache;
521537
int token;
522538

523539
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
@@ -612,23 +628,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
612628
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
613629
size_t max_generated_tokens, float temperature,
614630
const std::vector<int>& prompt, size_t start_pos,
615-
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
616-
const StreamFunc& stream_token, const AcceptFunc& accept_token,
617-
std::mt19937& gen, int verbosity) {
631+
KVCache& kv_cache, hwy::ThreadPool& pool,
632+
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
633+
const AcceptFunc& accept_token, std::mt19937& gen,
634+
int verbosity) {
618635
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
619-
start_pos, pool, inner_pool, stream_token, accept_token, gen,
620-
verbosity);
636+
start_pos, kv_cache, pool, inner_pool, stream_token,
637+
accept_token, gen, verbosity);
621638
}
622639

623640
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
624641
size_t max_generated_tokens, float temperature,
625642
const std::vector<int>& prompt, size_t start_pos,
626-
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
627-
const StreamFunc& stream_token, const AcceptFunc& accept_token,
628-
std::mt19937& gen, int verbosity) {
643+
KVCache& kv_cache, hwy::ThreadPool& pool,
644+
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
645+
const AcceptFunc& accept_token, std::mt19937& gen,
646+
int verbosity) {
629647
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
630-
start_pos, pool, inner_pool, stream_token, accept_token, gen,
631-
verbosity);
648+
start_pos, kv_cache, pool, inner_pool, stream_token,
649+
accept_token, gen, verbosity);
632650
}
633651

634652
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
@@ -753,9 +771,6 @@ GemmaImpl<Config>::GemmaImpl(
753771
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
754772
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
755773
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
756-
kv_cache(
757-
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
758-
Config::kSeqLen)),
759774
tokenizer(std::move(tokenizer)) {
760775
// PROFILER_ZONE("Startup.tokenizer");
761776
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
@@ -764,22 +779,24 @@ GemmaImpl<Config>::GemmaImpl(
764779
template <>
765780
void GemmaImpl<ConfigGemma2B>::Generate(
766781
size_t max_tokens, size_t max_generated_tokens, float temperature,
767-
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
768-
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
769-
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
782+
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
783+
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
784+
const StreamFunc& stream_token, const AcceptFunc& accept_token,
785+
std::mt19937& gen, int verbosity) {
770786
HWY_DYNAMIC_DISPATCH(Generate2B)
771787
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
772-
pool, inner_pool, stream_token, accept_token, gen, verbosity);
788+
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
773789
}
774790
template <>
775791
void GemmaImpl<ConfigGemma7B>::Generate(
776792
size_t max_tokens, size_t max_generated_tokens, float temperature,
777-
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
778-
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
779-
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
793+
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
794+
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
795+
const StreamFunc& stream_token, const AcceptFunc& accept_token,
796+
std::mt19937& gen, int verbosity) {
780797
HWY_DYNAMIC_DISPATCH(Generate7B)
781798
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
782-
pool, inner_pool, stream_token, accept_token, gen, verbosity);
799+
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
783800
}
784801

785802
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
@@ -814,14 +831,14 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
814831

815832
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
816833
float temperature, const std::vector<int>& prompt,
817-
size_t start_pos, hwy::ThreadPool& pool,
834+
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
818835
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
819836
const AcceptFunc& accept_token, std::mt19937& gen,
820837
int verbosity) {
821838
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
822839
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
823-
start_pos, pool, inner_pool, stream_token, accept_token,
824-
gen, verbosity);
840+
start_pos, kv_cache, pool, inner_pool, stream_token,
841+
accept_token, gen, verbosity);
825842
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
826843
}
827844

gemma.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ struct Gemma {
163163
gcpp::ModelTraining model_training;
164164
};
165165

166+
KVCache CreateKVCache(Model type); // convenient workaround for now
167+
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
168+
166169
// StreamFunc is called with (token, probability). For prompt tokens,
167170
// probability is 0.0f.
168171
using StreamFunc = std::function<bool(int, float)>;
@@ -211,7 +214,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
211214

212215
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
213216
float temperature, const std::vector<int>& prompt,
214-
size_t start_pos, hwy::ThreadPool& pool,
217+
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
215218
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
216219
const AcceptFunc& accept_token, std::mt19937& gen,
217220
int verbosity);

run.cc

+7-6
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
9696
std::cerr << "\n";
9797
}
9898

99-
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
100-
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
101-
int verbosity, const gcpp::AcceptFunc& accept_token,
102-
std::string& eot_line) {
99+
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
100+
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
101+
const InferenceArgs& args, int verbosity,
102+
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
103103
PROFILER_ZONE("Gen.misc");
104104
int abs_pos = 0; // absolute token index over all turns
105105
int current_pos = 0; // token index within the current turn
@@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
205205

206206
const double time_start = hwy::platform::Now();
207207
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
208-
args.temperature, prompt, abs_pos, pool, inner_pool,
208+
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
209209
stream_token, accept_token, gen, verbosity);
210210
const double time_end = hwy::platform::Now();
211211
const double tok_sec = current_pos / (time_end - time_start);
@@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
236236
}
237237

238238
gcpp::Gemma model(loader, pool);
239+
auto kv_cache = CreateKVCache(loader.ModelType());
239240

240241
if (const char* error = inference.Validate()) {
241242
ShowHelp(loader, inference, app);
@@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
273274
}
274275

275276
ReplGemma(
276-
model, pool, inner_pool, inference, app.verbosity,
277+
model, kv_cache, pool, inner_pool, inference, app.verbosity,
277278
/*accept_token=*/[](int) { return true; }, app.eot_line);
278279
}
279280

0 commit comments

Comments
 (0)