Skip to content

Kernels generated with use_fp16_qk_reductions=true break the LogitsTransform implementation used by prefill kernels #936

Closed
@diptorupd

Description

@diptorupd

I want to clarify the semantics of the LogitsTransform function declared on :

REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {

The logits parameter is templated and presumably can support __half. However, the computation of the value of logits on

logits = float(math::tanh(logits * soft_cap_pre_tanh_scale));
causes a compilation failure as the operator * cannot be resolved when the types are __half and float. Also, the assignment will also likely break because of the implicit float to __half conversion.

The LogitsTransform template is invoked inside the prefill kernel on line

variant.LogitsTransform(params, s_frag[mma_q][mma_kv][reg_id], batch_idx, q_idx, kv_idx,

I discovered this issue when working on fixing #806 and compiling kernels that were generated with the use_fp16_qk_reductions=true flag passed to aot_build_utils.generate.

I can apply a fix using a constexpr cast from fp16 to fp32 and vice-versa either at call-site or inside LogitsTransform . But, before I do any mechanical changes wanted to clarify the intent of the implementation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions