Skip to content

[Feature Request] Add an argument to control the number of CTAs used in attention APIs #591

Open
@yzh119

Description

@yzh119

Nanoflow overlaps decode/prefill/communication by limiting the number of SMs each kernel uses (in practice it's controlled by grid size), current nanoflow implementation modifies flashinfer kernels to support launching flashinfer kernels with specified grid size.

As flashinfer changes all kernel implementation to persistent kernels, we can support specifying the number of SM's at flashinfer side. More specifically, we can add an argument num_ctas at our plan functions to specify the grid size, and user can directly control it in Python.

The benefit of this feature include:

  1. Keep nanoflow's development in pace with latest flashinfer features (JIT/FA3/customization/etc).
  2. Making it possible to port nanoflow to pytorch. It may sacrifice some performance but I think overall it's good for nanoflow's adoption.
  3. Making it possible to use nanoflow-style parallelism in other llm serving frameworks such as vllm/sglang/mlc-llm/etc.

We also need to support such arguments in GEMM APIs by wrapping cutlass gemm implementations, leave them for future work.

cc @serendipity-zk @happierpig

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions