@@ -303,15 +303,17 @@ def test(model, criterion, input, target, mask):
303
303
help = 'dimension of the hidden representation (default: 64)' )
304
304
parser .add_argument ('--num-heads' , type = int , default = 8 ,
305
305
help = 'number of the attention heads (default: 4)' )
306
- parser .add_argument ('--concat-heads' , action = 'store_true' , default = False ,
306
+ parser .add_argument ('--concat-heads' , action = 'store_true' ,
307
307
help = 'wether to concatinate attention heads, or average over them (default: False)' )
308
308
parser .add_argument ('--val-every' , type = int , default = 20 ,
309
309
help = 'epochs to wait for print training and validation evaluation (default: 20)' )
310
- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
310
+ parser .add_argument ('--no-cuda' , action = 'store_true' ,
311
311
help = 'disables CUDA training' )
312
- parser .add_argument ('--no-mps' , action = 'store_true' , default = False ,
312
+ parser .add_argument ('--no-xpu' , action = 'store_true' ,
313
+ help = 'disables XPU training' )
314
+ parser .add_argument ('--no-mps' , action = 'store_true' ,
313
315
help = 'disables macOS GPU training' )
314
- parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
316
+ parser .add_argument ('--dry-run' , action = 'store_true' ,
315
317
help = 'quickly check a single pass' )
316
318
parser .add_argument ('--seed' , type = int , default = 13 , metavar = 'S' ,
317
319
help = 'random seed (default: 13)' )
@@ -320,12 +322,15 @@ def test(model, criterion, input, target, mask):
320
322
torch .manual_seed (args .seed )
321
323
use_cuda = not args .no_cuda and torch .cuda .is_available ()
322
324
use_mps = not args .no_mps and torch .backends .mps .is_available ()
325
+ use_xpu = not args .no_xpu and torch .xpu .is_available ()
323
326
324
327
# Set the device to run on
325
328
if use_cuda :
326
329
device = torch .device ('cuda' )
327
330
elif use_mps :
328
331
device = torch .device ('mps' )
332
+ elif use_xpu :
333
+ device = torch .device ('xpu' )
329
334
else :
330
335
device = torch .device ('cpu' )
331
336
print (f'Using { device } device' )
0 commit comments