Skip to content

[WIP][Bugfix] Minimax-VL-01 fix processing #17833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/model_executor/models/minimax_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)

def current_run_tensors(self, **kwargs) -> MinimaxCacheParams:
"""
Return the tensors for the current run as MinimaxCacheParams.
"""
cache_tensors, state_indices_tensor = super().current_run_tensors(**kwargs)
return MinimaxCacheParams(cache_tensors, state_indices_tensor)
57 changes: 26 additions & 31 deletions vllm/model_executor/models/minimax_text_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def __init__(
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads

self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size * 3,
Expand Down Expand Up @@ -427,17 +426,17 @@ def get_slopes_power_of_2(n):
n_attention_heads, 1, 1)
return slopes

def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor):
hidden = []
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
_start = attn_metadata.query_start_loc[_prefill_idx]
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
slot_id = state_indices_tensor[_prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
Expand Down Expand Up @@ -490,8 +489,7 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
state_indices_tensor)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor, attn_metadata)
Expand Down Expand Up @@ -845,19 +843,7 @@ def layer_fn(prefix):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers")

linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
if self.decoder_attention_types[i] == 0)
max_slots_number = scheduler_config.max_num_seqs
self.cache_shape = (linear_layer_nums, max_slots_number,
config.num_attention_heads //
get_tensor_model_parallel_world_size(),
config.head_dim, config.head_dim)
_dummy = torch.zeros(1)
self._dtype = _dummy.dtype
del _dummy

self.minimax_cache = MinimaxCacheManager(dtype=self._dtype,
cache_shape=self.cache_shape)
self.minimax_cache: Optional[MinimaxCacheManager] = None

rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim",
Expand Down Expand Up @@ -918,28 +904,23 @@ def get_input_embeddings(
def forward(self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
minimax_cache: MinimaxCacheManager,
minimax_cache_params: MinimaxCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []

(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)

self.minimax_cache = minimax_cache
minimax_cache_tensors = minimax_cache_params.state_indices_tensor

if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)

minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
Expand Down Expand Up @@ -997,6 +978,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config.sliding_window = None

self.CONCAT_FFN = True
self.minimax_cache: Optional[MinimaxCacheManager] = None

self.unpadded_vocab_size = self.config.vocab_size
if hasattr(vllm_config.model_config, "max_model_len"):
Expand Down Expand Up @@ -1024,6 +1006,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
if attn_type == 1)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]

linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
if config.attn_type_list[i] == 0)
max_slots_number = vllm_config.scheduler_config.max_num_seqs
self.cache_shape = (linear_layer_nums, max_slots_number,
config.num_attention_heads //
get_tensor_model_parallel_world_size(),
config.head_dim, config.head_dim)
return

def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
Expand All @@ -1046,7 +1036,12 @@ def forward(self,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
if self.minimax_cache is None:
self.minimax_cache = MinimaxCacheManager(
dtype=self._dtype, cache_shape=self.cache_shape)
minimax_cache_params = self.minimax_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, self.minimax_cache,
minimax_cache_params, intermediate_tensors,
inputs_embeds, **kwargs)

return hidden_states
Expand Down
Loading
Loading