Skip to content

Commit 433d129

Browse files
author
Varun Sundar Rabindranath
committed
add comments
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
1 parent 06127d3 commit 433d129

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

benchmarks/kernels/benchmark_lora.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,10 @@ def bench_torch_mm(ctx: BenchmarkContext,
722722
cuda_graph_nops: Optional[int] = None) -> TMeasurement:
723723
"""
724724
Benchmark basic torch.mm as a roofline.
725+
726+
When all the input tokens have the same LoRA ID, the LoRA kernels are just
727+
a matmul. This torch.mm benchmark serves as a roofline for that case.
728+
725729
input op_type is used in determining the m, k, n dimensions for the matmul.
726730
"""
727731

@@ -746,9 +750,10 @@ def bench_torch_mm(ctx: BenchmarkContext,
746750
# Make torch.mm kwargs
747751
mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)}
748752

749-
description = (f"torch.mm({dtype_to_str(dtype)}"
750-
f"x{dtype_to_str(dtype)}"
751-
f"=>{dtype_to_str(dtype)})")
753+
description = (
754+
f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}"
755+
f"x{dtype_to_str(dtype)}"
756+
f"=>{dtype_to_str(dtype)})")
752757
cuda_graph_params = None
753758
if cuda_graph_nops:
754759
cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
@@ -777,10 +782,18 @@ def print_timers(timers: List[TMeasurement],
777782
compare.print()
778783

779784
if args and args.cuda_graph_nops:
780-
print(f"The timings reported above is for {args.cuda_graph_nops} "
781-
"consecutive invocations of the benchmarking functions. "
782-
f"Please divide by {args.cuda_graph_nops} for single invocation "
783-
"timings ")
785+
print(
786+
f"Note : The timings reported above is for {args.cuda_graph_nops} "
787+
"consecutive invocations of the benchmarking functions. "
788+
f"Please divide by {args.cuda_graph_nops} for single invocation "
789+
"timings.")
790+
791+
print("Note on Comparison with torch.mm : The torch.mm numbers are "
792+
"benchmark numbers of a simple matmul emulating the single lora "
793+
"case. It is provided as a roofline for comparing our LoRA Kernel "
794+
"implementations. It is expected that the LoRA kernels will be "
795+
"slower than torch.mm in cases where num_loras is big. But for "
796+
"small num_loras the goal should be to match the torch.mm numbers.")
784797

785798

786799
def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):

0 commit comments

Comments
 (0)