Skip to content

Commit 11b44b9

Browse files
committed
Revert #35589, keep rope_kwargs; rely on them in modular_modernbert
1 parent 385853a commit 11b44b9

30 files changed

+148
-59
lines changed

src/transformers/models/aria/modeling_aria.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,7 @@ def _init_weights(self, module):
725725
class AriaTextRotaryEmbedding(nn.Module):
726726
def __init__(self, config: AriaTextConfig, device=None):
727727
super().__init__()
728+
self.rope_kwargs = {}
728729
# BC: "rope_type" was originally "type"
729730
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
730731
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -736,7 +737,7 @@ def __init__(self, config: AriaTextConfig, device=None):
736737
self.config = config
737738
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
738739

739-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
740+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
740741
self.register_buffer("inv_freq", inv_freq, persistent=False)
741742
self.original_inv_freq = self.inv_freq
742743

@@ -748,7 +749,9 @@ def _dynamic_frequency_update(self, position_ids, device):
748749
"""
749750
seq_len = torch.max(position_ids) + 1
750751
if seq_len > self.max_seq_len_cached: # growth
751-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
752+
inv_freq, self.attention_scaling = self.rope_init_fn(
753+
self.config, device, seq_len=seq_len, **self.rope_kwargs
754+
)
752755
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
753756
self.max_seq_len_cached = seq_len
754757

src/transformers/models/bamba/modeling_bamba.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=
122122
class BambaRotaryEmbedding(nn.Module):
123123
def __init__(self, config: BambaConfig, device=None):
124124
super().__init__()
125+
self.rope_kwargs = {}
125126
# BC: "rope_type" was originally "type"
126127
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
127128
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -133,7 +134,7 @@ def __init__(self, config: BambaConfig, device=None):
133134
self.config = config
134135
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
135136

136-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
137+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
137138
self.register_buffer("inv_freq", inv_freq, persistent=False)
138139
self.original_inv_freq = self.inv_freq
139140

@@ -145,7 +146,9 @@ def _dynamic_frequency_update(self, position_ids, device):
145146
"""
146147
seq_len = torch.max(position_ids) + 1
147148
if seq_len > self.max_seq_len_cached: # growth
148-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
149+
inv_freq, self.attention_scaling = self.rope_init_fn(
150+
self.config, device, seq_len=seq_len, **self.rope_kwargs
151+
)
149152
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
150153
self.max_seq_len_cached = seq_len
151154

src/transformers/models/cohere/modeling_cohere.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def forward(self, hidden_states):
7575
class CohereRotaryEmbedding(nn.Module):
7676
def __init__(self, config: CohereConfig, device=None):
7777
super().__init__()
78+
self.rope_kwargs = {}
7879
# BC: "rope_type" was originally "type"
7980
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
8081
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -86,7 +87,7 @@ def __init__(self, config: CohereConfig, device=None):
8687
self.config = config
8788
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
8889

89-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
90+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
9091
self.register_buffer("inv_freq", inv_freq, persistent=False)
9192
self.original_inv_freq = self.inv_freq
9293

@@ -98,7 +99,9 @@ def _dynamic_frequency_update(self, position_ids, device):
9899
"""
99100
seq_len = torch.max(position_ids) + 1
100101
if seq_len > self.max_seq_len_cached: # growth
101-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
102+
inv_freq, self.attention_scaling = self.rope_init_fn(
103+
self.config, device, seq_len=seq_len, **self.rope_kwargs
104+
)
102105
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
103106
self.max_seq_len_cached = seq_len
104107

src/transformers/models/cohere2/modeling_cohere2.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
class Cohere2RotaryEmbedding(nn.Module):
5656
def __init__(self, config: Cohere2Config, device=None):
5757
super().__init__()
58+
self.rope_kwargs = {}
5859
# BC: "rope_type" was originally "type"
5960
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
6061
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -66,7 +67,7 @@ def __init__(self, config: Cohere2Config, device=None):
6667
self.config = config
6768
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
6869

69-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
70+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
7071
self.register_buffer("inv_freq", inv_freq, persistent=False)
7172
self.original_inv_freq = self.inv_freq
7273

@@ -78,7 +79,9 @@ def _dynamic_frequency_update(self, position_ids, device):
7879
"""
7980
seq_len = torch.max(position_ids) + 1
8081
if seq_len > self.max_seq_len_cached: # growth
81-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
82+
inv_freq, self.attention_scaling = self.rope_init_fn(
83+
self.config, device, seq_len=seq_len, **self.rope_kwargs
84+
)
8285
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
8386
self.max_seq_len_cached = seq_len
8487

