Skip to content

"attn_bias is not correctly aligned" on A100 for MPT-30B #795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
dlopes78 opened this issue Aug 18, 2023 · 3 comments · Fixed by #834
Closed

"attn_bias is not correctly aligned" on A100 for MPT-30B #795

dlopes78 opened this issue Aug 18, 2023 · 3 comments · Fixed by #834
Labels
bug Something isn't working

Comments

@dlopes78
Copy link

Hello,

I saw a similar issue to this for MPT30B0-chat on H100, but I see the same error on A100 80Gb. Using vllm 0.1.3. Is there any workaround to fix this currently? It does happen for random prompt, so not straightforward to understand where it's coming from:

96 │   │   │   │   prompt_template = PromptTemplate(input_variables=["text"] │

│ 97 │ │ │ │ answer_chain = LLMChain(llm=self.llm , prompt=prompt_temp │
│ 98 │ │ │ │ │
│ ❱ 99 │ │ │ │ response = answer_chain.run(query) │
│ 100 │ │ │ │
│ 101 │ │ │ else: │
│ 102 │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/base. │
│ py:440 in run │
│ │
│ 437 │ │ if args and not kwargs: │
│ 438 │ │ │ if len(args) != 1: │
│ 439 │ │ │ │ raise ValueError("run supports only one positional argu │
│ ❱ 440 │ │ │ return self(args[0], callbacks=callbacks, tags=tags, metadata │
│ 441 │ │ │ │ _output_key │
│ 442 │ │ │ ] │
│ 443 │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/base. │
│ py:243 in call
│ │
│ 240 │ │ │ ) │
│ 241 │ │ except (KeyboardInterrupt, Exception) as e: │
│ 242 │ │ │ run_manager.on_chain_error(e) │
│ ❱ 243 │ │ │ raise e │
│ 244 │ │ run_manager.on_chain_end(outputs) │
│ 245 │ │ final_outputs: Dict[str, Any] = self.prep_outputs( │
│ 246 │ │ │ inputs, outputs, return_only_outputs │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/base. │
│ py:237 in call
│ │
│ 234 │ │ ) │
│ 235 │ │ try: │
│ 236 │ │ │ outputs = ( │
│ ❱ 237 │ │ │ │ self._call(inputs, run_manager=run_manager) │
│ 238 │ │ │ │ if new_arg_supported │
│ 239 │ │ │ │ else self._call(inputs) │
│ 240 │ │ │ ) │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/llm.p │
│ y:92 in _call │
│ │
│ 89 │ │ inputs: Dict[str, Any], │
│ 90 │ │ run_manager: Optional[CallbackManagerForChainRun] = None, │
│ 91 │ ) -> Dict[str, str]: │
│ ❱ 92 │ │ response = self.generate([inputs], run_manager=run_manager) │
│ 93 │ │ return self.create_outputs(response)[0] │
│ 94 │ │
│ 95 │ def generate( │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/llm.p │
│ y:102 in generate │
│ │
│ 99 │ ) -> LLMResult: │
│ 100 │ │ """Generate LLM result from inputs.""" │
│ 101 │ │ prompts, stop = self.prep_prompts(input_list, run_manager=run_man │
│ ❱ 102 │ │ return self.llm.generate_prompt( │
│ 103 │ │ │ prompts, │
│ 104 │ │ │ stop, │
│ 105 │ │ │ callbacks=run_manager.get_child() if run_manager else None, │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :186 in generate_prompt │
│ │
│ 183 │ │ **kwargs: Any, │
│ 184 │ ) -> LLMResult: │
│ 185 │ │ prompt_strings = [p.to_string() for p in prompts] │
│ ❱ 186 │ │ return self.generate(prompt_strings, stop=stop, callbacks=callbac │
│ 187 │ │
│ 188 │ async def agenerate_prompt( │
│ 189 │ │ self, │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :279 in generate │
│ │
│ 276 │ │ │ run_managers = callback_manager.on_llm_start( │
│ 277 │ │ │ │ dumpd(self), prompts, invocation_params=params, options=o │
│ 278 │ │ │ ) │
│ ❱ 279 │ │ │ output = self._generate_helper( │
│ 280 │ │ │ │ prompts, stop, run_managers, bool(new_arg_supported), **k │
│ 281 │ │ │ ) │
│ 282 │ │ │ return output │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :223 in _generate_helper │
│ │
│ 220 │ │ except (KeyboardInterrupt, Exception) as e: │
│ 221 │ │ │ for run_manager in run_managers: │
│ 222 │ │ │ │ run_manager.on_llm_error(e) │
│ ❱ 223 │ │ │ raise e │
│ 224 │ │ flattened_outputs = output.flatten() │
│ 225 │ │ for manager, flattened_output in zip(run_managers, flattened_outp │
│ 226 │ │ │ manager.on_llm_end(flattened_output) │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :210 in _generate_helper │
│ │
│ 207 │ ) -> LLMResult: │
│ 208 │ │ try: │
│ 209 │ │ │ output = ( │
│ ❱ 210 │ │ │ │ self._generate( │
│ 211 │ │ │ │ │ prompts, │
│ 212 │ │ │ │ │ stop=stop, │
│ 213 │ │ │ │ │ # TODO: support multiple run managers │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :604 in _generate │
│ │
│ 601 │ │ │ text = ( │
│ 602 │ │ │ │ self._call(prompt, stop=stop, run_manager=run_manager, ** │
│ 603 │ │ │ │ if new_arg_supported │
│ ❱ 604 │ │ │ │ else self._call(prompt, stop=stop, **kwargs) │
│ 605 │ │ │ ) │
│ 606 │ │ │ generations.append([Generation(text=text)]) │
│ 607 │ │ return LLMResult(generations=generations) │
│ │
│ /root/xxx.py:64 in _call │
│ │
│ 61 │ │ │ │ │ │ │ max_tokens=300, │
│ 62 │ │ │ │ │ │ │ ) │
│ 63 │ │ │
│ ❱ 64 │ │ output = model.generate(prompt, sampling_params) │
│ 65 │ │ │
│ 66 │ │ return output[0].outputs[0].text │
│ 67 │
│ │
│ /root/vllm/vllm/entrypoints/llm.py:130 in generate │
│ │
│ 127 │ │ │ else: │
│ 128 │ │ │ │ token_ids = prompt_token_ids[i] │
│ 129 │ │ │ self._add_request(prompt, sampling_params, token_ids) │
│ ❱ 130 │ │ return self._run_engine(use_tqdm) │
│ 131 │ │
│ 132 │ def _add_request( │
│ 133 │ │ self, │
│ │
│ /root/vllm/vllm/entrypoints/llm.py:150 in _run_engine │
│ │
│ 147 │ │ # Run the engine. │
│ 148 │ │ outputs: List[RequestOutput] = [] │
│ 149 │ │ while self.llm_engine.has_unfinished_requests(): │
│ ❱ 150 │ │ │ step_outputs = self.llm_engine.step() │
│ 151 │ │ │ for output in step_outputs: │
│ 152 │ │ │ │ if output.finished: │
│ 153 │ │ │ │ │ outputs.append(output) │
│ │
│ /root/vllm/vllm/engine/llm_engine.py:313 in step │
│ │
│ 310 │ │ │ ] │
│ 311 │ │ │
│ 312 │ │ # Execute the model. │
│ ❱ 313 │ │ output = self._run_workers( │
│ 314 │ │ │ "execute_model", │
│ 315 │ │ │ seq_group_metadata_list=seq_group_metadata_list, │
│ 316 │ │ │ blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, │
│ │
│ /root/vllm/vllm/engine/llm_engine.py:470 in _run_workers │
│ │
│ 467 │ │ │ else: │
│ 468 │ │ │ │ executor = getattr(worker, method) │
│ 469 │ │ │ │
│ ❱ 470 │ │ │ output = executor(*args, **kwargs) │
│ 471 │ │ │ all_outputs.append(output) │
│ 472 │ │ │
│ 473 │ │ if self.parallel_config.worker_use_ray: │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/utils/_contextli │
│ b.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /root/vllm/vllm/worker/worker.py:293 in execute_model │
│ │
│ 290 │ │ │ seq_group_metadata_list) │
│ 291 │ │ │
│ 292 │ │ # Execute the model. │
│ ❱ 293 │ │ output = self.model( │
│ 294 │ │ │ input_ids=input_tokens, │
│ 295 │ │ │ positions=input_positions, │
│ 296 │ │ │ kv_caches=self.gpu_cache, │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/models/mpt.py:234 in forward │
│ │
│ 231 │ │ input_metadata: InputMetadata, │
│ 232 │ │ cache_events: Optional[List[torch.cuda.Event]], │
│ 233 │ ) -> Dict[int, SequenceOutputs]: │
│ ❱ 234 │ │ hidden_states = self.transformer(input_ids, positions, kv_caches, │
│ 235 │ │ │ │ │ │ │ │ │ │ input_metadata, cache_events) │
│ 236 │ │ next_tokens = self.sampler(self.lm_head_weight, hidden_states, │
│ 237 │ │ │ │ │ │ │ │ input_metadata) │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/models/mpt.py:202 in forward │
│ │
│ 199 │ │ │ else: │
│ 200 │ │ │ │ cache_event = cache_events[i] │
│ 201 │ │ │ block = self.blocks[i] │
│ ❱ 202 │ │ │ hidden_states = block( │
│ 203 │ │ │ │ position_ids, │
│ 204 │ │ │ │ hidden_states, │
│ 205 │ │ │ │ kv_caches[i], │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/models/mpt.py:153 in forward │
│ │
│ 150 │ │ cache_event: Optional[torch.cuda.Event], │
│ 151 │ ) -> torch.Tensor: │
│ 152 │ │ x = self.norm_1(hidden_states) │
│ ❱ 153 │ │ x = self.attn( │
│ 154 │ │ │ position_ids=position_ids, │
│ 155 │ │ │ hidden_states=x, │
│ 156 │ │ │ kv_cache=kv_cache, │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/models/mpt.py:102 in forward │
│ │
│ 99 │ │ │ q = self.q_ln(q) │
│ 100 │ │ │ k = self.k_ln(k) │
│ 101 │ │ k_cache, v_cache = kv_cache │
│ ❱ 102 │ │ attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata │
│ 103 │ │ │ │ │ │ │ │ cache_event) │
│ 104 │ │ output, _ = self.out_proj(attn_output) │
│ 105 │ │ return output │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/layers/attention.py:202 in forward │
│ │
│ 199 │ │ │ # Prompt run. │
│ 200 │ │ │ assert input_metadata.num_generation_tokens == 0 │
│ 201 │ │ │ self.set_attn_bias(input_metadata) │
│ ❱ 202 │ │ │ self.multi_query_kv_attention( │
│ 203 │ │ │ │ output[:num_prompt_tokens], │
│ 204 │ │ │ │ query[:num_prompt_tokens], │
│ 205 │ │ │ │ key[:num_prompt_tokens], │
│ │
│ /root/vllm/vllm/model_executor/layers/attention.py:399 in │
│ multi_query_kv_attention │
│ │
│ 396 │ │ start = 0 │
│ 397 │ │ for i, prompt_len in enumerate(input_metadata.prompt_lens): │
│ 398 │ │ │ end = start + prompt_len │
│ ❱ 399 │ │ │ out = xops.memory_efficient_attention_forward( │
│ 400 │ │ │ │ query[None, start:end], │
│ 401 │ │ │ │ key[None, start:end], │
│ 402 │ │ │ │ value[None, start:end], │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/xformers/ops/fmha/in │
│ it
.py:213 in memory_efficient_attention_forward │
│ │
│ 210 │ """ │
│ 211 │ Calculates the forward pass of :attr:xformers.ops.memory_efficient_a │ │ 212 │ """ │ │ ❱ 213 │ return _memory_efficient_attention_forward( │ │ 214 │ │ Inputs( │ │ 215 │ │ │ query=query, key=key, value=value, p=p, attn_bias=attn_bias, │ │ 216 │ │ ), │ │ │ │ /root/miniconda3/envs/py311/lib/python3.11/site-packages/xformers/ops/fmha/__in │ │ it__.py:310 in _memory_efficient_attention_forward │ │ │ │ 307 │ else: │ │ 308 │ │ _ensure_op_supports_or_raise(ValueError, "memory_efficient_attent │ │ 309 │ │ │ ❱ 310 │ out, *_ = op.apply(inp, needs_gradient=False) │ │ 311 │ return out.reshape(output_shape) │ │ 312 │ │ 313 │ │ │ │ /root/miniconda3/envs/py311/lib/python3.11/site-packages/xformers/ops/fmha/cutl │ │ ass.py:175 in apply │ │ │ │ 172 │ │ if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: │ │ 173 │ │ │ raise NotImplementedError("Unsupported attn_bias type") │ │ 174 │ │ seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp) │ │ ❱ 175 │ │ out, lse, rng_seed, rng_offset = cls.OPERATOR( │ │ 176 │ │ │ query=inp.query, │ │ 177 │ │ │ key=inp.key, │ │ 178 │ │ │ value=inp.value, │ │ │ │ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/_ops.py:502 in │ │ __call__ │ │ │ │ 499 │ │ # is still callable from JIT │ │ 500 │ │ # We save the function ptr as the op` attribute on │
│ 501 │ │ # OpOverloadPacket to access it here. │
│ ❱ 502 │ │ return self._op(*args, **kwargs or {}) │
│ 503 │ │
│ 504 │ # TODO: use this to make a dir
│ 505 │ def overloads(self): │
╰─────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: attn_bias is not correctly aligned

currently generating with:
model = vllm.LLM(model= "MPT-30B" , trust_remote_code=True)
sampling_params = vllm.SamplingParams(n=1,
temperature=0.2,
top_p=0.9,
top_k = -1,
best_of=1,
use_beam_search=False,
max_tokens=300,
)

    output = model.generate(prompt, sampling_params)

Other libraries version:
xformers==0.0.20
langchain==0.0.232
torch==2.0.1
pytorch-triton==2.1.0+440fd1bf20

@WoosukKwon WoosukKwon added the bug Something isn't working label Aug 23, 2023
@WoosukKwon
Copy link
Collaborator

Hi @dlopes78, thanks for reporting the bug. Could you provide more details, especially the prompts you used? I tested the same model on A100-80GB, and couldn't reproduce the bug.

@dlopes78
Copy link
Author

Hello - this happens with random queries but for example with just "Hello" the crash should be reproducible.

@WoosukKwon
Copy link
Collaborator

@dlopes78 Thanks for providing the example. I reproduced the bug. This will be fixed by #834

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants