File tree 2 files changed +6
-18
lines changed
v1/attention/backends/mla 2 files changed +6
-18
lines changed Original file line number Diff line number Diff line change @@ -161,13 +161,8 @@ def forward_cuda(
161
161
) -> Tuple [torch .Tensor , torch .Tensor ]:
162
162
from vllm import _custom_ops as ops
163
163
164
- # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
165
- # is expensive, so avoid calling it if possible
166
- if self .cos_sin_cache .device != query .device or \
167
- self .cos_sin_cache .dtype != query .dtype :
168
- self .cos_sin_cache = self .cos_sin_cache .to (query .device ,
169
- dtype = query .dtype )
170
-
164
+ self .cos_sin_cache = self .cos_sin_cache .to (query .device ,
165
+ dtype = query .dtype )
171
166
# ops.rotary_embedding()/batched_rotary_embedding()
172
167
# are in-place operations that update the query and key tensors.
173
168
if offsets is not None :
Original file line number Diff line number Diff line change 222
222
Fp8LinearGenericOp , current_platform_fp8_dtype , is_fp8 )
223
223
from vllm .model_executor .layers .quantization .utils .quant_utils import (
224
224
scaled_quantize )
225
- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
226
- from vllm . platforms import current_platform
225
+ from vllm .model_executor .layers .rotary_embedding import (
226
+ DeepseekScalingRotaryEmbedding , RotaryEmbedding )
227
227
from vllm .utils import cdiv , round_down
228
228
229
229
try :
@@ -627,15 +627,8 @@ def __init__(
627
627
self .v_head_dim = v_head_dim
628
628
629
629
self .rotary_emb = rotary_emb
630
-
631
- if current_platform .is_cuda ():
632
- # Hack for V1 for now to avoid torch library overhead (since we are
633
- # already inside an attention custom op), pull out the forward
634
- # method from the rotary embedding and call it directly (and avoid
635
- # calling forward_native, when we can call forward_cuda)
636
- # TODO(lucas): we should probably find a cleaner way to do this
637
- self .rotary_emb = rotary_emb .forward_cuda
638
-
630
+ self .use_yarn_rope = isinstance (rotary_emb ,
631
+ DeepseekScalingRotaryEmbedding )
639
632
self .q_proj = q_proj
640
633
self .kv_b_proj = kv_b_proj
641
634
self .o_proj = o_proj
You can’t perform that action at this time.
0 commit comments