25
25
26
26
import torch
27
27
import torch .nn as nn
28
+ import torch .nn .functional as F
28
29
29
30
from ...activations import ACT2FN
30
31
from ...cache_utils import Cache , HybridCache , StaticCache
44
45
from ...utils .deprecation import deprecate_kwarg
45
46
from ..gemma import GemmaPreTrainedModel
46
47
from ..siglip import SiglipVisionModel
47
- from .configuration_gemma3 import Gemma3Config , Gemma3RotaryEmbeddingConfig , Gemma3TextConfig
48
+ from .configuration_gemma3 import Gemma3Config , Gemma3RotaryEmbeddingConfig , Gemma3TextConfig , Gemma3VisionConfig
48
49
49
50
50
51
logger = logging .get_logger (__name__ )
@@ -71,6 +72,28 @@ def extra_repr(self):
71
72
return f"{ tuple (self .weight .shape )} , eps={ self .eps } "
72
73
73
74
75
+ class Gemma3VisionAvgPool2D (nn .Module ):
76
+ def __init__ (self , config : Gemma3VisionConfig ):
77
+ super ().__init__ ()
78
+ self .config = config
79
+
80
+ def forward (self , x ):
81
+ """
82
+ Applies average pooling on (B, channels, width, width)
83
+ to make it (B, channels, final_width, final_width).
84
+ """
85
+ batch_size , seq_len , channels = x .shape
86
+ width = int (seq_len ** 0.5 )
87
+ if width * width != seq_len :
88
+ raise ValueError (f"Sequence length { seq_len } is not a perfect square. Cannot reshape to a square image." )
89
+ final_width = int (self .config .pooled_seq_len ** 0.5 )
90
+ kernel_size = width // final_width
91
+ x = x .transpose (1 , 2 ).reshape (batch_size , channels , width , width )
92
+ x = F .avg_pool2d (x , kernel_size = kernel_size , stride = kernel_size )
93
+ x = x .flatten (2 ).transpose (1 , 2 )
94
+ return x
95
+
96
+
74
97
class Gemma3MultimodalInputProjection (nn .Module ):
75
98
def __init__ (self , vision_dim : int , text_dim : int ):
76
99
super ().__init__ ()
@@ -1029,9 +1052,7 @@ def __init__(self, config: Gemma3Config):
1029
1052
)
1030
1053
self .mm_soft_emb_norm = Gemma3RMSNorm (vision_config .hidden_size , eps = vision_config .layer_norm_eps )
1031
1054
1032
- patches_per_image = vision_config .image_size // vision_config .patch_size
1033
- avg_pool_k = patches_per_image ** 2 // text_config .mm_tokens_per_image
1034
- self .avg_pool = nn .AvgPool1d (kernel_size = avg_pool_k , stride = avg_pool_k )
1055
+ self .avg_pool = Gemma3VisionAvgPool2D (config .vision_config )
1035
1056
self .vocab_size = text_config .vocab_size
1036
1057
self .pad_token_id = pad_token_id if (pad_token_id := text_config .pad_token_id ) is not None else - 1
1037
1058
self .post_init ()
@@ -1076,12 +1097,7 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
1076
1097
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
1077
1098
"""
1078
1099
vision_outputs = self .vision_model (pixel_values = pixel_values ).last_hidden_state
1079
- b , n , l = vision_outputs .shape
1080
- reshaped_vision_outputs = vision_outputs .permute (0 , 2 , 1 )
1081
- reshaped_vision_outputs = reshaped_vision_outputs .contiguous ()
1082
- reshaped_vision_outputs = reshaped_vision_outputs .view (b , l , n )
1083
- pooled_vision_outputs = self .avg_pool (reshaped_vision_outputs )
1084
- pooled_vision_outputs = pooled_vision_outputs .permute (0 , 2 , 1 )
1100
+ pooled_vision_outputs = self .avg_pool (vision_outputs )
1085
1101
image_features = self .encode_vision (pooled_vision_outputs )
1086
1102
return image_features
1087
1103
0 commit comments