Skip to content

Commit 3b17ea2

Browse files
authored
[TPU] Re-enable the Pallas MoE kernel (#18025)
Signed-off-by: Michael Goin <[email protected]>
1 parent 23baa21 commit 3b17ea2

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

requirements/tpu.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ setuptools==78.1.0
1818
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1919
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
2020
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
21-
torch==2.8.0.dev20250430
22-
torchvision==0.22.0.dev20250430
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
21+
torch==2.8.0.dev20250518
22+
torchvision==0.22.0.dev20250518
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
2626

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@
5050
else:
5151
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
5252
if current_platform.is_tpu():
53-
# the iterative moe implementation is used until the moe_pallas is fixed
54-
from .moe_torch_iterative import fused_moe as fused_moe_pallas
53+
from .moe_pallas import fused_moe as fused_moe_pallas
5554
else:
5655
fused_moe_pallas = None # type: ignore
5756
logger = init_logger(__name__)

vllm/model_executor/layers/fused_moe/moe_pallas.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,23 @@
22

33
import torch
44
import torch.nn.functional as F
5-
from torch_xla.experimental.custom_kernel import _histogram
5+
6+
7+
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
8+
"""
9+
Compute the histogram of a int32 tensor. The bin edges are defined by the
10+
min and max values, with step = 1.
11+
"""
12+
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
13+
assert min <= max, "min must be less than or equal to max."
14+
15+
def searchsorted(sorted_sequence: torch.Tensor,
16+
values_to_search: torch.Tensor) -> torch.Tensor:
17+
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
18+
19+
bin_edges = torch.linspace(min, max, max - min + 1,
20+
dtype=input.dtype).to(input.device)
21+
return searchsorted(bin_edges, input).to(torch.int32)
622

723

824
def fused_moe(
@@ -61,7 +77,7 @@ def fused_moe(
6177
x = torch.ops.xla.gmm(x, w2, group_sizes)
6278
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
6379

64-
x = x * topk_weights.unsqueeze_(dim=-1)
80+
x = x * topk_weights.unsqueeze(dim=-1)
6581
x = x.sum(dim=-2)
6682
x = x.reshape(orig_shape)
6783
return x

0 commit comments

Comments
 (0)