Skip to content

Commit 25f25fc

Browse files
committed
Fix PiecewiseCompileInterpreter
This PR fixes the other issue discovered in vllm-project#16859 when upgrading from PyTorch 2.6 to PyTorch 2.7. I don't know why the code used to work in PyTorch 2.6, but the explanation is: - when we are running PiecewiseCompileInterpreter, we end up doing FakeTensor propagation - FakeTensor propagation requires `enable_python_dispatcher` to work. The mechanism is that some of our "C++ implementations" for operations, like matmul, force specialization of dynamic shapes. torch.compile works around this by replacing PyTorch's "C++ implementation" for matmul with a python-based implementation for matmul that does not force specialization. Test Plan: - Ran `pytest -v tests/models/test_transformers.py -k test_models[meta-llama/Llama-3.2-1B-Instruct-transformers]` with PyTorch >= 2.7 and vllm-project#17330, verified that the test passes. Signed-off-by: rzou <[email protected]>
1 parent 79a1d25 commit 25f25fc

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

vllm/compilation/backends.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.fx as fx
14+
from torch._dispatch.python import enable_python_dispatcher
1415

1516
import vllm.envs as envs
1617
from vllm.config import CompilationConfig, VllmConfig
@@ -270,7 +271,7 @@ def run(self, *args):
270271
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
271272
for t in args
272273
]
273-
with self.fake_mode:
274+
with self.fake_mode, enable_python_dispatcher():
274275
return super().run(*fake_args)
275276

276277
def call_module(self, target: torch.fx.node.Target,

0 commit comments

Comments
 (0)