Skip to content

Commit 8cf3966

Browse files
committed
compile success: set default self extend values in noSSM and griffin
1 parent 02ce1e3 commit 8cf3966

File tree

2 files changed

+275
-0
lines changed

2 files changed

+275
-0
lines changed

gemma/configs.h

+260
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,266 @@ ModelConfig ConfigFromModel(Model model);
192192

193193
// Returns the sub-config for the ViT model of the PaliGemma model.
194194
ModelConfig VitConfig(const ModelConfig& config);
195+
template <class TConfig, typename = void>
196+
struct CacheLayerSize {
197+
constexpr size_t operator()() const {
198+
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
199+
}
200+
};
201+
202+
template <class TConfig, typename = void>
203+
struct CachePosSize {
204+
constexpr size_t operator()() const {
205+
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
206+
}
207+
};
208+
209+
struct ConfigNoSSM {
210+
static constexpr int kGriffinLayers = 0;
211+
212+
static constexpr int kConv1dWidth = 0;
213+
static constexpr bool kFFBiases = false;
214+
static constexpr bool kSoftmaxAttnOutputBiases = false;
215+
static constexpr bool kUseHalfRope = false;
216+
static constexpr bool kUseLocalAttention = false;
217+
static constexpr bool kInterleaveQKV = true;
218+
static constexpr int kNumTensorScales = 0;
219+
220+
static constexpr PostQKType kPostQK = PostQKType::Rope;
221+
static constexpr ActivationType kActivation = ActivationType::Gelu;
222+
static constexpr ResidualType kResidual = ResidualType::Add;
223+
224+
// Self-extend parameters with defaul values
225+
static constexpr bool kSelfExtend = false;
226+
static constexpr size_t kSelfExtendNgbSize = 0;
227+
static constexpr size_t kSelfExtendGrpSize = 1;
228+
};
229+
230+
struct ConfigBaseGemmaV1 : ConfigNoSSM {
231+
static constexpr float kAttCap = 0.0f;
232+
static constexpr float kFinalCap = 0.0f;
233+
static constexpr PostNormType kPostNorm = PostNormType::None;
234+
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
235+
};
236+
237+
struct ConfigBaseGemmaV2 : ConfigNoSSM {
238+
static constexpr float kAttCap = 50.0f;
239+
static constexpr float kFinalCap = 30.0f;
240+
static constexpr PostNormType kPostNorm = PostNormType::Scale;
241+
};
242+
243+
template <typename TWeight>
244+
struct ConfigGemma27B : public ConfigBaseGemmaV2 {
245+
using Weight = TWeight; // make accessible where we only have a TConfig
246+
247+
static constexpr int kSeqLen = 8192;
248+
static constexpr int kVocabSize = 256000;
249+
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
250+
FixedLayerConfig<46>(LayerAttentionType::kGemma);
251+
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
252+
RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
253+
static constexpr int kLayers = kLayerConfig.size();
254+
static constexpr int kGemmaLayers = kLayers;
255+
static constexpr int kModelDim = 4608;
256+
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
257+
static constexpr int kHeads = 32;
258+
static constexpr int kKVHeads = 16;
259+
static constexpr int kQKVDim = 128; // query size == key size == value size
260+
static constexpr int kTopK = gcpp::kTopK;
261+
static constexpr bool kAbsolutePE = false;
262+
static constexpr QueryScaleType kQueryScale =
263+
QueryScaleType::SqrtModelDimDivNumHeads;
264+
};
265+
266+
template <typename TWeight>
267+
struct ConfigGemma9B : public ConfigBaseGemmaV2 {
268+
using Weight = TWeight; // make accessible where we only have a TConfig
269+
270+
static constexpr int kSeqLen = 8192;
271+
static constexpr int kVocabSize = 256000;
272+
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
273+
FixedLayerConfig<42>(LayerAttentionType::kGemma);
274+
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
275+
RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
276+
static constexpr int kLayers = kLayerConfig.size();
277+
static constexpr int kGemmaLayers = kLayers;
278+
static constexpr int kModelDim = 3584;
279+
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
280+
static constexpr int kHeads = 16;
281+
static constexpr int kKVHeads = 8;
282+
static constexpr int kQKVDim = 256; // query size == key size == value size
283+
static constexpr int kTopK = gcpp::kTopK;
284+
static constexpr bool kAbsolutePE = false;
285+
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
286+
};
287+
288+
template <typename TWeight>
289+
struct ConfigGemma7B : public ConfigBaseGemmaV1 {
290+
using Weight = TWeight; // make accessible where we only have a TConfig
291+
292+
static constexpr int kSeqLen = gcpp::kSeqLen;
293+
static constexpr int kVocabSize = 256000;
294+
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
295+
FixedLayerConfig<28>(LayerAttentionType::kGemma);
296+
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
297+
FixedAttentionWindowSizes<28>(kSeqLen);
298+
static constexpr int kLayers = kLayerConfig.size();
299+
static constexpr int kGemmaLayers = kLayers;
300+
static constexpr int kModelDim = 3072;
301+
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
302+
static constexpr int kHeads = 16;
303+
static constexpr int kKVHeads = 16; // standard MHA
304+
static constexpr int kQKVDim = 256; // query size == key size == value size
305+
static constexpr int kTopK = gcpp::kTopK;
306+
static constexpr bool kAbsolutePE = false;
307+
};
308+
309+
template <typename TWeight>
310+
struct ConfigGemma2B : public ConfigBaseGemmaV1 {
311+
using Weight = TWeight; // make accessible where we only have a TConfig
312+
313+
static constexpr int kSeqLen = gcpp::kSeqLen;
314+
static constexpr int kVocabSize = 256000;
315+
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
316+
FixedLayerConfig<18>(LayerAttentionType::kGemma);
317+
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
318+
FixedAttentionWindowSizes<18>(kSeqLen);
319+
static constexpr int kLayers = kLayerConfig.size();
320+
static constexpr int kGemmaLayers = kLayers;
321+
static constexpr int kModelDim = 2048;
322+
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
323+
static constexpr int kHeads = 8;
324+
static constexpr int kKVHeads = 1;
325+
static constexpr int kQKVDim = 256; // query size == key size == value size
326+
static constexpr int kTopK = gcpp::kTopK;
327+
static constexpr bool kAbsolutePE = false;
328+
};
329+
330+
template <typename TWeight>
331+
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
332+
using Weight = TWeight; // make accessible where we only have a TConfig
333+
334+
static constexpr int kSeqLen = 8192;
335+
static constexpr int kVocabSize = 256000;
336+
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
337+
FixedLayerConfig<26>(LayerAttentionType::kGemma);
338+
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
339+
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
340+
static constexpr int kLayers = kLayerConfig.size();
341+
static constexpr int kGemmaLayers = kLayers;
342+
static constexpr int kModelDim = 2304;
343+
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
344+
static constexpr int kHeads = 8;
345+
static constexpr int kKVHeads = 4;
346+
static constexpr int kQKVDim = 256; // query size == key size == value size
347+
static constexpr int kTopK = gcpp::kTopK;
348+
static constexpr bool kAbsolutePE = false;
349+
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
350+
};
351+
352+
template <typename TWeight>
353+
struct ConfigGemmaTiny : public ConfigNoSSM {
354+
using Weight = TWeight; // make accessible where we only have a TConfig
355+
356+
static constexpr int kSeqLen = 32;
357+
static constexpr int kVocabSize = 64;
358+
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
359+
FixedLayerConfig<3>(LayerAttentionType::kGemma);
360+
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
361+
FixedAttentionWindowSizes<3>(kSeqLen);
362+
static constexpr int kLayers = kLayerConfig.size();
363+
static constexpr int kGemmaLayers = kLayers;
364+
static constexpr int kModelDim = 128;
365+
static constexpr int kFFHiddenDim = 256;
366+
static constexpr int kHeads = 4;
367+
static constexpr int kKVHeads = 1;
368+
static constexpr int kQKVDim = 16; // query size == key size == value size
369+
static constexpr int kTopK = gcpp::kTopK;
370+
static constexpr bool kAbsolutePE = false;
371+
static constexpr PostNormType kPostNorm = PostNormType::None;
372+
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
373+
374+
static constexpr float kAttCap = 0.0f;
375+
// This is required for optimize_test to pass.
376+
static constexpr float kFinalCap = 30.0f;
377+
};
378+
379+
template <typename TWeight>
380+
struct ConfigGriffin2B {
381+
using Weight = TWeight; // make accessible where we only have a TConfig
382+
383+
// Griffin uses local attention, so kSeqLen is actually the local attention
384+
// window.
385+
static constexpr int kSeqLen = 2048;
386+
static constexpr int kVocabSize = 256000;
387+
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
388+
LayerAttentionType::kGriffinRecurrentBlock,
389+
LayerAttentionType::kGriffinRecurrentBlock,
390+
LayerAttentionType::kGemma,
391+
LayerAttentionType::kGriffinRecurrentBlock,
392+
LayerAttentionType::kGriffinRecurrentBlock,
393+
LayerAttentionType::kGemma,
394+
LayerAttentionType::kGriffinRecurrentBlock,
395+
LayerAttentionType::kGriffinRecurrentBlock,
396+
LayerAttentionType::kGemma,
397+
LayerAttentionType::kGriffinRecurrentBlock,
398+
LayerAttentionType::kGriffinRecurrentBlock,
399+
LayerAttentionType::kGemma,
400+
LayerAttentionType::kGriffinRecurrentBlock,
401+
LayerAttentionType::kGriffinRecurrentBlock,
402+
LayerAttentionType::kGemma,
403+
LayerAttentionType::kGriffinRecurrentBlock,
404+
LayerAttentionType::kGriffinRecurrentBlock,
405+
LayerAttentionType::kGemma,
406+
LayerAttentionType::kGriffinRecurrentBlock,
407+
LayerAttentionType::kGriffinRecurrentBlock,
408+
LayerAttentionType::kGemma,
409+
LayerAttentionType::kGriffinRecurrentBlock,
410+
LayerAttentionType::kGriffinRecurrentBlock,
411+
LayerAttentionType::kGemma,
412+
LayerAttentionType::kGriffinRecurrentBlock,
413+
LayerAttentionType::kGriffinRecurrentBlock,
414+
};
415+
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
416+
FixedAttentionWindowSizes<26>(kSeqLen);
417+
static constexpr int kLayers = kLayerConfig.size();
418+
static constexpr int kGemmaLayers =
419+
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
420+
static constexpr int kGriffinLayers =
421+
NumLayersOfTypeBefore(kLayerConfig,
422+
LayerAttentionType::kGriffinRecurrentBlock,
423+
kLayers);
424+
static constexpr int kModelDim = 2560;
425+
static constexpr int kFFHiddenDim = 7680;
426+
static constexpr int kHeads = 10;
427+
static constexpr int kKVHeads = 1;
428+
static constexpr int kQKVDim = 256; // query size == key size == value size
429+
static constexpr int kTopK = gcpp::kTopK;
430+
static constexpr bool kAbsolutePE = false;
431+
static constexpr PostNormType kPostNorm = PostNormType::None;
432+
433+
// No SoftCap.
434+
static constexpr float kAttCap = 0.0f;
435+
static constexpr float kFinalCap = 0.0f;
436+
437+
// SSM config.
438+
static constexpr int kConv1dWidth = 4;
439+
static constexpr bool kFFBiases = true;
440+
static constexpr bool kSoftmaxAttnOutputBiases = true;
441+
static constexpr bool kUseHalfRope = true;
442+
static constexpr bool kUseLocalAttention = true;
443+
static constexpr bool kInterleaveQKV = false;
444+
static constexpr int kNumTensorScales = 140;
445+
static constexpr PostQKType kPostQK = PostQKType::Rope;
446+
static constexpr ActivationType kActivation = ActivationType::Gelu;
447+
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
448+
static constexpr ResidualType kResidual = ResidualType::Add;
449+
450+
// Self-extend parameters with defaul values
451+
static constexpr bool kSelfExtend = false;
452+
static constexpr size_t kSelfExtendNgbSize = 0;
453+
static constexpr size_t kSelfExtendGrpSize = 1;
454+
};
195455

196456
} // namespace gcpp
197457

gemma/gemma-inl.h

+15
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,13 @@ class GemmaAttention {
327327
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
328328
kv);
329329

330+
// When embedding position, we will use grouped key position
331+
if constexpr (TConfig::kSelfExtend) {
332+
if (pos > ngb_size) {
333+
pos /= grp_size;
334+
}
335+
}
336+
330337
// If MHA, also copy V into KVCache.
331338
if (is_mha_) {
332339
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
@@ -417,6 +424,14 @@ class GemmaAttention {
417424

418425
// Apply rope and scaling to Q.
419426
const size_t pos = queries_pos_[query_idx] + batch_idx;
427+
if constexpr (TConfig::kSelfExtend) {
428+
if (pos > ngb_size) {
429+
const size_t grp_pos = pos / grp_size;
430+
const size_t shift = ngb_size - ngb_size / grp_size;
431+
const size_t shifted_grouped_pos = grp_pos + shift;
432+
pos = shifted_grouped_pos;
433+
}
434+
}
420435
PositionalEncodingQK(q, pos, layer_, query_scale, q);
421436

422437
const size_t start_pos = StartPos(pos, layer_);

0 commit comments

Comments
 (0)