Skip to content

Commit dba4d9d

Browse files
authored
[v1][bugfix] fix cudagraph with inplace buffer assignment (#11596)
Signed-off-by: youkaichao <[email protected]>
1 parent 32b4c63 commit dba4d9d

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

vllm/compilation/wrapper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ def __init__(self,
2828
compiled_callable: Optional[Callable] = None,
2929
compilation_level: int = 0):
3030

31+
vllm_config = get_current_vllm_config()
32+
self.vllm_config = vllm_config
3133
if compiled_callable is None:
3234
# default compilation settings
3335
# compiling the forward method
3436

35-
vllm_config = get_current_vllm_config()
3637
backend = vllm_config.compilation_config.init_backend(vllm_config)
3738

3839
compiled_callable = torch.compile(
@@ -82,6 +83,13 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
8283

8384
self.compiled_codes.append(new_code)
8485

86+
if self.vllm_config.compilation_config.use_cudagraph and \
87+
"update" in new_code.co_names:
88+
import depyf
89+
src = depyf.decompile(new_code)
90+
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
91+
raise RuntimeError(msg)
92+
8593
@contextmanager
8694
def dispatch_to_code(self, index: int):
8795
"""Context manager to dispatch to the compiled code.

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -541,19 +541,12 @@ def __init__(
541541
short_cache = self._compute_cos_sin_cache(
542542
original_max_position_embeddings, short_factor, short_mscale)
543543
short_cache = short_cache.to(dtype)
544-
self.register_buffer("short_cos_sin_cache",
545-
short_cache,
546-
persistent=False)
547544

548545
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
549546
long_factor, long_mscale)
550547
long_cache = long_cache.to(dtype)
551-
self.register_buffer("long_cos_sin_cache",
552-
long_cache,
553-
persistent=False)
554548

555-
long_short_cache = torch.cat(
556-
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
549+
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
557550
self.register_buffer("long_short_cos_sin_cache",
558551
long_short_cache,
559552
persistent=False)
@@ -593,8 +586,6 @@ def forward(
593586
torch.full_like(positions, k)).long()
594587
idx = (torch.add(positions, long_prompt_offset)
595588
if long_prompt_offset is not None else positions)
596-
self.long_short_cos_sin_cache: torch.Tensor = (
597-
self.long_short_cos_sin_cache.to(idx.device))
598589
idx = torch.add(idx, offsets) if offsets is not None else idx
599590
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
600591

0 commit comments

Comments
 (0)