Skip to content

Commit 88030d1

Browse files
authored
Merge pull request huggingface#8 from RyanMullins/gemma3pooling
Gemma3 average pooling changed from 1D to 2D
2 parents 00af9a7 + 1a36187 commit 88030d1

File tree

3 files changed

+55
-19
lines changed

3 files changed

+55
-19
lines changed

src/transformers/models/gemma3/configuration_gemma3.py

+2
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def __init__(
256256
layer_norm_eps: float = 0.000001,
257257
vision_use_head: bool = False,
258258
torch_dtype: str = "bfloat16",
259+
pooled_seq_len: int = 256,
259260
**kwargs,
260261
):
261262
super().__init__(
@@ -273,6 +274,7 @@ def __init__(
273274
**kwargs,
274275
)
275276

277+
self.pooled_seq_len = pooled_seq_len
276278
self.vision_use_head = vision_use_head
277279

278280

src/transformers/models/gemma3/modeling_gemma3.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import torch
2727
import torch.nn as nn
28+
import torch.nn.functional as F
2829

2930
from ...activations import ACT2FN
3031
from ...cache_utils import Cache, HybridCache, StaticCache
@@ -44,7 +45,7 @@
4445
from ...utils.deprecation import deprecate_kwarg
4546
from ..gemma import GemmaPreTrainedModel
4647
from ..siglip import SiglipVisionModel
47-
from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig
48+
from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig, Gemma3VisionConfig
4849

4950

5051
logger = logging.get_logger(__name__)
@@ -71,6 +72,28 @@ def extra_repr(self):
7172
return f"{tuple(self.weight.shape)}, eps={self.eps}"
7273

7374

75+
class Gemma3VisionAvgPool2D(nn.Module):
76+
def __init__(self, config: Gemma3VisionConfig):
77+
super().__init__()
78+
self.config = config
79+
80+
def forward(self, x):
81+
"""
82+
Applies average pooling on (B, channels, width, width)
83+
to make it (B, channels, final_width, final_width).
84+
"""
85+
batch_size, seq_len, channels = x.shape
86+
width = int(seq_len**0.5)
87+
if width * width != seq_len:
88+
raise ValueError(f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image.")
89+
final_width = int(self.config.pooled_seq_len**0.5)
90+
kernel_size = width // final_width
91+
x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
92+
x = F.avg_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
93+
x = x.flatten(2).transpose(1, 2)
94+
return x
95+
96+
7497
class Gemma3MultimodalInputProjection(nn.Module):
7598
def __init__(self, vision_dim: int, text_dim: int):
7699
super().__init__()
@@ -1029,9 +1052,7 @@ def __init__(self, config: Gemma3Config):
10291052
)
10301053
self.mm_soft_emb_norm = Gemma3RMSNorm(vision_config.hidden_size, eps=vision_config.layer_norm_eps)
10311054

1032-
patches_per_image = vision_config.image_size // vision_config.patch_size
1033-
avg_pool_k = patches_per_image**2 // text_config.mm_tokens_per_image
1034-
self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k)
1055+
self.avg_pool = Gemma3VisionAvgPool2D(config.vision_config)
10351056
self.vocab_size = text_config.vocab_size
10361057
self.pad_token_id = pad_token_id if (pad_token_id := text_config.pad_token_id) is not None else -1
10371058
self.post_init()
@@ -1076,12 +1097,7 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
10761097
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
10771098
"""
10781099
vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state
1079-
b, n, l = vision_outputs.shape
1080-
reshaped_vision_outputs = vision_outputs.permute(0, 2, 1)
1081-
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
1082-
reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n)
1083-
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
1084-
pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1)
1100+
pooled_vision_outputs = self.avg_pool(vision_outputs)
10851101
image_features = self.encode_vision(pooled_vision_outputs)
10861102
return image_features
10871103

src/transformers/models/gemma3/modular_gemma3.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import PIL.Image
2626
import torch
2727
import torch.nn as nn
28+
import torch.nn.functional as F
2829
import torch.utils.checkpoint
2930

3031
from ...activations import ACT2FN
@@ -332,6 +333,7 @@ def __init__(
332333
layer_norm_eps: float = 0.000001,
333334
vision_use_head: bool = False,
334335
torch_dtype: str = "bfloat16",
336+
pooled_seq_len: int = 256,
335337
**kwargs,
336338
):
337339
super().__init__(
@@ -349,6 +351,7 @@ def __init__(
349351
**kwargs,
350352
)
351353

354+
self.pooled_seq_len = pooled_seq_len
352355
self.vision_use_head = vision_use_head
353356

354357

@@ -710,6 +713,28 @@ def model_input_names(self):
710713
class Gemma3RMSNorm(GemmaRMSNorm):
711714
pass
712715

716+
class Gemma3VisionAvgPool2D(nn.Module):
717+
def __init__(self, config: Gemma3VisionConfig):
718+
super().__init__()
719+
self.config = config
720+
721+
def forward(self, x):
722+
"""
723+
Applies average pooling on (B, channels, width, width)
724+
to make it (B, channels, final_width, final_width).
725+
"""
726+
batch_size, seq_len, channels = x.shape
727+
width = int(seq_len**0.5)
728+
if width * width != seq_len:
729+
raise ValueError(
730+
f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image."
731+
)
732+
final_width = int(self.config.pooled_seq_len**0.5)
733+
kernel_size = width//final_width
734+
x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
735+
x = F.avg_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
736+
x = x.flatten(2).transpose(1, 2)
737+
return x
713738

714739
class Gemma3MultimodalInputProjection(nn.Module):
715740

@@ -1709,9 +1734,7 @@ def __init__(self, config: Gemma3Config):
17091734
vision_config.hidden_size, eps=vision_config.layer_norm_eps
17101735
)
17111736

1712-
patches_per_image = vision_config.image_size // vision_config.patch_size
1713-
avg_pool_k = patches_per_image ** 2 // text_config.mm_tokens_per_image
1714-
self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k)
1737+
self.avg_pool = Gemma3VisionAvgPool2D(config.vision_config)
17151738
self.vocab_size = text_config.vocab_size
17161739
self.pad_token_id = (
17171740
pad_token_id
@@ -1760,12 +1783,7 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
17601783
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
17611784
"""
17621785
vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state
1763-
b, n, l = vision_outputs.shape
1764-
reshaped_vision_outputs = vision_outputs.permute(0, 2, 1)
1765-
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
1766-
reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n)
1767-
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
1768-
pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1)
1786+
pooled_vision_outputs = self.avg_pool(vision_outputs)
17691787
image_features = self.encode_vision(pooled_vision_outputs)
17701788
return image_features
17711789

0 commit comments

Comments
 (0)