Skip to content

Commit 21b1756

Browse files
committed
Add support for Intel GPU to GAT example
Signed-off-by: jafraustro <[email protected]>
1 parent 5dfeb46 commit 21b1756

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

Diff for: gat/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ options:
8989
epochs to wait for print training and validation evaluation (default: 20)
9090
--no-cuda disables CUDA training
9191
--no-mps disables macOS GPU training
92+
--no-xpu disables XPU training
9293
--dry-run quickly check a single pass
9394
--seed S random seed (default: 13)
9495
```

Diff for: gat/main.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,17 @@ def test(model, criterion, input, target, mask):
303303
help='dimension of the hidden representation (default: 64)')
304304
parser.add_argument('--num-heads', type=int, default=8,
305305
help='number of the attention heads (default: 4)')
306-
parser.add_argument('--concat-heads', action='store_true', default=False,
306+
parser.add_argument('--concat-heads', action='store_true',
307307
help='wether to concatinate attention heads, or average over them (default: False)')
308308
parser.add_argument('--val-every', type=int, default=20,
309309
help='epochs to wait for print training and validation evaluation (default: 20)')
310-
parser.add_argument('--no-cuda', action='store_true', default=False,
310+
parser.add_argument('--no-cuda', action='store_true',
311311
help='disables CUDA training')
312-
parser.add_argument('--no-mps', action='store_true', default=False,
312+
parser.add_argument('--no-xpu', action='store_true',
313+
help='disables XPU training')
314+
parser.add_argument('--no-mps', action='store_true',
313315
help='disables macOS GPU training')
314-
parser.add_argument('--dry-run', action='store_true', default=False,
316+
parser.add_argument('--dry-run', action='store_true',
315317
help='quickly check a single pass')
316318
parser.add_argument('--seed', type=int, default=13, metavar='S',
317319
help='random seed (default: 13)')
@@ -320,12 +322,15 @@ def test(model, criterion, input, target, mask):
320322
torch.manual_seed(args.seed)
321323
use_cuda = not args.no_cuda and torch.cuda.is_available()
322324
use_mps = not args.no_mps and torch.backends.mps.is_available()
325+
use_xpu = not args.no_xpu and torch.xpu.is_available()
323326

324327
# Set the device to run on
325328
if use_cuda:
326329
device = torch.device('cuda')
327330
elif use_mps:
328331
device = torch.device('mps')
332+
elif use_xpu:
333+
device = torch.device('xpu')
329334
else:
330335
device = torch.device('cpu')
331336
print(f'Using {device} device')

0 commit comments

Comments
 (0)