Skip to content

Commit a83467c

Browse files
mgoinlulmer
authored andcommitted
Fix benchmark_moe.py tuning for CUDA devices (vllm-project#14164)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent 5959ca9 commit a83467c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import time
5+
from contextlib import nullcontext
56
from datetime import datetime
67
from itertools import product
78
from typing import Any, TypedDict
@@ -412,7 +413,8 @@ def tune(
412413
hidden_size, search_space,
413414
is_fp16, topk)
414415

415-
with torch.cuda.device(self.device_id):
416+
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
417+
) else nullcontext():
416418
for config in tqdm(search_space):
417419
try:
418420
kernel_time = benchmark_config(

0 commit comments

Comments
 (0)