Skip to content

Commit e48c3a8

Browse files
youkaichaomzusman
authored andcommitted
[torch.compile] fix sym_tensor_indices (vllm-project#12191)
Signed-off-by: youkaichao <[email protected]>
1 parent 04ba021 commit e48c3a8

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/compilation/backends.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
624624
]
625625

626626
# index of tensors that have symbolic shapes (batch size)
627+
# for weights and static buffers, they will have concrete shapes.
628+
# symbolic shape only happens for input tensors.
629+
from torch.fx.experimental.symbolic_shapes import is_symbolic
627630
self.sym_tensor_indices = [
628631
i for i, x in enumerate(fake_args)
629-
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
632+
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
633+
any(is_symbolic(d) for d in x.size())
630634
]
631635

632636
# compiler managed cudagraph input buffers

0 commit comments

Comments
 (0)