Skip to content

Remove dummy forward path #3669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/torch/attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ It contains the following predefined fields:
| request_ids | List[int] | The request ID of each sequence in the batch. |
| prompt_lens | List[int] | The prompt length of each sequence in the batch. |
| kv_cache_params | KVCacheParams | The parameters for the KV cache. |
| is_dummy_attention | bool | Indicates whether this is a simulation-only attention operation used for KV cache memory estimation. Defaults to False. |

During `AttentionMetadata.__init__`, you can initialize additional fields for the new attention metadata.
For example, the Flashinfer metadata initializes `decode_wrapper` here.
Expand Down
10 changes: 1 addition & 9 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..utils import get_global_attrs, get_model_extra_attrs
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
PredefinedAttentionMask, dummy_forward)
PredefinedAttentionMask)

try:
check_cuda_arch()
Expand Down Expand Up @@ -465,14 +465,6 @@ def forward_pattern(
else:
metadata = get_global_attrs().attention_metadata()

# This is only for memory estimation for now.
# NOTE: this method is not accurate while it works for most scenario.
if metadata is None or metadata.kv_cache_manager is None:
q = q.view(-1, num_heads, head_dim)
k = k.view(-1, num_kv_heads, head_dim)
v = v.view(-1, num_kv_heads, head_dim)
return dummy_forward(q, k, v)

assert isinstance(
metadata,
FlashInferAttentionMetadata,
Expand Down
35 changes: 0 additions & 35 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Union)

import torch
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from typing_extensions import Self

from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
Expand Down Expand Up @@ -125,7 +124,6 @@ class AttentionMetadata:
_num_generations: int = field(init=False, default=0, repr=False)
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
_num_tokens: int = field(init=False, default=0, repr=False)
is_dummy_attention: bool = False

def __post_init__(self) -> None:
if self.is_cross:
Expand Down Expand Up @@ -548,36 +546,3 @@ class MLAParams:
qk_nope_head_dim: int = 0
v_head_dim: int = 0
predicted_tokens_per_seq: int = 1


@torch.library.custom_op("trtllm::attn_dummy_fwd", mutates_args=())
def dummy_forward(q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor) -> torch.Tensor:
"""
Dummy attention forward function to estimate memory usage.
Args:
q (torch.Tensor): Query tensor with shape (num_q_tokens, num_heads, head_dim),.
k (torch.Tensor): Key tensor with shape (num_new_kv_tokens, num_kv_heads, head_dim)
v (torch.Tensor): Value tensor with shape (num_new_kv_tokens, num_kv_heads, head_dim)
Returns:
torch.Tensor with shape (num_q_tokens, num_heads * head_dim)
"""
head_dim = q.shape[2]
assert q.dim() == 3
assert k.dim() == 3 and k.size(2) == head_dim
assert v.dim() == 3 and v.size(2) == head_dim
# This is only for memory estimation for now.
# NOTE: this method is not accurate while it works for most scenario.
o = _flash_attention_forward(q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
attention_mask=None,
query_length=q.size(0),
is_causal=True)
return o.reshape(o.size(1), -1)


@dummy_forward.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
num_q_tokens = q.size(0)
return torch.empty_like(q).reshape(num_q_tokens, -1)
8 changes: 1 addition & 7 deletions tensorrt_llm/_torch/attention_backend/star_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from ..distributed import allgather
from .flashinfer import FlashInferAttentionMetadata, PlanParams
from .interface import (AttentionBackend, AttentionMask,
PredefinedAttentionMask, dummy_forward)
from .interface import AttentionBackend, AttentionMask, PredefinedAttentionMask


# Please sync with flashinfer's DISPATCH_GQA_GROUP_SIZE in include/flashinfer/utils.cuh
Expand Down Expand Up @@ -326,11 +325,6 @@ def forward(self,
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)

# This is only for memory estimation for now.
# NOTE: this method is not accurate while it works for most scenario.
if metadata is None or metadata.kv_cache_manager is None:
return dummy_forward(q, k, v)

num_contexts = metadata.num_contexts
num_queries = metadata.num_queries
num_generations = metadata.num_generations
Expand Down
29 changes: 2 additions & 27 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
AttentionMetadata, KVCacheParams, MLAParams,
PositionalEmbeddingParams, PredefinedAttentionMask,
RopeParams, dummy_forward)
RopeParams)


@dataclass(kw_only=True, init=False)
Expand Down Expand Up @@ -489,7 +489,7 @@ def __post_init__(self) -> None:

def prepare(self) -> None:

