32
32
33
33
class Gemma3TextConfig (PretrainedConfig ):
34
34
r"""
35
- This is the configuration class to store the configuration of a [`Gemma3Model `]. It is used to instantiate a Gemma3
35
+ This is the configuration class to store the configuration of a [`Gemma3TextModel `]. It is used to instantiate an Gemma3Text
36
36
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
37
- defaults will yield a similar configuration to that of the Gemma3-4B .
38
- e.g. [google/gemma-3-4b ](https://huggingface.co/google/gemma-3-4b )
37
+ defaults will yield a similar configuration to that of the Gemma3Text-7B .
38
+ e.g. [google/gemma3_text-7b ](https://huggingface.co/google/gemma3_text-7b )
39
39
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
40
documentation from [`PretrainedConfig`] for more information.
41
-
42
41
Args:
43
- vocab_size (`int`, *optional*, defaults to 262144 ):
44
- Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the
45
- `inputs_ids` passed when calling [`Gemma3Model `]
42
+ vocab_size (`int`, *optional*, defaults to 262208 ):
43
+ Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`Gemma3TextModel `]
46
45
hidden_size (`int`, *optional*, defaults to 2304):
47
46
Dimension of the hidden representations.
48
47
intermediate_size (`int`, *optional*, defaults to 9216):
@@ -61,14 +60,43 @@ class Gemma3TextConfig(PretrainedConfig):
61
60
`num_attention_heads`.
62
61
head_dim (`int`, *optional*, defaults to 256):
63
62
The attention head dimension.
64
- sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window
65
- attention. This is the size of the sliding window.
66
- query_pre_attn_scalar (`float`, *optional*):
67
- The scaling factor used on the attention scores, not that
63
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
64
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
65
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
66
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
67
+ The maximum sequence length that this model might ever be used with.
68
+ initializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
71
+ The epsilon used by the rms normalization layers.
72
+ use_cache (`bool`, *optional*, defaults to `True`):
73
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
74
+ relevant if `config.is_decoder=True`.
75
+ pad_token_id (`int`, *optional*, defaults to 0):
76
+ Padding token id.
77
+ eos_token_id (`int`, *optional*, defaults to 1):
78
+ End of stream token id.
79
+ bos_token_id (`int`, *optional*, defaults to 2):
80
+ Beginning of stream token id.
81
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
82
+ Whether to tie weight embeddings
68
83
rope_theta (`float`, *optional*, defaults to 1000000.0):
69
- The base period of the RoPE embeddings used for global attention.
84
+ The base period of the RoPE embeddings.
85
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
86
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
87
+ attention_dropout (`float`, *optional*, defaults to 0.0):
88
+ The dropout ratio for the attention probabilities.
89
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
90
+ Scaling factor used on the attention scores
91
+ sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
92
+ size of the sliding window.
93
+ final_logit_softcapping (`float`, *optional*):
94
+ Scaling factor when applying tanh softcapping on the logits.
95
+ attn_logit_softcapping (`float`, *optional*):
96
+ Scaling factor when applying tanh softcapping on the attention scores.
97
+ cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
70
98
rope_scaling (`Dict`, *optional*):
71
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
99
+ Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention . NOTE: if you apply new rope type
72
100
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
73
101
accordingly.
74
102
Expected contents:
@@ -108,79 +136,68 @@ class Gemma3TextConfig(PretrainedConfig):
108
136
The base period of the RoPE embeddings for local attention.
109
137
sliding_window_pattern (`int`, *optional*, defaults to 6):
110
138
Pattern for the sliding window attention.
111
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
112
- The epsilon used by the rms normalization layers.
113
- hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
114
- The non-linear activation function (function or string) in the decoder. Will default to
115
- `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
116
- activation function.
117
- pad_token_id (`int`, *optional*, defaults to 0):
118
- Padding token id.
119
- eos_token_id (`int`, *optional*, defaults to 1):
120
- End of stream token id.
121
- bos_token_id (`int`, *optional*, defaults to 2):
122
- Beginning of stream token id.
123
- tie_word_embeddings (`bool`, *optional*, defaults to `True`):
124
- Whether to tie weight embeddings
125
- max_position_embeddings (`int`, *optional*, defaults to 131072):
126
- The maximum sequence length that this model might ever be used with.
127
- initializer_range (`float`, *optional*, defaults to 0.02):
128
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
129
- attention_bias (`bool`, *optional*, defaults to `False`):
130
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
131
- attention_dropout (`float`, *optional*, defaults to 0.0):
132
- The dropout ratio for the attention probabilities.
133
- use_cache (`bool`, *optional*, defaults to `True`):
134
- Whether or not the model should return the last key/values attentions (not used by all models). Only
135
- relevant if `config.is_decoder=True`.
136
- final_logit_softcapping (`bool`, *optional*, defaults to `True`):
137
- Whether to apply logit softcapping or nor
138
- attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
139
- Scaling factor when applying tanh soft-capping on the attention scorexs.
140
- cache_implementation (`str`, *optional*, defaults to `"hybrid"`):
141
- The cache type to be used with `generate`.
142
139
143
140
```python
144
- >>> from transformers import Gemma3Model , Gemma3TextConfig
145
- >>> # Initializing a Gemma3 gemma3-4b style configuration
146
- >>> configuration = Gemma3Config ()
147
- >>> # Initializing a model from the gemma3-4b style configuration
148
- >>> model = Gemma3Model (configuration)
141
+ >>> from transformers import Gemma3TextModel , Gemma3TextConfig
142
+ >>> # Initializing a Gemma3Text gemma3_text-7b style configuration
143
+ >>> configuration = Gemma3TextConfig ()
144
+ >>> # Initializing a model from the gemma3_text-7b style configuration
145
+ >>> model = Gemma3TextModel (configuration)
149
146
>>> # Accessing the model configuration
150
147
>>> configuration = model.config
151
- ```"""
148
+ ```
149
+ rope_local_base_freq (float, *optional*, defaults to 10000.0):
150
+ The base period of the RoPE embeddings for local attention.
151
+ sliding_window_pattern (`int`, *optional*, defaults to 6):
152
+ Pattern for the sliding window attention.
153
+ """
152
154
153
155
model_type = "gemma3_text"
156
+ keys_to_ignore_at_inference = ["past_key_values" ]
157
+ base_model_tp_plan = {
158
+ "layers.*.self_attn.q_proj" : "colwise" ,
159
+ "layers.*.self_attn.k_proj" : "colwise" ,
160
+ "layers.*.self_attn.v_proj" : "colwise" ,
161
+ "layers.*.self_attn.o_proj" : "rowwise" ,
162
+ "layers.*.mlp.gate_proj" : "colwise" ,
163
+ "layers.*.mlp.up_proj" : "colwise" ,
164
+ "layers.*.mlp.down_proj" : "rowwise" ,
165
+ }
166
+ base_model_pp_plan = {
167
+ "embed_tokens" : (["input_ids" ], ["inputs_embeds" ]),
168
+ "layers" : (["hidden_states" , "attention_mask" ], ["hidden_states" ]),
169
+ "norm" : (["hidden_states" ], ["hidden_states" ]),
170
+ }
154
171
155
172
def __init__ (
156
173
self ,
157
- vocab_size : int = 262_144 ,
158
- hidden_size : int = 2304 ,
159
- intermediate_size : int = 9216 ,
160
- num_hidden_layers : int = 26 ,
161
- num_attention_heads : int = 8 ,
162
- num_key_value_heads : int = 4 ,
163
- head_dim : int = 256 ,
164
- sliding_window : int = 4096 ,
165
- query_pre_attn_scalar : Optional [float ] = None ,
166
- rope_theta : float = 1_000_000.0 ,
167
- rope_scaling = None ,
168
- rope_local_base_freq : float = 10_000.0 ,
169
- sliding_window_pattern : int = 6 ,
170
- rms_norm_eps : float = 1e-6 ,
171
- hidden_activation : str = "gelu_pytorch_tanh" ,
172
- pad_token_id : int = 0 ,
173
- eos_token_id : int = 1 ,
174
- bos_token_id : int = 2 ,
175
- tie_word_embeddings : bool = True ,
176
- max_position_embeddings : int = 131_072 ,
177
- initializer_range : float = 0.02 ,
178
- attention_bias : bool = False ,
179
- attention_dropout : float = 0.0 ,
180
- use_cache : bool = True ,
174
+ vocab_size = 262_208 ,
175
+ hidden_size = 2304 ,
176
+ intermediate_size = 9216 ,
177
+ num_hidden_layers = 26 ,
178
+ num_attention_heads = 8 ,
179
+ num_key_value_heads = 4 ,
180
+ head_dim = 256 ,
181
+ hidden_activation = "gelu_pytorch_tanh" ,
182
+ max_position_embeddings = 131_072 ,
183
+ initializer_range = 0.02 ,
184
+ rms_norm_eps = 1e-6 ,
185
+ use_cache = True ,
186
+ pad_token_id = 0 ,
187
+ eos_token_id = 1 ,
188
+ bos_token_id = 2 ,
189
+ tie_word_embeddings = True ,
190
+ rope_theta = 1_000_000.0 ,
191
+ attention_bias = False ,
192
+ attention_dropout = 0.0 ,
193
+ query_pre_attn_scalar = 256 ,
194
+ sliding_window = 4096 ,
181
195
final_logit_softcapping = None ,
182
196
attn_logit_softcapping = None ,
183
- cache_implementation : str = "hybrid" ,
197
+ cache_implementation = "hybrid" ,
198
+ rope_scaling = None ,
199
+ rope_local_base_freq = 10_000.0 ,
200
+ sliding_window_pattern = 6 ,
184
201
** kwargs ,
185
202
):
186
203
super ().__init__ (
@@ -190,7 +207,6 @@ def __init__(
190
207
tie_word_embeddings = tie_word_embeddings ,
191
208
** kwargs ,
192
209
)
193
-
194
210
self .vocab_size = vocab_size
195
211
self .max_position_embeddings = max_position_embeddings
196
212
self .hidden_size = hidden_size
@@ -203,10 +219,6 @@ def __init__(
203
219
self .rms_norm_eps = rms_norm_eps
204
220
self .use_cache = use_cache
205
221
self .rope_theta = rope_theta
206
- self .rope_scaling = rope_scaling
207
- self .rope_local_base_freq = rope_local_base_freq
208
- # For configuring HybridCache to work with 5:1 attention pattern
209
- self .sliding_window_pattern = sliding_window_pattern
210
222
self .attention_bias = attention_bias
211
223
self .attention_dropout = attention_dropout
212
224
self .hidden_activation = hidden_activation
@@ -215,6 +227,11 @@ def __init__(
215
227
self .final_logit_softcapping = final_logit_softcapping
216
228
self .attn_logit_softcapping = attn_logit_softcapping
217
229
self .cache_implementation = cache_implementation
230
+
231
+ self .rope_local_base_freq = rope_local_base_freq
232
+ # For configuring HybridCache to work with 5:1 attention pattern
233
+ self .sliding_window_pattern = sliding_window_pattern
234
+ self .rope_scaling = rope_scaling
218
235
rope_config_validation (self )
219
236
220
237
@@ -245,6 +262,7 @@ class Gemma3Config(PretrainedConfig):
245
262
initializer_range (`float`, *optional*, defaults to 0.02):
246
263
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
247
264
265
+
248
266
Example:
249
267
250
268
```python
0 commit comments