@@ -615,45 +615,41 @@ class VitAttention {
615
615
}
616
616
617
617
HWY_NOINLINE void DotSoftmaxWeightedSum () {
618
- const float query_scale =
619
- 1 .0f / sqrtf (static_cast <float >(layer_config_.qkv_dim ));
618
+ const size_t qkv_dim = layer_config_.qkv_dim ;
619
+ const size_t heads = layer_config_.heads ;
620
+ HWY_ASSERT_M (heads == layer_config_.kv_heads , " Vit expects MHA" );
621
+ const size_t seq_len = activations_.seq_len ;
622
+ const float query_scale = 1 .0f / sqrtf (static_cast <float >(qkv_dim));
620
623
PROFILER_ZONE (" Gen.VitAttention.DotSoftmax" );
621
- // A "head group" in the context of GQA refers to a collection of query
622
- // heads that share the same key and value heads.
623
- HWY_ASSERT_M (layer_config_.heads == layer_config_.kv_heads ,
624
- " Vit expects MHA" );
625
624
626
625
// Compute Q.K, softmax, and weighted V.
627
- pool_.Run (
628
- 0 , layer_config_.heads * num_tokens_,
629
- [&](uint64_t task, size_t /* thread*/ ) HWY_ATTR {
630
- const size_t head = task % layer_config_.heads ;
631
- const size_t token = task / layer_config_.heads ;
632
- // Compute Q.K scores, which are "logits" stored in head_att.
633
- float * HWY_RESTRICT q =
634
- activations_.q .Batch (token) + head * 3 * layer_config_.qkv_dim ;
635
- MulByConst (query_scale, q, layer_config_.qkv_dim );
636
- float * HWY_RESTRICT head_att =
637
- activations_.att .Batch (token) + head * activations_.seq_len ;
638
- for (size_t i = 0 ; i < activations_.seq_len ; ++i) {
639
- float * HWY_RESTRICT k = activations_.q .Batch (i) +
640
- head * 3 * layer_config_.qkv_dim +
641
- layer_config_.qkv_dim ;
642
- head_att[i] = Dot (q, k, layer_config_.qkv_dim ); // score = q.k
643
- }
644
- // SoftMax yields "probabilities" in head_att.
645
- Softmax (head_att, activations_.seq_len );
646
- // Compute weighted sum of v into att_out.
647
- float * HWY_RESTRICT att_out =
648
- activations_.att_out .Batch (token) + head * layer_config_.qkv_dim ;
649
- hwy::ZeroBytes (att_out, layer_config_.qkv_dim * sizeof (*att_out));
650
- for (size_t i = 0 ; i < activations_.seq_len ; ++i) {
651
- float * HWY_RESTRICT v = activations_.q .Batch (i) +
652
- head * 3 * layer_config_.qkv_dim +
653
- 2 * layer_config_.qkv_dim ;
654
- MulByConstAndAdd (head_att[i], v, att_out, layer_config_.qkv_dim );
655
- }
656
- });
626
+ pool_.Run (0 , layer_config_.heads * num_tokens_,
627
+ [&](uint64_t task, size_t /* thread*/ ) HWY_ATTR {
628
+ const size_t head = task % layer_config_.heads ;
629
+ const size_t token = task / layer_config_.heads ;
630
+ // Compute Q.K scores, which are "logits" stored in head_att.
631
+ float * HWY_RESTRICT q =
632
+ activations_.q .Batch (token) + head * 3 * qkv_dim;
633
+ MulByConst (query_scale, q, qkv_dim);
634
+ float * HWY_RESTRICT head_att =
635
+ activations_.att .Batch (token) + head * activations_.seq_len ;
636
+ for (size_t i = 0 ; i < seq_len; ++i) {
637
+ float * HWY_RESTRICT k =
638
+ activations_.q .Batch (i) + head * 3 * qkv_dim + qkv_dim;
639
+ head_att[i] = Dot (q, k, qkv_dim); // score = q.k
640
+ }
641
+ // SoftMax yields "probabilities" in head_att.
642
+ Softmax (head_att, seq_len);
643
+ // Compute weighted sum of v into att_out.
644
+ float * HWY_RESTRICT att_out =
645
+ activations_.att_out .Batch (token) + head * qkv_dim;
646
+ hwy::ZeroBytes (att_out, qkv_dim * sizeof (*att_out));
647
+ for (size_t i = 0 ; i < seq_len; ++i) {
648
+ float * HWY_RESTRICT v = activations_.q .Batch (i) +
649
+ head * 3 * qkv_dim + 2 * qkv_dim;
650
+ MulByConstAndAdd (head_att[i], v, att_out, qkv_dim);
651
+ }
652
+ });
657
653
}
658
654
659
655
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
@@ -965,6 +961,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
965
961
layer_weights->vit .layer_norm_0_scale .data_scale1 (),
966
962
layer_weights->vit .layer_norm_0_bias .data_scale1 (),
967
963
activations.pre_att_rms_out .All (), model_dim);
964
+
968
965
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
969
966
// y ~ att_sums
970
967
VitAttention<T>(num_tokens, layer, activations, layer_weights)();
@@ -1104,8 +1101,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
1104
1101
const ModelWeightsPtrs<T>& weights,
1105
1102
Activations& activations) {
1106
1103
const size_t model_dim = weights.weights_config .vit_model_dim ;
1107
- const size_t patch_width =
1108
- weights.weights_config .vit_layer_configs [0 ].patch_width ;
1104
+ const size_t patch_width = weights.weights_config .patch_width ;
1109
1105
const size_t seq_len = weights.weights_config .vit_seq_len ;
1110
1106
const size_t patch_size = patch_width * patch_width * 3 ;
1111
1107
HWY_DASSERT (weights.vit_img_embedding_kernel .NumElements () ==
@@ -1483,17 +1479,16 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
1483
1479
const Image& image, ImageTokens& image_tokens,
1484
1480
PerClusterPools& pools) {
1485
1481
if (model.Config ().vit_layer_configs .empty ()) {
1486
- return ;
1487
- } else {
1488
- Activations prefill_activations (model.Config ());
1489
- RuntimeConfig prefill_runtime_config = runtime_config;
1490
- prefill_runtime_config.prefill_tbatch_size = model.Config ().vit_seq_len ;
1491
- prefill_activations.Allocate (prefill_runtime_config.prefill_tbatch_size ,
1492
- pools);
1493
- // Weights are for the full PaliGemma model, not just the ViT part.
1494
- PrefillVit (*model.GetWeightsOfType <T>(), prefill_runtime_config, image,
1495
- image_tokens, prefill_activations);
1482
+ HWY_ABORT (" Model does not support generating image tokens." );
1496
1483
}
1484
+ RuntimeConfig prefill_runtime_config = runtime_config;
1485
+ ModelConfig vit_config = VitConfig (model.Config ());
1486
+ prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len ;
1487
+ Activations prefill_activations (vit_config);
1488
+ prefill_activations.Allocate (vit_config.seq_len , pools);
1489
+ // Weights are for the full PaliGemma model, not just the ViT part.
1490
+ PrefillVit (*model.GetWeightsOfType <T>(), prefill_runtime_config, image,
1491
+ image_tokens, prefill_activations);
1497
1492
}
1498
1493
1499
1494
} // namespace HWY_NAMESPACE
0 commit comments