if not self.is_dummy_attention and self.kv_cache_manager is None:
if self.kv_cache_manager is None:
# Convert the attention metadata to a TRT-LLM no cache attention metadata.
assert self.kv_cache_manager is None, "no cache attention should not have KV cache manager"
assert self._max_seq_len_storage is not None, "max_seq_len should be set for no cache attention"
Expand Down Expand Up @@ -641,31 +641,6 @@ def forward(
mrope_config: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
# This is only for memory estimation for now.
# NOTE: this method is not accurate while it works for most scenario.
if metadata.is_dummy_attention:
q_size = self.num_heads * self.head_dim
k_size = self.num_kv_heads * self.head_dim
v_size = self.num_kv_heads * self.v_head_dim
q, k, v = q.split([q_size, k_size, v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
if self.head_dim != self.v_head_dim:
# the dummy forward doesn't support head_dim != v_head_dim case
# so we use a tensor with supported shape to replace the v
# the memory estimation is not accurate in this case
v = torch.randn(q.shape[0],
self.num_kv_heads,
self.head_dim,
dtype=q.dtype,
device=q.device)
output = dummy_forward(q, k, v)
if self.head_dim != self.v_head_dim:
output = output[..., :self.num_kv_heads *
self.v_head_dim].contiguous()
return output

assert isinstance(
metadata,
TrtllmAttentionMetadata,
Expand Down
14 changes: 2 additions & 12 deletions tensorrt_llm/_torch/attention_backend/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
AttentionMaskConverter = None

from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
PredefinedAttentionMask, dummy_forward)
PredefinedAttentionMask)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -230,12 +230,7 @@ def forward(self,
*,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
**kwargs) -> torch.Tensor:

# This is only for memory estimation for now.
# NOTE: this method is not accurate while it works for most scenario.
if metadata.is_dummy_attention:
return dummy_forward(q, k, v)
elif metadata.kv_cache_manager is None:
if metadata.kv_cache_manager is None:
# NOTE: WAR for no kv cache attn e.g. BERT,
# try to separate the kv cache estimation path from no kv cache attn.
num_heads = self.num_heads
Expand All @@ -249,11 +244,6 @@ def forward(self,
metadata=metadata,
attention_mask=attention_mask)

# This is only for memory estimation for now.
# NOTE: this method is not accurate while it works for most scenario.
if metadata is None or metadata.kv_cache_manager is None:
return dummy_forward(q, k, v)

past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
cache_indices = [
block_ids[0] for block_ids in metadata.block_ids_per_seq
Expand Down
18 changes: 5 additions & 13 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,21 +576,15 @@ def _create_extra_inputs(bs, num_tokens_per_request):
extra_model_inputs=_create_extra_inputs(bs, 1))
torch.cuda.synchronize()

def _set_up_attn_metadata(self,
kv_cache_manager: KVCacheManager,
is_dummy_forward: bool = False):
# is_dummy_forward is used to indicate whether the forward is
# a dummy forward for memory estimation OR
# a real forward w.o. kv cache
def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
if kv_cache_manager is None:
return self.attn_backend.Metadata(
max_num_requests=self.batch_size,
max_num_tokens=self.max_num_tokens,
kv_cache_manager=None,
mapping=self.mapping,
runtime_features=self.attn_runtime_features,
enable_flash_mla=self.model.model_config.enable_flash_mla,
is_dummy_attention=is_dummy_forward)
enable_flash_mla=self.model.model_config.enable_flash_mla)

if self.attn_metadata is not None:
# This assertion can be relaxed if needed: just create a new metadata
Expand Down Expand Up @@ -1282,7 +1276,7 @@ def _prepare_tp_inputs_no_cache(
all_rank_num_tokens = self.dist.allgather(attn_metadata.num_tokens)
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
# this is for no cache attention, not for dummy attention
if not attn_metadata.is_dummy_attention and attn_metadata.kv_cache_manager is None:
if attn_metadata.kv_cache_manager is None:
assert isinstance(
attn_metadata,
(VanillaAttentionMetadata, TrtllmAttentionMetadata)
Expand Down Expand Up @@ -1596,14 +1590,12 @@ def forward(self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tensors_device: Optional[Dict[str, torch.Tensor]] = None,
extra_model_inputs: Optional[Dict[str, Any]] = None,
is_dummy_forward: bool = False):
extra_model_inputs: Optional[Dict[str, Any]] = None):

kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)

attn_metadata = self._set_up_attn_metadata(kv_cache_manager,
is_dummy_forward)
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
if self.spec_config is not None:
spec_resource_manager = resource_manager.get_resource_manager(
'spec_resource_manager')
Expand Down
4 changes: 1 addition & 3 deletions tests/unittest/_torch/test_attention_no_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,7 @@ def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers,
max_num_tokens=8192,
kv_cache_manager=None,
mapping=None,
runtime_features=None,
is_dummy_attention=False,
)
runtime_features=None)

# NOTE: set up metadata
attn_metadata.seq_lens = torch.tensor(sequence_lengths, dtype=torch.int)
Expand Down