|
30 | 30 | from ...activations import ACT2FN
|
31 | 31 | from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
32 | 32 | from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
| 33 | +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS |
33 | 34 | from ...modeling_utils import PreTrainedModel
|
34 | 35 | from ...utils import (
|
35 | 36 | add_code_sample_docstrings,
|
@@ -235,30 +236,62 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
235 | 236 |
|
236 | 237 |
|
237 | 238 | class ModernBertRotaryEmbedding(nn.Module):
|
238 |
| - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
| 239 | + def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None): |
239 | 240 | super().__init__()
|
| 241 | + self.rope_kwargs = {"dim": dim, "base": base} |
| 242 | + # BC: "rope_type" was originally "type" |
| 243 | + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
| 244 | + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| 245 | + else: |
| 246 | + self.rope_type = "default" |
| 247 | + self.max_seq_len_cached = config.max_position_embeddings |
| 248 | + self.original_max_seq_len = config.max_position_embeddings |
| 249 | + self.config = None |
| 250 | + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
| 251 | + |
| 252 | + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) |
| 253 | + self.register_buffer("inv_freq", inv_freq, persistent=False) |
| 254 | + self.original_inv_freq = self.inv_freq |
| 255 | + |
| 256 | + def _dynamic_frequency_update(self, position_ids, device): |
| 257 | + """ |
| 258 | + dynamic RoPE layers should recompute `inv_freq` in the following situations: |
| 259 | + 1 - growing beyond the cached sequence length (allow scaling) |
| 260 | + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) |
| 261 | + """ |
| 262 | + seq_len = torch.max(position_ids) + 1 |
| 263 | + if seq_len > self.max_seq_len_cached: # growth |
| 264 | + inv_freq, self.attention_scaling = self.rope_init_fn( |
| 265 | + self.config, device, seq_len=seq_len, **self.rope_kwargs |
| 266 | + ) |
| 267 | + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation |
| 268 | + self.max_seq_len_cached = seq_len |
240 | 269 |
|
241 |
| - self.dim = dim |
242 |
| - self.max_position_embeddings = max_position_embeddings |
243 |
| - self.base = base |
244 |
| - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) |
245 |
| - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) |
| 270 | + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset |
| 271 | + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) |
| 272 | + self.max_seq_len_cached = self.original_max_seq_len |
246 | 273 |
|
247 | 274 | @torch.no_grad()
|
248 |
| - def forward(self, x, position_ids, seq_len=None): |
249 |
| - # x: [bs, num_attention_heads, seq_len, head_size] |
250 |
| - self.inv_freq.to(x.device) |
| 275 | + def forward(self, x, position_ids): |
| 276 | + if "dynamic" in self.rope_type: |
| 277 | + self._dynamic_frequency_update(position_ids, device=x.device) |
| 278 | + |
| 279 | + # Core RoPE block |
251 | 280 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
252 | 281 | position_ids_expanded = position_ids[:, None, :].float()
|
253 |
| - # Force float32 since bfloat16 loses precision on long contexts |
254 |
| - # See https://github.com/huggingface/transformers/pull/29285 |
| 282 | + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) |
255 | 283 | device_type = x.device.type
|
256 | 284 | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
257 | 285 | with torch.autocast(device_type=device_type, enabled=False):
|
258 | 286 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
259 | 287 | emb = torch.cat((freqs, freqs), dim=-1)
|
260 | 288 | cos = emb.cos()
|
261 | 289 | sin = emb.sin()
|
| 290 | + |
| 291 | + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention |
| 292 | + cos = cos * self.attention_scaling |
| 293 | + sin = sin * self.attention_scaling |
| 294 | + |
262 | 295 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
263 | 296 |
|
264 | 297 |
|
@@ -462,9 +495,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
|
462 | 495 | dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
463 | 496 | )
|
464 | 497 | else:
|
465 |
| - self.rotary_emb = ModernBertRotaryEmbedding( |
466 |
| - dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta |
467 |
| - ) |
| 498 | + self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta) |
468 | 499 |
|
469 | 500 | self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
470 | 501 | self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
|
0 commit comments