Skip to content

Commit c638457

Browse files
danielkeyserscopybara-github
authored andcommitted
Fix PaliGemma's GenerateImageTokensT().
Move image related config values from LayerConfig to ModelConfig. Minor changes: Add a few comments, remove gcpp:: qualification where it wasn't needed in a few places, define local constants in VitAttention.DotSoftmaxWeightedSum() PiperOrigin-RevId: 687210519
1 parent 0d68555 commit c638457

File tree

5 files changed

+83
-70
lines changed

5 files changed

+83
-70
lines changed

gemma/common.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020

2121
#include <string>
2222

23+
#include "compression/shared.h" // ModelTraining
2324
#include "gemma/configs.h" // IWYU pragma: export
2425
#include "hwy/base.h" // ConvertScalarTo
2526

2627
namespace gcpp {
2728

28-
// TODO(janwas): merge with functions below.
29+
// Struct to bundle model information.
2930
struct ModelInfo {
3031
Model model;
3132
ModelTraining training;
@@ -42,13 +43,13 @@ const char* ParseType(const std::string& type_string, Type& type);
4243
const char* ModelString(Model model, ModelTraining training);
4344
const char* StringFromType(Type type);
4445

46+
// Wraps the given prompt using the expected control tokens for IT models.
4547
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
4648

47-
// ----------------------------------------------------------------------------
48-
//
49-
49+
// Returns the scale value to use for the embedding (basically sqrt model_dim).
5050
float EmbeddingScaling(size_t model_dim);
5151

52+
// Returns the scale value to use for the query in the attention computation.
5253
float ChooseQueryScale(const ModelConfig& config);
5354

5455
} // namespace gcpp

gemma/configs.cc

+25-13
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static ModelConfig ConfigGemma2_27B() {
4040
config.model_name = "Gemma2_27B";
4141
config.model = Model::GEMMA2_27B;
4242
config.model_dim = 4608;
43-
config.vocab_size = gcpp::kVocabSize;
43+
config.vocab_size = kVocabSize;
4444
config.seq_len = 8192;
4545
LayerConfig layer_config = {.model_dim = config.model_dim,
4646
.ff_hidden_dim = 16 * 4608 / 2, // = 36864
@@ -61,7 +61,7 @@ static ModelConfig ConfigGemma2_9B() {
6161
config.model_name = "Gemma2_9B";
6262
config.model = Model::GEMMA2_9B;
6363
config.model_dim = 3584;
64-
config.vocab_size = gcpp::kVocabSize;
64+
config.vocab_size = kVocabSize;
6565
config.seq_len = 8192;
6666
LayerConfig layer_config = {.model_dim = config.model_dim,
6767
.ff_hidden_dim = 8 * 3584 / 2, // = 14336
@@ -82,7 +82,7 @@ static ModelConfig ConfigGemma2_2B() {
8282
config.model_name = "Gemma2_2B";
8383
config.model = Model::GEMMA2_2B;
8484
config.model_dim = 2304;
85-
config.vocab_size = gcpp::kVocabSize;
85+
config.vocab_size = kVocabSize;
8686
config.seq_len = 8192;
8787
LayerConfig layer_config = {.model_dim = config.model_dim,
8888
.ff_hidden_dim = 8 * 2304 / 2, // = 9216
@@ -103,8 +103,8 @@ static ModelConfig ConfigGemma7B() {
103103
config.model_name = "Gemma7B";
104104
config.model = Model::GEMMA_7B;
105105
config.model_dim = 3072;
106-
config.vocab_size = gcpp::kVocabSize;
107-
config.seq_len = gcpp::kSeqLen;
106+
config.vocab_size = kVocabSize;
107+
config.seq_len = kSeqLen;
108108
LayerConfig layer_config = {
109109
.model_dim = config.model_dim,
110110
.ff_hidden_dim = 16 * 3072 / 2, // = 24576
@@ -115,7 +115,7 @@ static ModelConfig ConfigGemma7B() {
115115
config.layer_configs = {28, layer_config};
116116
config.num_tensor_scales = 4 * config.layer_configs.size();
117117
config.query_scale = QueryScaleType::SqrtKeySize;
118-
config.attention_window_sizes = FixedAttentionWindowSizes<28>(gcpp::kSeqLen);
118+
config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen);
119119
return config;
120120
}
121121

@@ -124,8 +124,8 @@ static ModelConfig ConfigGemma2B() {
124124
config.model_name = "Gemma2B";
125125
config.model = Model::GEMMA_2B;
126126
config.model_dim = 2048;
127-
config.vocab_size = gcpp::kVocabSize;
128-
config.seq_len = gcpp::kSeqLen;
127+
config.vocab_size = kVocabSize;
128+
config.seq_len = kSeqLen;
129129
LayerConfig layer_config = {
130130
.model_dim = config.model_dim,
131131
.ff_hidden_dim = 16 * 2048 / 2, // = 16384
@@ -135,7 +135,7 @@ static ModelConfig ConfigGemma2B() {
135135
};
136136
config.layer_configs = {18, layer_config};
137137
config.num_tensor_scales = 4 * config.layer_configs.size();
138-
config.attention_window_sizes = FixedAttentionWindowSizes<18>(gcpp::kSeqLen);
138+
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
139139
return config;
140140
}
141141

@@ -169,7 +169,7 @@ static ModelConfig ConfigGriffin2B() {
169169
// Griffin uses local attention, so kSeqLen is actually the local attention
170170
// window.
171171
config.model_dim = 2560;
172-
config.vocab_size = gcpp::kVocabSize;
172+
config.vocab_size = kVocabSize;
173173
config.seq_len = 2048;
174174
LayerConfig layer_config = {
175175
.model_dim = config.model_dim,
@@ -204,22 +204,34 @@ static ModelConfig ConfigPaliGemma_224() {
204204
config.model = Model::PALIGEMMA_224;
205205
config.vit_model_dim = 1152;
206206
config.vocab_size = 256000 + 1024 + 128; // = 257152
207-
config.vit_seq_len = 16 * 16;
207+
config.image_size = 224;
208+
config.patch_width = 14;
209+
const size_t num_patches = config.image_size / config.patch_width;
210+
config.vit_seq_len = num_patches * num_patches;
208211
LayerConfig layer_config = {
209212
.model_dim = config.vit_model_dim,
210213
.ff_hidden_dim = 4304,
211214
.heads = 16,
212215
.kv_heads = 16,
213216
.qkv_dim = 72,
217+
.ff_biases = true,
214218
.type = LayerAttentionType::kVit,
215-
.patch_width = 14,
216-
.image_size = 224,
217219
};
218220
config.vit_layer_configs = {27, layer_config};
219221
config.num_vit_scales = 4 * config.vit_layer_configs.size();
220222
return config;
221223
}
222224

225+
ModelConfig VitConfig(const ModelConfig& config) {
226+
ModelConfig vit_config = ConfigNoSSM();
227+
vit_config.model_dim = config.vit_model_dim;
228+
vit_config.seq_len = config.vit_seq_len;
229+
vit_config.layer_configs = config.vit_layer_configs;
230+
// The Vit part does not have a vocabulary, the image patches are embedded.
231+
vit_config.vocab_size = 0;
232+
return vit_config;
233+
}
234+
223235
ModelConfig ConfigFromModel(Model model) {
224236
switch (model) {
225237
case Model::GEMMA_2B:

gemma/configs.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@ struct LayerConfig {
131131
LayerAttentionType type = LayerAttentionType::kGemma;
132132
ActivationType activation = ActivationType::Gelu;
133133
PostQKType post_qk = PostQKType::Rope;
134-
// Dimensions related to image processing.
135-
int patch_width = 14;
136-
int image_size = 224;
137134
};
138135

139136
struct ModelConfig {
@@ -185,11 +182,17 @@ struct ModelConfig {
185182
std::unordered_set<std::string> scale_names;
186183
int norm_num_groups = 1;
187184
int model_family_version = 1;
185+
// Dimensions related to image processing.
186+
int patch_width = 14;
187+
int image_size = 224;
188188
};
189189

190190
// Returns the config for the given model.
191191
ModelConfig ConfigFromModel(Model model);
192192

193+
// Returns the sub-config for the ViT model of the PaliGemma model.
194+
ModelConfig VitConfig(const ModelConfig& config);
195+
193196
} // namespace gcpp
194197

195198
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_

gemma/gemma-inl.h

+43-48
Original file line numberDiff line numberDiff line change
@@ -615,45 +615,41 @@ class VitAttention {
615615
}
616616

617617
HWY_NOINLINE void DotSoftmaxWeightedSum() {
618-
const float query_scale =
619-
1.0f / sqrtf(static_cast<float>(layer_config_.qkv_dim));
618+
const size_t qkv_dim = layer_config_.qkv_dim;
619+
const size_t heads = layer_config_.heads;
620+
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
621+
const size_t seq_len = activations_.seq_len;
622+
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
620623
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
621-
// A "head group" in the context of GQA refers to a collection of query
622-
// heads that share the same key and value heads.
623-
HWY_ASSERT_M(layer_config_.heads == layer_config_.kv_heads,
624-
"Vit expects MHA");
625624

626625
// Compute Q.K, softmax, and weighted V.
627-
pool_.Run(
628-
0, layer_config_.heads * num_tokens_,
629-
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
630-
const size_t head = task % layer_config_.heads;
631-
const size_t token = task / layer_config_.heads;
632-
// Compute Q.K scores, which are "logits" stored in head_att.
633-
float* HWY_RESTRICT q =
634-
activations_.q.Batch(token) + head * 3 * layer_config_.qkv_dim;
635-
MulByConst(query_scale, q, layer_config_.qkv_dim);
636-
float* HWY_RESTRICT head_att =
637-
activations_.att.Batch(token) + head * activations_.seq_len;
638-
for (size_t i = 0; i < activations_.seq_len; ++i) {
639-
float* HWY_RESTRICT k = activations_.q.Batch(i) +
640-
head * 3 * layer_config_.qkv_dim +
641-
layer_config_.qkv_dim;
642-
head_att[i] = Dot(q, k, layer_config_.qkv_dim); // score = q.k
643-
}
644-
// SoftMax yields "probabilities" in head_att.
645-
Softmax(head_att, activations_.seq_len);
646-
// Compute weighted sum of v into att_out.
647-
float* HWY_RESTRICT att_out =
648-
activations_.att_out.Batch(token) + head * layer_config_.qkv_dim;
649-
hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out));
650-
for (size_t i = 0; i < activations_.seq_len; ++i) {
651-
float* HWY_RESTRICT v = activations_.q.Batch(i) +
652-
head * 3 * layer_config_.qkv_dim +
653-
2 * layer_config_.qkv_dim;
654-
MulByConstAndAdd(head_att[i], v, att_out, layer_config_.qkv_dim);
655-
}
656-
});
626+
pool_.Run(0, layer_config_.heads * num_tokens_,
627+
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
628+
const size_t head = task % layer_config_.heads;
629+
const size_t token = task / layer_config_.heads;
630+
// Compute Q.K scores, which are "logits" stored in head_att.
631+
float* HWY_RESTRICT q =
632+
activations_.q.Batch(token) + head * 3 * qkv_dim;
633+
MulByConst(query_scale, q, qkv_dim);
634+
float* HWY_RESTRICT head_att =
635+
activations_.att.Batch(token) + head * activations_.seq_len;
636+
for (size_t i = 0; i < seq_len; ++i) {
637+
float* HWY_RESTRICT k =
638+
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
639+
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
640+
}
641+
// SoftMax yields "probabilities" in head_att.
642+
Softmax(head_att, seq_len);
643+
// Compute weighted sum of v into att_out.
644+
float* HWY_RESTRICT att_out =
645+
activations_.att_out.Batch(token) + head * qkv_dim;
646+
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
647+
for (size_t i = 0; i < seq_len; ++i) {
648+
float* HWY_RESTRICT v = activations_.q.Batch(i) +
649+
head * 3 * qkv_dim + 2 * qkv_dim;
650+
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
651+
}
652+
});
657653
}
658654

659655
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
@@ -965,6 +961,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
965961
layer_weights->vit.layer_norm_0_scale.data_scale1(),
966962
layer_weights->vit.layer_norm_0_bias.data_scale1(),
967963
activations.pre_att_rms_out.All(), model_dim);
964+
968965
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
969966
// y ~ att_sums
970967
VitAttention<T>(num_tokens, layer, activations, layer_weights)();
@@ -1104,8 +1101,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
11041101
const ModelWeightsPtrs<T>& weights,
11051102
Activations& activations) {
11061103
const size_t model_dim = weights.weights_config.vit_model_dim;
1107-
const size_t patch_width =
1108-
weights.weights_config.vit_layer_configs[0].patch_width;
1104+
const size_t patch_width = weights.weights_config.patch_width;
11091105
const size_t seq_len = weights.weights_config.vit_seq_len;
11101106
const size_t patch_size = patch_width * patch_width * 3;
11111107
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
@@ -1483,17 +1479,16 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
14831479
const Image& image, ImageTokens& image_tokens,
14841480
PerClusterPools& pools) {
14851481
if (model.Config().vit_layer_configs.empty()) {
1486-
return;
1487-
} else {
1488-
Activations prefill_activations(model.Config());
1489-
RuntimeConfig prefill_runtime_config = runtime_config;
1490-
prefill_runtime_config.prefill_tbatch_size = model.Config().vit_seq_len;
1491-
prefill_activations.Allocate(prefill_runtime_config.prefill_tbatch_size,
1492-
pools);
1493-
// Weights are for the full PaliGemma model, not just the ViT part.
1494-
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
1495-
image_tokens, prefill_activations);
1482+
HWY_ABORT("Model does not support generating image tokens.");
14961483
}
1484+
RuntimeConfig prefill_runtime_config = runtime_config;
1485+
ModelConfig vit_config = VitConfig(model.Config());
1486+
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
1487+
Activations prefill_activations(vit_config);
1488+
prefill_activations.Allocate(vit_config.seq_len, pools);
1489+
// Weights are for the full PaliGemma model, not just the ViT part.
1490+
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
1491+
image_tokens, prefill_activations);
14971492
}
14981493

14991494
} // namespace HWY_NAMESPACE

gemma/weights.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,10 @@ struct ModelWeightsPtrs {
349349
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
350350
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
351351
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
352-
vit_img_embedding_kernel("img_emb_kernel", 14 * 14 * 3,
353-
config.vit_model_dim),
352+
vit_img_embedding_kernel(
353+
"img_emb_kernel",
354+
config.patch_width * config.patch_width * 3,
355+
config.vit_model_dim),
354356
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
355357
vit_img_head_bias("img_head_bias", 1, config.model_dim),
356358
vit_img_head_kernel("img_head_kernel", config.vit_model_dim,

0 commit comments

Comments
 (0)