@@ -192,6 +192,266 @@ ModelConfig ConfigFromModel(Model model);
192
192
193
193
// Returns the sub-config for the ViT model of the PaliGemma model.
194
194
ModelConfig VitConfig (const ModelConfig& config);
195
+ template <class TConfig , typename = void >
196
+ struct CacheLayerSize {
197
+ constexpr size_t operator ()() const {
198
+ return TConfig::kKVHeads * TConfig::kQKVDim * 2 ;
199
+ }
200
+ };
201
+
202
+ template <class TConfig , typename = void >
203
+ struct CachePosSize {
204
+ constexpr size_t operator ()() const {
205
+ return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
206
+ }
207
+ };
208
+
209
+ struct ConfigNoSSM {
210
+ static constexpr int kGriffinLayers = 0 ;
211
+
212
+ static constexpr int kConv1dWidth = 0 ;
213
+ static constexpr bool kFFBiases = false ;
214
+ static constexpr bool kSoftmaxAttnOutputBiases = false ;
215
+ static constexpr bool kUseHalfRope = false ;
216
+ static constexpr bool kUseLocalAttention = false ;
217
+ static constexpr bool kInterleaveQKV = true ;
218
+ static constexpr int kNumTensorScales = 0 ;
219
+
220
+ static constexpr PostQKType kPostQK = PostQKType::Rope;
221
+ static constexpr ActivationType kActivation = ActivationType::Gelu;
222
+ static constexpr ResidualType kResidual = ResidualType::Add;
223
+
224
+ // Self-extend parameters with defaul values
225
+ static constexpr bool kSelfExtend = false ;
226
+ static constexpr size_t kSelfExtendNgbSize = 0 ;
227
+ static constexpr size_t kSelfExtendGrpSize = 1 ;
228
+ };
229
+
230
+ struct ConfigBaseGemmaV1 : ConfigNoSSM {
231
+ static constexpr float kAttCap = 0 .0f ;
232
+ static constexpr float kFinalCap = 0 .0f ;
233
+ static constexpr PostNormType kPostNorm = PostNormType::None;
234
+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
235
+ };
236
+
237
+ struct ConfigBaseGemmaV2 : ConfigNoSSM {
238
+ static constexpr float kAttCap = 50 .0f ;
239
+ static constexpr float kFinalCap = 30 .0f ;
240
+ static constexpr PostNormType kPostNorm = PostNormType::Scale;
241
+ };
242
+
243
+ template <typename TWeight>
244
+ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
245
+ using Weight = TWeight; // make accessible where we only have a TConfig
246
+
247
+ static constexpr int kSeqLen = 8192 ;
248
+ static constexpr int kVocabSize = 256000 ;
249
+ static constexpr std::array<LayerAttentionType, 46 > kLayerConfig =
250
+ FixedLayerConfig<46 >(LayerAttentionType::kGemma );
251
+ static constexpr std::array<size_t , 46 > kAttentionWindowSizes =
252
+ RepeatedAttentionWindowSizes<46 , 2 >({4096 , kSeqLen });
253
+ static constexpr int kLayers = kLayerConfig .size();
254
+ static constexpr int kGemmaLayers = kLayers ;
255
+ static constexpr int kModelDim = 4608 ;
256
+ static constexpr int kFFHiddenDim = 16 * 4608 / 2 ; // = 36864
257
+ static constexpr int kHeads = 32 ;
258
+ static constexpr int kKVHeads = 16 ;
259
+ static constexpr int kQKVDim = 128 ; // query size == key size == value size
260
+ static constexpr int kTopK = gcpp::kTopK ;
261
+ static constexpr bool kAbsolutePE = false ;
262
+ static constexpr QueryScaleType kQueryScale =
263
+ QueryScaleType::SqrtModelDimDivNumHeads;
264
+ };
265
+
266
+ template <typename TWeight>
267
+ struct ConfigGemma9B : public ConfigBaseGemmaV2 {
268
+ using Weight = TWeight; // make accessible where we only have a TConfig
269
+
270
+ static constexpr int kSeqLen = 8192 ;
271
+ static constexpr int kVocabSize = 256000 ;
272
+ static constexpr std::array<LayerAttentionType, 42 > kLayerConfig =
273
+ FixedLayerConfig<42 >(LayerAttentionType::kGemma );
274
+ static constexpr std::array<size_t , 42 > kAttentionWindowSizes =
275
+ RepeatedAttentionWindowSizes<42 , 2 >({4096 , kSeqLen });
276
+ static constexpr int kLayers = kLayerConfig .size();
277
+ static constexpr int kGemmaLayers = kLayers ;
278
+ static constexpr int kModelDim = 3584 ;
279
+ static constexpr int kFFHiddenDim = 8 * 3584 / 2 ; // = 14336
280
+ static constexpr int kHeads = 16 ;
281
+ static constexpr int kKVHeads = 8 ;
282
+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
283
+ static constexpr int kTopK = gcpp::kTopK ;
284
+ static constexpr bool kAbsolutePE = false ;
285
+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
286
+ };
287
+
288
+ template <typename TWeight>
289
+ struct ConfigGemma7B : public ConfigBaseGemmaV1 {
290
+ using Weight = TWeight; // make accessible where we only have a TConfig
291
+
292
+ static constexpr int kSeqLen = gcpp::kSeqLen ;
293
+ static constexpr int kVocabSize = 256000 ;
294
+ static constexpr std::array<LayerAttentionType, 28 > kLayerConfig =
295
+ FixedLayerConfig<28 >(LayerAttentionType::kGemma );
296
+ static constexpr std::array<size_t , 28 > kAttentionWindowSizes =
297
+ FixedAttentionWindowSizes<28 >(kSeqLen );
298
+ static constexpr int kLayers = kLayerConfig .size();
299
+ static constexpr int kGemmaLayers = kLayers ;
300
+ static constexpr int kModelDim = 3072 ;
301
+ static constexpr int kFFHiddenDim = 16 * 3072 / 2 ; // = 24576
302
+ static constexpr int kHeads = 16 ;
303
+ static constexpr int kKVHeads = 16 ; // standard MHA
304
+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
305
+ static constexpr int kTopK = gcpp::kTopK ;
306
+ static constexpr bool kAbsolutePE = false ;
307
+ };
308
+
309
+ template <typename TWeight>
310
+ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
311
+ using Weight = TWeight; // make accessible where we only have a TConfig
312
+
313
+ static constexpr int kSeqLen = gcpp::kSeqLen ;
314
+ static constexpr int kVocabSize = 256000 ;
315
+ static constexpr std::array<LayerAttentionType, 18 > kLayerConfig =
316
+ FixedLayerConfig<18 >(LayerAttentionType::kGemma );
317
+ static constexpr std::array<size_t , 18 > kAttentionWindowSizes =
318
+ FixedAttentionWindowSizes<18 >(kSeqLen );
319
+ static constexpr int kLayers = kLayerConfig .size();
320
+ static constexpr int kGemmaLayers = kLayers ;
321
+ static constexpr int kModelDim = 2048 ;
322
+ static constexpr int kFFHiddenDim = 16 * 2048 / 2 ; // = 16384
323
+ static constexpr int kHeads = 8 ;
324
+ static constexpr int kKVHeads = 1 ;
325
+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
326
+ static constexpr int kTopK = gcpp::kTopK ;
327
+ static constexpr bool kAbsolutePE = false ;
328
+ };
329
+
330
+ template <typename TWeight>
331
+ struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
332
+ using Weight = TWeight; // make accessible where we only have a TConfig
333
+
334
+ static constexpr int kSeqLen = 8192 ;
335
+ static constexpr int kVocabSize = 256000 ;
336
+ static constexpr std::array<LayerAttentionType, 26 > kLayerConfig =
337
+ FixedLayerConfig<26 >(LayerAttentionType::kGemma );
338
+ static constexpr std::array<size_t , 26 > kAttentionWindowSizes =
339
+ RepeatedAttentionWindowSizes<26 , 2 >({4096 , kSeqLen });
340
+ static constexpr int kLayers = kLayerConfig .size();
341
+ static constexpr int kGemmaLayers = kLayers ;
342
+ static constexpr int kModelDim = 2304 ;
343
+ static constexpr int kFFHiddenDim = 8 * 2304 / 2 ; // = 9216
344
+ static constexpr int kHeads = 8 ;
345
+ static constexpr int kKVHeads = 4 ;
346
+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
347
+ static constexpr int kTopK = gcpp::kTopK ;
348
+ static constexpr bool kAbsolutePE = false ;
349
+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
350
+ };
351
+
352
+ template <typename TWeight>
353
+ struct ConfigGemmaTiny : public ConfigNoSSM {
354
+ using Weight = TWeight; // make accessible where we only have a TConfig
355
+
356
+ static constexpr int kSeqLen = 32 ;
357
+ static constexpr int kVocabSize = 64 ;
358
+ static constexpr std::array<LayerAttentionType, 3 > kLayerConfig =
359
+ FixedLayerConfig<3 >(LayerAttentionType::kGemma );
360
+ static constexpr std::array<size_t , 3 > kAttentionWindowSizes =
361
+ FixedAttentionWindowSizes<3 >(kSeqLen );
362
+ static constexpr int kLayers = kLayerConfig .size();
363
+ static constexpr int kGemmaLayers = kLayers ;
364
+ static constexpr int kModelDim = 128 ;
365
+ static constexpr int kFFHiddenDim = 256 ;
366
+ static constexpr int kHeads = 4 ;
367
+ static constexpr int kKVHeads = 1 ;
368
+ static constexpr int kQKVDim = 16 ; // query size == key size == value size
369
+ static constexpr int kTopK = gcpp::kTopK ;
370
+ static constexpr bool kAbsolutePE = false ;
371
+ static constexpr PostNormType kPostNorm = PostNormType::None;
372
+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
373
+
374
+ static constexpr float kAttCap = 0 .0f ;
375
+ // This is required for optimize_test to pass.
376
+ static constexpr float kFinalCap = 30 .0f ;
377
+ };
378
+
379
+ template <typename TWeight>
380
+ struct ConfigGriffin2B {
381
+ using Weight = TWeight; // make accessible where we only have a TConfig
382
+
383
+ // Griffin uses local attention, so kSeqLen is actually the local attention
384
+ // window.
385
+ static constexpr int kSeqLen = 2048 ;
386
+ static constexpr int kVocabSize = 256000 ;
387
+ static constexpr std::array<LayerAttentionType, 26 > kLayerConfig = {
388
+ LayerAttentionType::kGriffinRecurrentBlock ,
389
+ LayerAttentionType::kGriffinRecurrentBlock ,
390
+ LayerAttentionType::kGemma ,
391
+ LayerAttentionType::kGriffinRecurrentBlock ,
392
+ LayerAttentionType::kGriffinRecurrentBlock ,
393
+ LayerAttentionType::kGemma ,
394
+ LayerAttentionType::kGriffinRecurrentBlock ,
395
+ LayerAttentionType::kGriffinRecurrentBlock ,
396
+ LayerAttentionType::kGemma ,
397
+ LayerAttentionType::kGriffinRecurrentBlock ,
398
+ LayerAttentionType::kGriffinRecurrentBlock ,
399
+ LayerAttentionType::kGemma ,
400
+ LayerAttentionType::kGriffinRecurrentBlock ,
401
+ LayerAttentionType::kGriffinRecurrentBlock ,
402
+ LayerAttentionType::kGemma ,
403
+ LayerAttentionType::kGriffinRecurrentBlock ,
404
+ LayerAttentionType::kGriffinRecurrentBlock ,
405
+ LayerAttentionType::kGemma ,
406
+ LayerAttentionType::kGriffinRecurrentBlock ,
407
+ LayerAttentionType::kGriffinRecurrentBlock ,
408
+ LayerAttentionType::kGemma ,
409
+ LayerAttentionType::kGriffinRecurrentBlock ,
410
+ LayerAttentionType::kGriffinRecurrentBlock ,
411
+ LayerAttentionType::kGemma ,
412
+ LayerAttentionType::kGriffinRecurrentBlock ,
413
+ LayerAttentionType::kGriffinRecurrentBlock ,
414
+ };
415
+ static constexpr std::array<size_t , 26 > kAttentionWindowSizes =
416
+ FixedAttentionWindowSizes<26 >(kSeqLen );
417
+ static constexpr int kLayers = kLayerConfig .size();
418
+ static constexpr int kGemmaLayers =
419
+ NumLayersOfTypeBefore (kLayerConfig , LayerAttentionType::kGemma , kLayers );
420
+ static constexpr int kGriffinLayers =
421
+ NumLayersOfTypeBefore (kLayerConfig ,
422
+ LayerAttentionType::kGriffinRecurrentBlock ,
423
+ kLayers );
424
+ static constexpr int kModelDim = 2560 ;
425
+ static constexpr int kFFHiddenDim = 7680 ;
426
+ static constexpr int kHeads = 10 ;
427
+ static constexpr int kKVHeads = 1 ;
428
+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
429
+ static constexpr int kTopK = gcpp::kTopK ;
430
+ static constexpr bool kAbsolutePE = false ;
431
+ static constexpr PostNormType kPostNorm = PostNormType::None;
432
+
433
+ // No SoftCap.
434
+ static constexpr float kAttCap = 0 .0f ;
435
+ static constexpr float kFinalCap = 0 .0f ;
436
+
437
+ // SSM config.
438
+ static constexpr int kConv1dWidth = 4 ;
439
+ static constexpr bool kFFBiases = true ;
440
+ static constexpr bool kSoftmaxAttnOutputBiases = true ;
441
+ static constexpr bool kUseHalfRope = true ;
442
+ static constexpr bool kUseLocalAttention = true ;
443
+ static constexpr bool kInterleaveQKV = false ;
444
+ static constexpr int kNumTensorScales = 140 ;
445
+ static constexpr PostQKType kPostQK = PostQKType::Rope;
446
+ static constexpr ActivationType kActivation = ActivationType::Gelu;
447
+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
448
+ static constexpr ResidualType kResidual = ResidualType::Add;
449
+
450
+ // Self-extend parameters with defaul values
451
+ static constexpr bool kSelfExtend = false ;
452
+ static constexpr size_t kSelfExtendNgbSize = 0 ;
453
+ static constexpr size_t kSelfExtendGrpSize = 1 ;
454
+ };
195
455
196
456
} // namespace gcpp
197
457
0 commit comments