We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f622dbc commit d20e261Copy full SHA for d20e261
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
@@ -115,6 +115,10 @@ def apply_weights(self,
115
layer: torch.nn.Module,
116
x: torch.Tensor,
117
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
118
+ # marlin requires contiguous memory layout
119
+ # prefix caching may cause x to be non-contiguous
120
+ x = x.contiguous() # no-op if already contiguous
121
+
122
c = self.config
123
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
124
0 commit comments