Closed
Description
I want to clarify the semantics of the LogitsTransform
function declared on :
The logits
parameter is templated and presumably can support __half
. However, the computation of the value of logits
on
operator *
cannot be resolved when the types are __half
and float
. Also, the assignment will also likely break because of the implicit float
to __half
conversion.
The LogitsTransform
template is invoked inside the prefill
kernel on line
I discovered this issue when working on fixing #806 and compiling kernels that were generated with the use_fp16_qk_reductions=true
flag passed to aot_build_utils.generate
.
I can apply a fix using a constexpr
cast from fp16 to fp32 and vice-versa either at call-site or inside LogitsTransform
. But, before I do any mechanical changes wanted to clarify the intent of the implementation.
Metadata
Metadata
Assignees
Labels
No labels