@@ -70,7 +70,7 @@ def test_rotary_embedding(
70
70
if rotary_dim is None :
71
71
rotary_dim = head_size
72
72
rope = get_rope (head_size , rotary_dim , max_position , base , is_neox_style )
73
- rope = rope .to (dtype = dtype )
73
+ rope = rope .to (dtype = dtype , device = torch . get_default_device () )
74
74
75
75
positions = torch .randint (0 , max_position , (batch_size , seq_len ))
76
76
query_shape = tensor_shape_fn (batch_size , seq_len , num_heads , head_size )
@@ -125,7 +125,7 @@ def test_batched_rotary_embedding(
125
125
"rope_type" : "linear" ,
126
126
"factor" : (1 , )
127
127
})
128
- rope = rope .to (dtype = dtype )
128
+ rope = rope .to (dtype = dtype , device = torch . get_default_device () )
129
129
130
130
positions = torch .randint (0 , max_position , (batch_size , seq_len ))
131
131
query_shape = tensor_shape_fn (batch_size , seq_len , num_heads , head_size )
@@ -184,7 +184,7 @@ def test_batched_rotary_embedding_multi_lora(
184
184
"rope_type" : "linear" ,
185
185
"factor" : tuple (scaling_factors )
186
186
})
187
- rope = rope .to (dtype = dtype )
187
+ rope = rope .to (dtype = dtype , device = torch . get_default_device () )
188
188
189
189
positions = torch .randint (0 , max_position , (batch_size , seq_len ))
190
190
query = torch .randn (batch_size ,
0 commit comments