Skip to content

Commit db5a29b

Browse files
authored
[Bugfix] Fix LoRA test (#18518)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 5179777 commit db5a29b

File tree

2 files changed

+73
-65
lines changed

2 files changed

+73
-65
lines changed

tests/lora/test_lora_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def run_check(fn, args, expected: list):
6969
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
7070
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
7171

72-
# Remove all LoRAs
72+
# Remove all LoRAs.
7373
run_check(llm.remove_lora, 13, [12, 10, 11])
7474
run_check(llm.remove_lora, 12, [10, 11])
7575
run_check(llm.remove_lora, 11, [10])

tests/v1/sample/test_topk_topp_sampler.py

Lines changed: 72 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,40 @@
1616
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
1717

1818

19+
@pytest.fixture(autouse=True)
20+
def reset_default_device():
21+
"""
22+
Explicitly set the default device, which can affect subsequent tests.
23+
Adding this fixture helps avoid this problem.
24+
"""
25+
original_device = torch.get_default_device()
26+
yield
27+
torch.set_default_device(original_device)
28+
29+
1930
def test_topk_impl_equivalance():
2031

21-
with torch.device(DEVICE):
22-
generator = Generator(device=DEVICE).manual_seed(33)
32+
torch.set_default_device(DEVICE)
33+
generator = Generator(device=DEVICE).manual_seed(33)
2334

24-
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
35+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
2536

26-
# Random top-k values between 1 and 9.
27-
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
37+
# Random top-k values between 1 and 9.
38+
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
2839

29-
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
30-
k.masked_fill_(
31-
torch.randint(0,
32-
2, (BATCH_SIZE, ),
33-
generator=generator,
34-
dtype=bool), VOCAB_SIZE)
40+
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
41+
k.masked_fill_(
42+
torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
43+
VOCAB_SIZE)
3544

36-
# Top-k only implementation
37-
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
45+
# Top-k only implementation
46+
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
3847

39-
# Top-p + top-k
40-
no_op_top_p = torch.tensor([1.0])
41-
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
48+
# Top-p + top-k
49+
no_op_top_p = torch.tensor([1.0])
50+
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
4251

43-
assert torch.allclose(result1, result2)
52+
assert torch.allclose(result1, result2)
4453

4554

4655
def test_flashinfer_sampler():
@@ -58,50 +67,49 @@ def test_flashinfer_sampler():
5867
pytest.skip(
5968
"FlashInfer not installed or not available on this platform.")
6069

61-
with torch.device(DEVICE):
62-
generator = Generator(device=DEVICE).manual_seed(42)
63-
64-
# Generate random logits
65-
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
66-
67-
# Generate various top-k and top-p values
68-
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
69-
p_values = torch.rand(
70-
(BATCH_SIZE, ),
71-
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
72-
73-
# Sometimes disable top-k (k=vocab_size)
74-
k_values.masked_fill_(
75-
torch.randint(0,
76-
2, (BATCH_SIZE, ),
77-
generator=generator,
78-
dtype=torch.bool), VOCAB_SIZE)
79-
80-
# Sometimes disable top-p (p=1.0)
81-
p_values.masked_fill_(
82-
torch.randint(0,
83-
2, (BATCH_SIZE, ),
84-
generator=generator,
85-
dtype=torch.bool), 1.0)
86-
87-
python_logits = apply_top_k_top_p(
88-
logits=logits.clone(),
89-
k=k_values,
90-
p=p_values,
91-
)
92-
python_probs = torch.softmax(python_logits, dim=-1)
93-
94-
# FlashInfer only exposed renorm interfaces for probs so convert first
95-
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
96-
flashinfer_probs = top_k_renorm_probs(
97-
probs=flashinfer_probs,
98-
top_k=k_values,
99-
)
100-
flashinfer_probs = top_p_renorm_probs(
101-
probs=flashinfer_probs,
102-
top_p=p_values,
103-
)
104-
105-
# Compare the results
106-
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
107-
"FlashInfer and Python sampling implementations do not match!"
70+
torch.set_default_device(DEVICE)
71+
generator = Generator(device=DEVICE).manual_seed(42)
72+
73+
# Generate random logits
74+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
75+
76+
# Generate various top-k and top-p values
77+
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
78+
p_values = torch.rand(
79+
(BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
80+
81+
# Sometimes disable top-k (k=vocab_size)
82+
k_values.masked_fill_(
83+
torch.randint(0,
84+
2, (BATCH_SIZE, ),
85+
generator=generator,
86+
dtype=torch.bool), VOCAB_SIZE)
87+
88+
# Sometimes disable top-p (p=1.0)
89+
p_values.masked_fill_(
90+
torch.randint(0,
91+
2, (BATCH_SIZE, ),
92+
generator=generator,
93+
dtype=torch.bool), 1.0)
94+
95+
python_logits = apply_top_k_top_p(
96+
logits=logits.clone(),
97+
k=k_values,
98+
p=p_values,
99+
)
100+
python_probs = torch.softmax(python_logits, dim=-1)
101+
102+
# FlashInfer only exposed renorm interfaces for probs so convert first
103+
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
104+
flashinfer_probs = top_k_renorm_probs(
105+
probs=flashinfer_probs,
106+
top_k=k_values,
107+
)
108+
flashinfer_probs = top_p_renorm_probs(
109+
probs=flashinfer_probs,
110+
top_p=p_values,
111+
)
112+
113+
# Compare the results
114+
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
115+
"FlashInfer and Python sampling implementations do not match!"

0 commit comments

Comments
 (0)