Skip to content

Commit 35ff071

Browse files
authored
Merge pull request huggingface#9 from huggingface/raushan-working
Raushan address PR comments
2 parents ae80685 + aa9d141 commit 35ff071

16 files changed

+760
-389
lines changed

docs/source/en/model_doc/gemma3.md

+38
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,51 @@ This model was contributed by [INSERT](INSERT).
2626

2727
## Usage tips
2828

29+
2930
- For image+text and image-only inputs use `Gemma3ForConditionalGeneration`.
3031
- For text-only inputs use `Gemma3ForCausalLM` for generation to avoid loading the vision tower.
3132
- Each sample can contain multiple images, and the number of images can vary between samples. However make sure to pass correctly batched images to the processor, where each batch is a list of one or more images.
3233
- The text passed to the processor should have the `"<start_of_image_>"` token where the images should be inserted.
3334
- The processor has its own `apply_chat_template` method to convert chat messages to text that can then be passed as text to the processor. You can also get a vectorized output from `apply_chat_template`. See the examples below for more details on how to use it.
3435

3536

37+
### Image cropping for high resolution images
38+
39+
`do_pan_and_scan`
40+
41+
The model supports cropping images into smaller patches when the image aspect ratio exceeds a certain value. By default the images are not cropped and only the base image is forwarded to the model. Users can set `do_pan_and_scan=True` to obtain several crops per image along with the base image to improve the quality in DocVQA or similar tasks requiring higher resolution images,
42+
43+
```python
44+
45+
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
46+
47+
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
48+
messages = [
49+
{
50+
"role": "system",
51+
"content": [
52+
{"type": "text", "text": "You are a helpful assistant."}
53+
]
54+
},
55+
{
56+
"role": "user", "content": [
57+
{"type": "image", "url": url},
58+
{"type": "text", "text": "What is shown in this image?"},
59+
]
60+
},
61+
]
62+
inputs = processor.apply_chat_template(
63+
messages,
64+
tokenize=True,
65+
return_dict=True,
66+
return_tensors="pt",
67+
add_generation_prompt=True,
68+
do_pan_and_scan=True,
69+
).to(model.device)
70+
71+
```
72+
73+
3674
## Usage Example
3775

3876
### Single-image Inference

