Skip to content

Commit 6dfc9d8

Browse files
authored
bugfix: fix unittests not yield error in AOT mode (#657)
The unittests in AOT mode failed since #629 because we didn't use return instead of yield in warmup functions, this PR fixes the issue.
1 parent 4c15777 commit 6dfc9d8

10 files changed

+10
-10
lines changed

tests/test_alibi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
@pytest.fixture(autouse=True, scope="module")
2727
def warmup_jit():
2828
if flashinfer.jit.has_prebuilt_ops:
29-
return
29+
yield
3030
try:
3131
flashinfer.jit.parallel_load_modules(
3232
jit_decode_attention_func_args(

tests/test_batch_decode_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@pytest.fixture(autouse=True, scope="module")
2525
def warmup_jit():
2626
if flashinfer.jit.has_prebuilt_ops:
27-
return
27+
yield
2828
try:
2929
flashinfer.jit.parallel_load_modules(
3030
jit_decode_attention_func_args(

tests/test_batch_prefill_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@pytest.fixture(autouse=True, scope="module")
2525
def warmup_jit():
2626
if flashinfer.jit.has_prebuilt_ops:
27-
return
27+
yield
2828
try:
2929
flashinfer.jit.parallel_load_modules(
3030
jit_prefill_attention_func_args(

tests/test_block_sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
@pytest.fixture(autouse=True, scope="module")
2727
def warmup_jit():
2828
if flashinfer.jit.has_prebuilt_ops:
29-
return
29+
yield
3030
try:
3131
flashinfer.jit.parallel_load_modules(
3232
jit_decode_attention_func_args(

tests/test_logits_cap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
@pytest.fixture(autouse=True, scope="module")
2727
def warmup_jit():
2828
if flashinfer.jit.has_prebuilt_ops:
29-
return
29+
yield
3030
try:
3131
flashinfer.jit.parallel_load_modules(
3232
jit_decode_attention_func_args(

tests/test_non_contiguous_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@pytest.fixture(autouse=True, scope="module")
99
def warmup_jit():
1010
if flashinfer.jit.has_prebuilt_ops:
11-
return
11+
yield
1212
try:
1313
flashinfer.jit.parallel_load_modules(
1414
jit_decode_attention_func_args(

tests/test_non_contiguous_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@pytest.fixture(autouse=True, scope="module")
2525
def warmup_jit():
2626
if flashinfer.jit.has_prebuilt_ops:
27-
return
27+
yield
2828
try:
2929
flashinfer.jit.parallel_load_modules(
3030
jit_prefill_attention_func_args(

tests/test_shared_prefix_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@pytest.fixture(autouse=True, scope="module")
2525
def warmup_jit():
2626
if flashinfer.jit.has_prebuilt_ops:
27-
return
27+
yield
2828
try:
2929
flashinfer.jit.parallel_load_modules(
3030
jit_decode_attention_func_args(

tests/test_sliding_window.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@pytest.fixture(autouse=True, scope="module")
2525
def warmup_jit():
2626
if flashinfer.jit.has_prebuilt_ops:
27-
return
27+
yield
2828
try:
2929
flashinfer.jit.parallel_load_modules(
3030
jit_decode_attention_func_args(

tests/test_tensor_cores_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@pytest.fixture(autouse=True, scope="module")
2525
def warmup_jit():
2626
if flashinfer.jit.has_prebuilt_ops:
27-
return
27+
yield
2828
try:
2929
flashinfer.jit.parallel_load_modules(
3030
jit_decode_attention_func_args(

0 commit comments

Comments
 (0)