Skip to content

Commit dc0f7cc

Browse files
authored
[BugFix] Enhance test_pos_encoding to support execution on multi-devices (#13187)
Signed-off-by: wchen61 <[email protected]>
1 parent d3d547e commit dc0f7cc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/kernels/test_pos_encoding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_rotary_embedding(
7070
if rotary_dim is None:
7171
rotary_dim = head_size
7272
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
73-
rope = rope.to(dtype=dtype)
73+
rope = rope.to(dtype=dtype, device=torch.get_default_device())
7474

7575
positions = torch.randint(0, max_position, (batch_size, seq_len))
7676
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
@@ -125,7 +125,7 @@ def test_batched_rotary_embedding(
125125
"rope_type": "linear",
126126
"factor": (1, )
127127
})
128-
rope = rope.to(dtype=dtype)
128+
rope = rope.to(dtype=dtype, device=torch.get_default_device())
129129

130130
positions = torch.randint(0, max_position, (batch_size, seq_len))
131131
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
@@ -184,7 +184,7 @@ def test_batched_rotary_embedding_multi_lora(
184184
"rope_type": "linear",
185185
"factor": tuple(scaling_factors)
186186
})
187-
rope = rope.to(dtype=dtype)
187+
rope = rope.to(dtype=dtype, device=torch.get_default_device())
188188

189189
positions = torch.randint(0, max_position, (batch_size, seq_len))
190190
query = torch.randn(batch_size,

0 commit comments

Comments
 (0)