@@ -722,6 +722,10 @@ def bench_torch_mm(ctx: BenchmarkContext,
722
722
cuda_graph_nops : Optional [int ] = None ) -> TMeasurement :
723
723
"""
724
724
Benchmark basic torch.mm as a roofline.
725
+
726
+ When all the input tokens have the same LoRA ID, the LoRA kernels are just
727
+ a matmul. This torch.mm benchmark serves as a roofline for that case.
728
+
725
729
input op_type is used in determining the m, k, n dimensions for the matmul.
726
730
"""
727
731
@@ -746,9 +750,10 @@ def bench_torch_mm(ctx: BenchmarkContext,
746
750
# Make torch.mm kwargs
747
751
mm_kwargs = {'input' : ArgPool (As ), 'mat2' : ArgPool (Bs ), 'out' : ArgPool (Cs )}
748
752
749
- description = (f"torch.mm({ dtype_to_str (dtype )} "
750
- f"x{ dtype_to_str (dtype )} "
751
- f"=>{ dtype_to_str (dtype )} )" )
753
+ description = (
754
+ f"single-lora roofline using torch.mm ({ dtype_to_str (dtype )} "
755
+ f"x{ dtype_to_str (dtype )} "
756
+ f"=>{ dtype_to_str (dtype )} )" )
752
757
cuda_graph_params = None
753
758
if cuda_graph_nops :
754
759
cuda_graph_params = CudaGraphBenchParams (cuda_graph_nops )
@@ -777,10 +782,18 @@ def print_timers(timers: List[TMeasurement],
777
782
compare .print ()
778
783
779
784
if args and args .cuda_graph_nops :
780
- print (f"The timings reported above is for { args .cuda_graph_nops } "
781
- "consecutive invocations of the benchmarking functions. "
782
- f"Please divide by { args .cuda_graph_nops } for single invocation "
783
- "timings " )
785
+ print (
786
+ f"Note : The timings reported above is for { args .cuda_graph_nops } "
787
+ "consecutive invocations of the benchmarking functions. "
788
+ f"Please divide by { args .cuda_graph_nops } for single invocation "
789
+ "timings." )
790
+
791
+ print ("Note on Comparison with torch.mm : The torch.mm numbers are "
792
+ "benchmark numbers of a simple matmul emulating the single lora "
793
+ "case. It is provided as a roofline for comparing our LoRA Kernel "
794
+ "implementations. It is expected that the LoRA kernels will be "
795
+ "slower than torch.mm in cases where num_loras is big. But for "
796
+ "small num_loras the goal should be to match the torch.mm numbers." )
784
797
785
798
786
799
def run (args : argparse .Namespace , bench_ctxs : List [BenchmarkContext ]):
0 commit comments