src/transformers/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@
12601260
_import_structure["models.emu3"].append("Emu3ImageProcessor")
12611261
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
12621262
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
1263-
_import_structure["models.gemma3"].append("Gemma3ImageProcessor")
1263+
_import_structure["models.gemma3"].extend(("Gemma3ImageProcessor", "Gemma3ImageProcessorFast"))
12641264
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
12651265
_import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"])
12661266
_import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
@@ -2458,9 +2458,9 @@
24582458
[
24592459
"Gemma3ForCausalLM",
24602460
"Gemma3ForConditionalGeneration",
2461-
"Gemma3Model",
24622461
"Gemma3PreTrainedModel",
24632462
"Gemma3Processor",
2463+
"Gemma3TextModel",
24642464
]
24652465
)
24662466
_import_structure["models.git"].extend(
@@ -6548,6 +6548,7 @@
65486548
from .models.deit import DeiTImageProcessorFast
65496549
from .models.depth_pro import DepthProImageProcessorFast
65506550
from .models.detr import DetrImageProcessorFast
6551+
from .models.gemma3 import Gemma3ImageProcessorFast
65516552
from .models.got_ocr2 import GotOcr2ImageProcessorFast
65526553
from .models.llava import LlavaImageProcessorFast
65536554
from .models.llava_next import LlavaNextImageProcessorFast
@@ -7477,9 +7478,9 @@
74777478
from .models.gemma3 import (
74787479
Gemma3ForCausalLM,
74797480
Gemma3ForConditionalGeneration,
7480-
Gemma3Model,
74817481
Gemma3PreTrainedModel,
74827482
Gemma3Processor,
7483+
Gemma3TextModel,
74837484
)
74847485
from .models.git import (
74857486
GitForCausalLM,

src/transformers/models/auto/image_processing_auto.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
("flava", ("FlavaImageProcessor",)),
8787
("focalnet", ("BitImageProcessor",)),
8888
("fuyu", ("FuyuImageProcessor",)),
89-
("gemma3", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
89+
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
9090
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
9191
("glpn", ("GLPNImageProcessor",)),
9292
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),

src/transformers/models/auto/modeling_auto.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
("funnel", ("FunnelModel", "FunnelBaseModel")),
119119
("gemma", "GemmaModel"),
120120
("gemma2", "Gemma2Model"),
121-
("gemma3_text", "Gemma3Model"),
121+
("gemma3_text", "Gemma3TextModel"),
122122
("git", "GitModel"),
123123
("glm", "GlmModel"),
124124
("glpn", "GLPNModel"),

src/transformers/models/gemma3/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
if TYPE_CHECKING:
2121
from .configuration_gemma3 import *
2222
from .image_processing_gemma3 import *
23+
from .image_processing_gemma3_fast import *
2324
from .modeling_gemma3 import *
2425
from .processing_gemma3 import *
2526
else:

src/transformers/models/gemma3/configuration_gemma3.py

+98-80
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,16 @@
3232

3333
class Gemma3TextConfig(PretrainedConfig):
3434
r"""
35-
This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3
35+
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
3636
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
37-
defaults will yield a similar configuration to that of the Gemma3-4B.
38-
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
37+
defaults will yield a similar configuration to that of the Gemma3Text-7B.
38+
e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
3939
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
4040
documentation from [`PretrainedConfig`] for more information.
41-
4241
Args:
43-
vocab_size (`int`, *optional*, defaults to 262144):
44-
Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the
45-
`inputs_ids` passed when calling [`Gemma3Model`]
42+
vocab_size (`int`, *optional*, defaults to 262208):
43+
Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
44+
`inputs_ids` passed when calling [`Gemma3TextModel`]
4645
hidden_size (`int`, *optional*, defaults to 2304):
4746
Dimension of the hidden representations.
4847
intermediate_size (`int`, *optional*, defaults to 9216):
@@ -61,14 +60,43 @@ class Gemma3TextConfig(PretrainedConfig):
6160
`num_attention_heads`.
6261
head_dim (`int`, *optional*, defaults to 256):
6362
The attention head dimension.
64-
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window
65-
attention. This is the size of the sliding window.
66-
query_pre_attn_scalar (`float`, *optional*):
67-
The scaling factor used on the attention scores, not that
63+
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
64+
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
65+
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
66+
max_position_embeddings (`int`, *optional*, defaults to 131072):
67+
The maximum sequence length that this model might ever be used with.
68+
initializer_range (`float`, *optional*, defaults to 0.02):
69+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
71+
The epsilon used by the rms normalization layers.
72+
use_cache (`bool`, *optional*, defaults to `True`):
73+
Whether or not the model should return the last key/values attentions (not used by all models). Only
74+
relevant if `config.is_decoder=True`.
75+
pad_token_id (`int`, *optional*, defaults to 0):
76+
Padding token id.
77+
eos_token_id (`int`, *optional*, defaults to 1):
78+
End of stream token id.
79+
bos_token_id (`int`, *optional*, defaults to 2):
80+
Beginning of stream token id.
81+
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
82+
Whether to tie weight embeddings
6883
rope_theta (`float`, *optional*, defaults to 1000000.0):
69-
The base period of the RoPE embeddings used for global attention.
84+
The base period of the RoPE embeddings.
85+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
86+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
87+
attention_dropout (`float`, *optional*, defaults to 0.0):
88+
The dropout ratio for the attention probabilities.
89+
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
90+
Scaling factor used on the attention scores
91+
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
92+
size of the sliding window.
93+
final_logit_softcapping (`float`, *optional*):
94+
Scaling factor when applying tanh softcapping on the logits.
95+
attn_logit_softcapping (`float`, *optional*):
96+
Scaling factor when applying tanh softcapping on the attention scores.
97+
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
7098
rope_scaling (`Dict`, *optional*):
71-
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
99+
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type
72100
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
73101
accordingly.
74102
Expected contents:
@@ -108,79 +136,68 @@ class Gemma3TextConfig(PretrainedConfig):
108136
The base period of the RoPE embeddings for local attention.
109137
sliding_window_pattern (`int`, *optional*, defaults to 6):
110138
Pattern for the sliding window attention.
111-
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
112-
The epsilon used by the rms normalization layers.
113-
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
114-
The non-linear activation function (function or string) in the decoder. Will default to
115-
`"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
116-
activation function.
117-
pad_token_id (`int`, *optional*, defaults to 0):
118-
Padding token id.
119-
eos_token_id (`int`, *optional*, defaults to 1):
120-
End of stream token id.
121-
bos_token_id (`int`, *optional*, defaults to 2):
122-
Beginning of stream token id.
123-
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
124-
Whether to tie weight embeddings
125-
max_position_embeddings (`int`, *optional*, defaults to 131072):
126-
The maximum sequence length that this model might ever be used with.
127-
initializer_range (`float`, *optional*, defaults to 0.02):
128-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
129-
attention_bias (`bool`, *optional*, defaults to `False`):
130-
Whether to use a bias in the query, key, value and output projection layers during self-attention.
131-
attention_dropout (`float`, *optional*, defaults to 0.0):
132-
The dropout ratio for the attention probabilities.
133-
use_cache (`bool`, *optional*, defaults to `True`):
134-
Whether or not the model should return the last key/values attentions (not used by all models). Only
135-
relevant if `config.is_decoder=True`.
136-
final_logit_softcapping (`bool`, *optional*, defaults to `True`):
137-
Whether to apply logit softcapping or nor
138-
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
139-
Scaling factor when applying tanh soft-capping on the attention scorexs.
140-
cache_implementation (`str`, *optional*, defaults to `"hybrid"`):
141-
The cache type to be used with `generate`.
142139
143140
```python
144-
>>> from transformers import Gemma3Model, Gemma3TextConfig
145-
>>> # Initializing a Gemma3 gemma3-4b style configuration
146-
>>> configuration = Gemma3Config()
147-
>>> # Initializing a model from the gemma3-4b style configuration
148-
>>> model = Gemma3Model(configuration)
141+
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
142+
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
143+
>>> configuration = Gemma3TextConfig()
144+
>>> # Initializing a model from the gemma3_text-7b style configuration
145+
>>> model = Gemma3TextModel(configuration)
149146
>>> # Accessing the model configuration
150147
>>> configuration = model.config
151-
```"""
148+
```
149+
rope_local_base_freq (float, *optional*, defaults to 10000.0):
150+
The base period of the RoPE embeddings for local attention.
151+
sliding_window_pattern (`int`, *optional*, defaults to 6):
152+
Pattern for the sliding window attention.
153+
"""
152154

153155
model_type = "gemma3_text"
156+
keys_to_ignore_at_inference = ["past_key_values"]
157+
base_model_tp_plan = {
158+
"layers.*.self_attn.q_proj": "colwise",
159+
"layers.*.self_attn.k_proj": "colwise",
160+
"layers.*.self_attn.v_proj": "colwise",
161+
"layers.*.self_attn.o_proj": "rowwise",
162+
"layers.*.mlp.gate_proj": "colwise",
163+
"layers.*.mlp.up_proj": "colwise",
164+
"layers.*.mlp.down_proj": "rowwise",
165+
}
166+
base_model_pp_plan = {
167+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
168+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
169+
"norm": (["hidden_states"], ["hidden_states"]),
170+
}
154171

155172
def __init__(
156173
self,
157-
vocab_size: int = 262_144,
158-
hidden_size: int = 2304,
159-
intermediate_size: int = 9216,
160-
num_hidden_layers: int = 26,
161-
num_attention_heads: int = 8,
162-
num_key_value_heads: int = 4,
163-
head_dim: int = 256,
164-
sliding_window: int = 4096,
165-
query_pre_attn_scalar: Optional[float] = None,
166-
rope_theta: float = 1_000_000.0,
167-
rope_scaling=None,
168-
rope_local_base_freq: float = 10_000.0,
169-
sliding_window_pattern: int = 6,
170-
rms_norm_eps: float = 1e-6,
171-
hidden_activation: str = "gelu_pytorch_tanh",
172-
pad_token_id: int = 0,
173-
eos_token_id: int = 1,
174-
bos_token_id: int = 2,
175-
tie_word_embeddings: bool = True,
176-
max_position_embeddings: int = 131_072,
177-
initializer_range: float = 0.02,
178-
attention_bias: bool = False,
179-
attention_dropout: float = 0.0,
180-
use_cache: bool = True,
174+
vocab_size=262_208,
175+
hidden_size=2304,
176+
intermediate_size=9216,
177+
num_hidden_layers=26,
178+
num_attention_heads=8,
179+
num_key_value_heads=4,
180+
head_dim=256,
181+
hidden_activation="gelu_pytorch_tanh",
182+
max_position_embeddings=131_072,
183+
initializer_range=0.02,
184+
rms_norm_eps=1e-6,
185+
use_cache=True,
186+
pad_token_id=0,
187+
eos_token_id=1,
188+
bos_token_id=2,
189+
tie_word_embeddings=True,
190+
rope_theta=1_000_000.0,
191+
attention_bias=False,
192+
attention_dropout=0.0,
193+
query_pre_attn_scalar=256,
194+
sliding_window=4096,
181195
final_logit_softcapping=None,
182196
attn_logit_softcapping=None,
183-
cache_implementation: str = "hybrid",
197+
cache_implementation="hybrid",
198+
rope_scaling=None,
199+
rope_local_base_freq=10_000.0,
200+
sliding_window_pattern=6,
184201
**kwargs,
185202
):
186203
super().__init__(
@@ -190,7 +207,6 @@ def __init__(
190207
tie_word_embeddings=tie_word_embeddings,
191208
**kwargs,
192209
)
193-
194210
self.vocab_size = vocab_size
195211
self.max_position_embeddings = max_position_embeddings
196212
self.hidden_size = hidden_size
@@ -203,10 +219,6 @@ def __init__(
203219
self.rms_norm_eps = rms_norm_eps
204220
self.use_cache = use_cache
205221
self.rope_theta = rope_theta
206-
self.rope_scaling = rope_scaling
207-
self.rope_local_base_freq = rope_local_base_freq
208-
# For configuring HybridCache to work with 5:1 attention pattern
209-
self.sliding_window_pattern = sliding_window_pattern
210222
self.attention_bias = attention_bias
211223
self.attention_dropout = attention_dropout
212224
self.hidden_activation = hidden_activation
@@ -215,6 +227,11 @@ def __init__(
215227
self.final_logit_softcapping = final_logit_softcapping
216228
self.attn_logit_softcapping = attn_logit_softcapping
217229
self.cache_implementation = cache_implementation
230+
231+
self.rope_local_base_freq = rope_local_base_freq
232+
# For configuring HybridCache to work with 5:1 attention pattern
233+
self.sliding_window_pattern = sliding_window_pattern
234+
self.rope_scaling = rope_scaling
218235
rope_config_validation(self)
219236

220237

@@ -245,6 +262,7 @@ class Gemma3Config(PretrainedConfig):
245262
initializer_range (`float`, *optional*, defaults to 0.02):
246263
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
247264
265+
248266
Example:
249267
250268
```python

0 commit comments

Comments
 (0)