src/transformers/models/diffllama/modeling_diffllama.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def _init_weights(self, module):
614614
class DiffLlamaRotaryEmbedding(nn.Module):
615615
def __init__(self, config: DiffLlamaConfig, device=None):
616616
super().__init__()
617+
self.rope_kwargs = {}
617618
# BC: "rope_type" was originally "type"
618619
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
619620
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -625,7 +626,7 @@ def __init__(self, config: DiffLlamaConfig, device=None):
625626
self.config = config
626627
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
627628

628-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
629+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
629630
self.register_buffer("inv_freq", inv_freq, persistent=False)
630631
self.original_inv_freq = self.inv_freq
631632

@@ -637,7 +638,9 @@ def _dynamic_frequency_update(self, position_ids, device):
637638
"""
638639
seq_len = torch.max(position_ids) + 1
639640
if seq_len > self.max_seq_len_cached: # growth
640-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
641+
inv_freq, self.attention_scaling = self.rope_init_fn(
642+
self.config, device, seq_len=seq_len, **self.rope_kwargs
643+
)
641644
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
642645
self.max_seq_len_cached = seq_len
643646

src/transformers/models/falcon/modeling_falcon.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
112112
class FalconRotaryEmbedding(nn.Module):
113113
def __init__(self, config: FalconConfig, device=None):
114114
super().__init__()
115+
self.rope_kwargs = {}
115116
# BC: "rope_type" was originally "type"
116117
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
117118
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -123,7 +124,7 @@ def __init__(self, config: FalconConfig, device=None):
123124
self.config = config
124125
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
125126

126-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
127+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
127128
self.register_buffer("inv_freq", inv_freq, persistent=False)
128129
self.original_inv_freq = self.inv_freq
129130

@@ -135,7 +136,9 @@ def _dynamic_frequency_update(self, position_ids, device):
135136
"""
136137
seq_len = torch.max(position_ids) + 1
137138
if seq_len > self.max_seq_len_cached: # growth
138-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
139+
inv_freq, self.attention_scaling = self.rope_init_fn(
140+
self.config, device, seq_len=seq_len, **self.rope_kwargs
141+
)
139142
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
140143
self.max_seq_len_cached = seq_len
141144

src/transformers/models/gemma/modeling_gemma.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def forward(self, x):
9494
class GemmaRotaryEmbedding(nn.Module):
9595
def __init__(self, config: GemmaConfig, device=None):
9696
super().__init__()
97+
self.rope_kwargs = {}
9798
# BC: "rope_type" was originally "type"
9899
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
99100
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -105,7 +106,7 @@ def __init__(self, config: GemmaConfig, device=None):
105106
self.config = config
106107
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
107108

108-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
109+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
109110
self.register_buffer("inv_freq", inv_freq, persistent=False)
110111
self.original_inv_freq = self.inv_freq
111112

@@ -117,7 +118,9 @@ def _dynamic_frequency_update(self, position_ids, device):
117118
"""
118119
seq_len = torch.max(position_ids) + 1
119120
if seq_len > self.max_seq_len_cached: # growth
120-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
121+
inv_freq, self.attention_scaling = self.rope_init_fn(
122+
self.config, device, seq_len=seq_len, **self.rope_kwargs
123+
)
121124
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
122125
self.max_seq_len_cached = seq_len
123126

src/transformers/models/gemma2/modeling_gemma2.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def forward(
326326
class Gemma2RotaryEmbedding(nn.Module):
327327
def __init__(self, config: Gemma2Config, device=None):
328328
super().__init__()
329+
self.rope_kwargs = {}
329330
# BC: "rope_type" was originally "type"
330331
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
331332
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -337,7 +338,7 @@ def __init__(self, config: Gemma2Config, device=None):
337338
self.config = config
338339
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
339340

340-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
341+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
341342
self.register_buffer("inv_freq", inv_freq, persistent=False)
342343
self.original_inv_freq = self.inv_freq
343344

@@ -349,7 +350,9 @@ def _dynamic_frequency_update(self, position_ids, device):
349350
"""
350351
seq_len = torch.max(position_ids) + 1
351352
if seq_len > self.max_seq_len_cached: # growth
352-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
353+
inv_freq, self.attention_scaling = self.rope_init_fn(
354+
self.config, device, seq_len=seq_len, **self.rope_kwargs
355+
)
353356
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
354357
self.max_seq_len_cached = seq_len
355358

src/transformers/models/glm/modeling_glm.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def extra_repr(self):
257257
class GlmRotaryEmbedding(nn.Module):
258258
def __init__(self, config: GlmConfig, device=None):
259259
super().__init__()
260+
self.rope_kwargs = {}
260261
# BC: "rope_type" was originally "type"
261262
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
262263
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -268,7 +269,7 @@ def __init__(self, config: GlmConfig, device=None):
268269
self.config = config
269270
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
270271

271-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
272+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
272273
self.register_buffer("inv_freq", inv_freq, persistent=False)
273274
self.original_inv_freq = self.inv_freq
274275

