Skip to content

[FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) #18500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
8e27206
implement falcon mamba 2 architecture
dhiaEddineRhaiem Feb 26, 2025
191a588
push changes
younesbelkada Feb 26, 2025
d54a010
more changes
younesbelkada Feb 27, 2025
d86dd28
more changes
younesbelkada Feb 27, 2025
a614d9b
more fixes
younesbelkada Feb 27, 2025
004d4e4
more fixes
younesbelkada Feb 27, 2025
d2e74be
remove prints
younesbelkada Feb 27, 2025
12f2dab
fix
younesbelkada Feb 27, 2025
4e131c5
some clean up
younesbelkada Feb 27, 2025
d882e9f
Merge remote-tracking branch 'upstream/main' into falcon_mamba2
younesbelkada Mar 3, 2025
331ea6d
fixes
younesbelkada Mar 3, 2025
6379856
final fixes
younesbelkada Mar 3, 2025
02e9d89
fix
younesbelkada Mar 5, 2025
615f234
add docs
younesbelkada Mar 5, 2025
76e4554
fix hard-coded tied weights
younesbelkada Mar 12, 2025
cfc959f
fix logits issue
younesbelkada Mar 12, 2025
3834a92
Merge remote-tracking branch 'upstream/main' into falcon_mamba2
younesbelkada Apr 7, 2025
536898c
chore: rename to `FalconH1`
younesbelkada Apr 7, 2025
54cc7af
chore: rename file
younesbelkada Apr 7, 2025
a8762f2
fix issue
younesbelkada Apr 8, 2025
4ede3a2
fix Tensor parallelism Falcon H1
dhiaEddineRhaiem Apr 8, 2025
85f4292
Merge branch 'falcon-h1-clean' of https://github.com/tiiuae/vllm-inte…
dhiaEddineRhaiem Apr 8, 2025
ab427c5
adress maintainer comment on the PR
dhiaEddineRhaiem Apr 9, 2025
d585425
formatting
dhiaEddineRhaiem Apr 9, 2025
087cf17
fix pp and residual
ilyasch2 Apr 16, 2025
82a365c
clean
ilyasch2 Apr 16, 2025
ce57226
Merge remote-tracking branch 'upstream/main' into final_fh1
ilyasch2 Apr 16, 2025
958a159
fix
ilyasch2 Apr 16, 2025
d0e3c31
clean
ilyasch2 Apr 16, 2025
6579ec8
remove unnecessary arguments
ilyasch2 Apr 16, 2025
32296aa
clean
ilyasch2 Apr 16, 2025
6f738ad
clean
ilyasch2 Apr 18, 2025
0189427
modify default rope theta
ilyasch2 Apr 18, 2025
37fa5c0
Merge remote-tracking branch 'origin/main' into final_fh1
younesbelkada May 6, 2025
4068ff0
fix
younesbelkada May 6, 2025
20a84ea
chore: update supported models
JingweiZuo May 13, 2025
9fdbf19
chore: clean CI issue
JingweiZuo May 13, 2025
2d7ce65
adress TODO tasks and fix formatting
dhiaEddineRhaiem May 13, 2025
91e3b2d
fix ruff formatting
dhiaEddineRhaiem May 14, 2025
55b132d
fix format
dhiaEddineRhaiem May 14, 2025
6df6407
style
dhiaEddineRhaiem May 14, 2025
bd08b81
fix pre-commit checks
dhiaEddineRhaiem May 14, 2025
57d0332
Update requirements/test.txt from pip-compile
dhiaEddineRhaiem May 14, 2025
e201b92
fix pre-compile issues
dhiaEddineRhaiem May 14, 2025
baa2722
adress comments
dhiaEddineRhaiem May 14, 2025
cd51b64
fix: ruff
dhiaEddineRhaiem May 17, 2025
4a58965
Merge remote-tracking branch 'public/main' into final_fh1
dhiaEddineRhaiem May 20, 2025
559457b
adress comments
dhiaEddineRhaiem May 20, 2025
79e94f1
small fix
dhiaEddineRhaiem May 20, 2025
1f890d3
adress comment b2
dhiaEddineRhaiem May 20, 2025
5f39764
add type annotations
dhiaEddineRhaiem May 20, 2025
647f93a
fix test model TYPO
dhiaEddineRhaiem May 20, 2025
14e14c7
Merge remote-tracking branch 'upstream/main' into final_fh1
dhiaEddineRhaiem May 21, 2025
f630452
add min_transformers_version
dhiaEddineRhaiem May 21, 2025
87be690
fix norm output dtype when use_rms_norm is not used
dhiaEddineRhaiem May 21, 2025
e0811a0
Merge branch 'main' into final_fh1
dhiaEddineRhaiem May 21, 2025
939927a
add comment for silu precision
dhiaEddineRhaiem May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def forward_native(
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x
return x.to(input_dtype)

if self.n_groups == 1:
if self.tp_size > 1:
Expand Down Expand Up @@ -117,9 +117,11 @@ def forward_cuda(
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:

input_dtype = x.dtype
if not self.use_rms_norm:
return x * nn.functional.silu(gate.to(torch.float32))
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(
torch.float32)).to(input_dtype)

if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ def forward(
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
if get_pp_group().is_first_rank:
Expand Down