Skip to content

Commit 0839a51

Browse files
committed
ModernBert: reuse GemmaRotaryEmbedding via modular
1 parent b5b3fdc commit 0839a51

File tree

2 files changed

+52
-44
lines changed

2 files changed

+52
-44
lines changed

src/transformers/models/modernbert/modeling_modernbert.py

+45-14
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ...activations import ACT2FN
3131
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
3232
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
33+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
3334
from ...modeling_utils import PreTrainedModel
3435
from ...utils import (
3536
add_code_sample_docstrings,
@@ -235,30 +236,62 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
235236

236237

237238
class ModernBertRotaryEmbedding(nn.Module):
238-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
239+
def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None):
239240
super().__init__()
241+
self.rope_kwargs = {"dim": dim, "base": base}
242+
# BC: "rope_type" was originally "type"
243+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
244+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
245+
else:
246+
self.rope_type = "default"
247+
self.max_seq_len_cached = config.max_position_embeddings
248+
self.original_max_seq_len = config.max_position_embeddings
249+
self.config = None
250+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
251+
252+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
253+
self.register_buffer("inv_freq", inv_freq, persistent=False)
254+
self.original_inv_freq = self.inv_freq
255+
256+
def _dynamic_frequency_update(self, position_ids, device):
257+
"""
258+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
259+
1 - growing beyond the cached sequence length (allow scaling)
260+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
261+
"""
262+
seq_len = torch.max(position_ids) + 1
263+
if seq_len > self.max_seq_len_cached: # growth
264+
inv_freq, self.attention_scaling = self.rope_init_fn(
265+
self.config, device, seq_len=seq_len, **self.rope_kwargs
266+
)
267+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
268+
self.max_seq_len_cached = seq_len
240269

241-
self.dim = dim
242-
self.max_position_embeddings = max_position_embeddings
243-
self.base = base
244-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
245-
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
270+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
271+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
272+
self.max_seq_len_cached = self.original_max_seq_len
246273

247274
@torch.no_grad()
248-
def forward(self, x, position_ids, seq_len=None):
249-
# x: [bs, num_attention_heads, seq_len, head_size]
250-
self.inv_freq.to(x.device)
275+
def forward(self, x, position_ids):
276+
if "dynamic" in self.rope_type:
277+
self._dynamic_frequency_update(position_ids, device=x.device)
278+
279+
# Core RoPE block
251280
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
252281
position_ids_expanded = position_ids[:, None, :].float()
253-
# Force float32 since bfloat16 loses precision on long contexts
254-
# See https://github.com/huggingface/transformers/pull/29285
282+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
255283
device_type = x.device.type
256284
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
257285
with torch.autocast(device_type=device_type, enabled=False):
258286
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
259287
emb = torch.cat((freqs, freqs), dim=-1)
260288
cos = emb.cos()
261289
sin = emb.sin()
290+
291+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
292+
cos = cos * self.attention_scaling
293+
sin = sin * self.attention_scaling
294+
262295
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
263296

264297

@@ -462,9 +495,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
462495
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
463496
)
464497
else:
465-
self.rotary_emb = ModernBertRotaryEmbedding(
466-
dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
467-
)
498+
self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)
468499

469500
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
470501
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()

src/transformers/models/modernbert/modular_modernbert.py

+7-30
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
logging,
4141
)
4242
from ...utils.import_utils import is_triton_available
43-
from ..gemma.modeling_gemma import apply_rotary_pos_emb
43+
from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
4444

4545

4646
if is_flash_attn_2_available():
@@ -493,32 +493,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
493493
return self.Wo(self.drop(self.act(input) * gate))
494494

495495

496-
class ModernBertRotaryEmbedding(nn.Module):
497-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
498-
super().__init__()
499-
500-
self.dim = dim
501-
self.max_position_embeddings = max_position_embeddings
502-
self.base = base
503-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
504-
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
505-
506-
@torch.no_grad()
507-
def forward(self, x, position_ids, seq_len=None):
508-
# x: [bs, num_attention_heads, seq_len, head_size]
509-
self.inv_freq.to(x.device)
510-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
511-
position_ids_expanded = position_ids[:, None, :].float()
512-
# Force float32 since bfloat16 loses precision on long contexts
513-
# See https://github.com/huggingface/transformers/pull/29285
514-
device_type = x.device.type
515-
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
516-
with torch.autocast(device_type=device_type, enabled=False):
517-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
518-
emb = torch.cat((freqs, freqs), dim=-1)
519-
cos = emb.cos()
520-
sin = emb.sin()
521-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
496+
class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
497+
def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None):
498+
super().__init__(self, config=config, device=device)
499+
self.rope_kwargs = {"dim": dim, "base": base}
500+
self.config = None
522501

523502

524503
def eager_attention_forward(
@@ -687,9 +666,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
687666
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
688667
)
689668
else:
690-
self.rotary_emb = ModernBertRotaryEmbedding(
691-
dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
692-
)
669+
self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)
693670

694671
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
695672
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()

0 commit comments

Comments
 (0)