@@ -235,13 +235,30 @@ struct GemmaInterface {
235
235
236
236
virtual void Generate (size_t max_tokens, size_t max_generated_tokens,
237
237
float temperature, const std::vector<int >& prompt,
238
- size_t start_pos, hwy::ThreadPool& pool ,
239
- hwy::ThreadPool& inner_pool,
238
+ size_t start_pos, KVCache& kv_cache ,
239
+ hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
240
240
const StreamFunc& stream_token,
241
241
const AcceptFunc& accept_token, std::mt19937& gen,
242
242
int verbosity) = 0;
243
243
};
244
244
245
+ template <class Config >
246
+ KVCache CreateKVCache () {
247
+ return CreateKVCache (Config::kLayers * Config::kKVHeads * Config::kQKVDim ,
248
+ Config::kSeqLen );
249
+ }
250
+
251
+ KVCache CreateKVCache (Model type) {
252
+ switch (type) {
253
+ case Model::GEMMA_2B:
254
+ return CreateKVCache<ConfigGemma2B>();
255
+ case Model::GEMMA_7B:
256
+ return CreateKVCache<ConfigGemma7B>();
257
+ default :
258
+ HWY_ABORT (" Model type %d unknown." , static_cast <int >(type));
259
+ }
260
+ }
261
+
245
262
template <class Config >
246
263
struct GemmaImpl : public GemmaInterface {
247
264
GemmaImpl ( // const LoaderArgs& args,
@@ -255,22 +272,22 @@ struct GemmaImpl : public GemmaInterface {
255
272
c_weights->c_layer_ptrs .~CompressedLayerPointers<Config>();
256
273
}
257
274
258
- const sentencepiece::SentencePieceProcessor* Tokenizer () const {
275
+ const sentencepiece::SentencePieceProcessor* Tokenizer () const override {
259
276
return tokenizer.get ();
260
277
}
261
278
262
279
void Generate (size_t max_tokens, size_t max_generated_tokens,
263
280
float temperature, const std::vector<int >& prompt,
264
- size_t start_pos, hwy::ThreadPool& pool,
281
+ size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
265
282
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
266
- const AcceptFunc& accept_token, std::mt19937&, int verbosity);
283
+ const AcceptFunc& accept_token, std::mt19937&,
284
+ int verbosity) override ;
267
285
268
286
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
269
287
270
288
hwy::AlignedFreeUniquePtr<uint8_t []> compressed_weights;
271
289
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize >> prefill;
272
290
hwy::AlignedUniquePtr<Activations<Config, 1 >> state;
273
- KVCache kv_cache;
274
291
};
275
292
276
293
} // namespace gcpp
@@ -503,7 +520,7 @@ void Transformer(int token, size_t pos,
503
520
template <class TConfig >
504
521
void GenerateImpl (GemmaImpl<TConfig>& gemma, size_t max_tokens,
505
522
size_t max_generated_tokens, float temperature,
506
- const std::vector<int >& prompt, size_t pos,
523
+ const std::vector<int >& prompt, size_t pos, KVCache& kv_cache,
507
524
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
508
525
const StreamFunc& stream_token,
509
526
const AcceptFunc& accept_token, std::mt19937& gen,
@@ -517,7 +534,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
517
534
const CompressedWeights<TConfig>& c_weights =
518
535
*reinterpret_cast <CompressedWeights<TConfig>*>(
519
536
gemma.compressed_weights .get ());
520
- KVCache& kv_cache = gemma.kv_cache ;
521
537
int token;
522
538
523
539
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
@@ -612,23 +628,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
612
628
void Generate2B (GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
613
629
size_t max_generated_tokens, float temperature,
614
630
const std::vector<int >& prompt, size_t start_pos,
615
- hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
616
- const StreamFunc& stream_token, const AcceptFunc& accept_token,
617
- std::mt19937& gen, int verbosity) {
631
+ KVCache& kv_cache, hwy::ThreadPool& pool,
632
+ hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
633
+ const AcceptFunc& accept_token, std::mt19937& gen,
634
+ int verbosity) {
618
635
GenerateImpl (gemma, max_tokens, max_generated_tokens, temperature, prompt,
619
- start_pos, pool, inner_pool, stream_token, accept_token, gen ,
620
- verbosity);
636
+ start_pos, kv_cache, pool, inner_pool, stream_token,
637
+ accept_token, gen, verbosity);
621
638
}
622
639
623
640
void Generate7B (GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
624
641
size_t max_generated_tokens, float temperature,
625
642
const std::vector<int >& prompt, size_t start_pos,
626
- hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
627
- const StreamFunc& stream_token, const AcceptFunc& accept_token,
628
- std::mt19937& gen, int verbosity) {
643
+ KVCache& kv_cache, hwy::ThreadPool& pool,
644
+ hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
645
+ const AcceptFunc& accept_token, std::mt19937& gen,
646
+ int verbosity) {
629
647
GenerateImpl (gemma, max_tokens, max_generated_tokens, temperature, prompt,
630
- start_pos, pool, inner_pool, stream_token, accept_token, gen ,
631
- verbosity);
648
+ start_pos, kv_cache, pool, inner_pool, stream_token,
649
+ accept_token, gen, verbosity);
632
650
}
633
651
634
652
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
@@ -753,9 +771,6 @@ GemmaImpl<Config>::GemmaImpl(
753
771
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
754
772
prefill (hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize >>()),
755
773
state(hwy::MakeUniqueAligned<Activations<Config, 1 >>()),
756
- kv_cache(
757
- CreateKVCache (Config::kLayers * Config::kKVHeads * Config::kQKVDim ,
758
- Config::kSeqLen )),
759
774
tokenizer(std::move(tokenizer)) {
760
775
// PROFILER_ZONE("Startup.tokenizer");
761
776
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
@@ -764,22 +779,24 @@ GemmaImpl<Config>::GemmaImpl(
764
779
template <>
765
780
void GemmaImpl<ConfigGemma2B>::Generate(
766
781
size_t max_tokens, size_t max_generated_tokens, float temperature,
767
- const std::vector<int >& prompt, size_t start_pos, hwy::ThreadPool& pool,
768
- hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
769
- const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
782
+ const std::vector<int >& prompt, size_t start_pos, KVCache& kv_cache,
783
+ hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
784
+ const StreamFunc& stream_token, const AcceptFunc& accept_token,
785
+ std::mt19937& gen, int verbosity) {
770
786
HWY_DYNAMIC_DISPATCH (Generate2B)
771
787
(*this , max_tokens, max_generated_tokens, temperature, prompt, start_pos,
772
- pool, inner_pool, stream_token, accept_token, gen, verbosity);
788
+ kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
773
789
}
774
790
template <>
775
791
void GemmaImpl<ConfigGemma7B>::Generate(
776
792
size_t max_tokens, size_t max_generated_tokens, float temperature,
777
- const std::vector<int >& prompt, size_t start_pos, hwy::ThreadPool& pool,
778
- hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
779
- const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
793
+ const std::vector<int >& prompt, size_t start_pos, KVCache& kv_cache,
794
+ hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
795
+ const StreamFunc& stream_token, const AcceptFunc& accept_token,
796
+ std::mt19937& gen, int verbosity) {
780
797
HWY_DYNAMIC_DISPATCH (Generate7B)
781
798
(*this , max_tokens, max_generated_tokens, temperature, prompt, start_pos,
782
- pool, inner_pool, stream_token, accept_token, gen, verbosity);
799
+ kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
783
800
}
784
801
785
802
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
@@ -814,14 +831,14 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
814
831
815
832
void GenerateGemma (Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
816
833
float temperature, const std::vector<int >& prompt,
817
- size_t start_pos, hwy::ThreadPool& pool,
834
+ size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
818
835
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
819
836
const AcceptFunc& accept_token, std::mt19937& gen,
820
837
int verbosity) {
821
838
pool.SetWaitMode (hwy::PoolWaitMode::kSpin );
822
839
gemma.impl_ ->Generate (max_tokens, max_generated_tokens, temperature, prompt,
823
- start_pos, pool, inner_pool, stream_token, accept_token ,
824
- gen, verbosity);
840
+ start_pos, kv_cache, pool, inner_pool, stream_token,
841
+ accept_token, gen, verbosity);
825
842
pool.SetWaitMode (hwy::PoolWaitMode::kBlock );
826
843
}
827
844
0 commit comments