Skip to content

Commit ccda19f

Browse files
authored
Merge pull request huggingface#36 from huggingface/sparse-llama4-moe
Add support for sparse `Llama4TextMoe` layer from the kernel hub
2 parents a515579 + a9045fc commit ccda19f

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
_hub_kernels_available = True
2727

2828
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
29+
"Llama4TextMoe": {
30+
"cuda": LayerRepository(
31+
# Move to kernels-community/moe once we release.
32+
repo_id="kernels-community/moe-new-models",
33+
layer_name="Llama4TextMoe",
34+
)
35+
},
2936
"MultiScaleDeformableAttention": {
3037
"cuda": LayerRepository(
3138
repo_id="kernels-community/deformable-detr",

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ...activations import ACT2FN
3434
from ...cache_utils import Cache, DynamicCache
3535
from ...generation import GenerationMixin
36+
from ...integrations.hub_kernels import use_kernel_forward_from_hub
3637
from ...modeling_attn_mask_utils import AttentionMaskConverter
3738
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3839
from ...modeling_outputs import (
@@ -150,6 +151,7 @@ def extra_repr(self):
150151
return f"{tuple(self.weight.shape)}, eps={self.eps}"
151152

152153

154+
@use_kernel_forward_from_hub("Llama4TextMoe")
153155
class Llama4TextMoe(nn.Module):
154156
def __init__(self, config):
155157
super().__init__()

0 commit comments

Comments
 (0)