@@ -280,7 +281,9 @@ def _dynamic_frequency_update(self, position_ids, device):
280281
"""
281282
seq_len = torch.max(position_ids) + 1
282283
if seq_len > self.max_seq_len_cached: # growth
283-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
284+
inv_freq, self.attention_scaling = self.rope_init_fn(
285+
self.config, device, seq_len=seq_len, **self.rope_kwargs
286+
)
284287
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
285288
self.max_seq_len_cached = seq_len
286289

src/transformers/models/gpt_neox/modeling_gpt_neox.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ def __init__(self, config, layer_idx=None):
493493
class GPTNeoXRotaryEmbedding(nn.Module):
494494
def __init__(self, config: GPTNeoXConfig, device=None):
495495
super().__init__()
496+
self.rope_kwargs = {}
496497
# BC: "rope_type" was originally "type"
497498
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
498499
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -504,7 +505,7 @@ def __init__(self, config: GPTNeoXConfig, device=None):
504505
self.config = config
505506
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
506507

507-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
508+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
508509
self.register_buffer("inv_freq", inv_freq, persistent=False)
509510
self.original_inv_freq = self.inv_freq
510511

@@ -516,7 +517,9 @@ def _dynamic_frequency_update(self, position_ids, device):
516517
"""
517518
seq_len = torch.max(position_ids) + 1
518519
if seq_len > self.max_seq_len_cached: # growth
519-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
520+
inv_freq, self.attention_scaling = self.rope_init_fn(
521+
self.config, device, seq_len=seq_len, **self.rope_kwargs
522+
)
520523
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
521524
self.max_seq_len_cached = seq_len
522525

src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
227227
class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
228228
def __init__(self, config: GPTNeoXJapaneseConfig, device=None):
229229
super().__init__()
230+
self.rope_kwargs = {}
230231
# BC: "rope_type" was originally "type"
231232
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
232233
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -238,7 +239,7 @@ def __init__(self, config: GPTNeoXJapaneseConfig, device=None):
238239
self.config = config
239240
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
240241

241-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
242+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
242243
self.register_buffer("inv_freq", inv_freq, persistent=False)
243244
self.original_inv_freq = self.inv_freq
244245

@@ -250,7 +251,9 @@ def _dynamic_frequency_update(self, position_ids, device):
250251
"""
251252
seq_len = torch.max(position_ids) + 1
252253
if seq_len > self.max_seq_len_cached: # growth
253-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
254+
inv_freq, self.attention_scaling = self.rope_init_fn(
255+
self.config, device, seq_len=seq_len, **self.rope_kwargs
256+
)
254257
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
255258
self.max_seq_len_cached = seq_len
256259

src/transformers/models/granite/modeling_granite.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def forward(
311311
class GraniteRotaryEmbedding(nn.Module):
312312
def __init__(self, config: GraniteConfig, device=None):
313313
super().__init__()
314+
self.rope_kwargs = {}
314315
# BC: "rope_type" was originally "type"
315316
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
316317
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -322,7 +323,7 @@ def __init__(self, config: GraniteConfig, device=None):
322323
self.config = config
323324
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
324325

325-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
326+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
326327
self.register_buffer("inv_freq", inv_freq, persistent=False)
327328
self.original_inv_freq = self.inv_freq
328329

@@ -334,7 +335,9 @@ def _dynamic_frequency_update(self, position_ids, device):
334335
"""
335336
seq_len = torch.max(position_ids) + 1
336337
if seq_len > self.max_seq_len_cached: # growth
337-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
338+
inv_freq, self.attention_scaling = self.rope_init_fn(
339+
self.config, device, seq_len=seq_len, **self.rope_kwargs
340+
)
338341
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
339342
self.max_seq_len_cached = seq_len
340343

src/transformers/models/granitemoe/modeling_granitemoe.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def extra_repr(self):
160160
class GraniteMoeRotaryEmbedding(nn.Module):
161161
def __init__(self, config: GraniteMoeConfig, device=None):
162162
super().__init__()
163+
self.rope_kwargs = {}
163164
# BC: "rope_type" was originally "type"
164165
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
165166
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -171,7 +172,7 @@ def __init__(self, config: GraniteMoeConfig, device=None):
171172
self.config = config
172173
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
173174

174-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
175+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
175176
self.register_buffer("inv_freq", inv_freq, persistent=False)
176177
self.original_inv_freq = self.inv_freq
177178

@@ -183,7 +184,9 @@ def _dynamic_frequency_update(self, position_ids, device):
183184
"""
184185
seq_len = torch.max(position_ids) + 1
185186
if seq_len > self.max_seq_len_cached: # growth
186-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
187+
inv_freq, self.attention_scaling = self.rope_init_fn(
188+
self.config, device, seq_len=seq_len, **self.rope_kwargs
189+
)
187190
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
188191
self.max_seq_len_cached = seq_len
189192

0 commit comments

Comments
 (0)