We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fdbbfb6 commit 89373b8Copy full SHA for 89373b8
bitsandbytes/backends/cpu_xpu_common.py
@@ -552,6 +552,8 @@ def gemm_4bit_impl(
552
GEMM output tensor.
553
"""
554
if getattr(state, "ipex", False):
555
+ # compute_dtype: 1 indicates fp16, 2 indicates bf16
556
+ compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
557
output = torch.ops.torch_ipex.woq_linear(
558
A,
559
B,
@@ -562,7 +564,7 @@ def gemm_4bit_impl(
562
564
None,
563
565
566
state.blocksize,
- ipex_cpu.quantization.WoqLowpMode.BF16,
567
+ compute_dtype,
568
1,
569
state.compensation,
570
)
0 commit comments