-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathtab_transformer.py
140 lines (119 loc) · 4.71 KB
/
tab_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Reported (reproduced, xgboost) results of of TabTransformer model based on
Table 1 of original paper https://arxiv.org/abs/2012.06678.
adult: 73.8 (88.86)
bank-marketing: 93.4 (90.83, 81.00)
dota2: 63.3 (52.44, 53.75)
"""
import argparse
import os.path as osp
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm
from torch_frame.data import DataLoader
from torch_frame.datasets import AdultCensusIncome, BankMarketing, Dota2
from torch_frame.nn import TabTransformer
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='dota2',
choices=["adult", "dota2", "bank-marketing"])
parser.add_argument('--channels', type=int, default=32)
parser.add_argument('--num_heads', type=int, default=8)
parser.add_argument('--num_layers', type=int, default=6)
parser.add_argument('--encoder_pad_size', type=int, default=2)
parser.add_argument('--attention_dropout', type=float, default=0.3)
parser.add_argument('--ffn_dropout', type=float, default=0.3)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--compile', action='store_true')
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Prepare datasets
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
args.dataset)
if args.dataset == "adult":
dataset = AdultCensusIncome(root=path)
elif args.dataset == "bank-marketing":
dataset = BankMarketing(root=path)
elif args.dataset == "dota2":
dataset = Dota2(root=path)
else:
raise ValueError(f"Unsupported dataset called {args.dataset}")
dataset.materialize()
assert dataset.task_type.is_classification
dataset = dataset.shuffle()
# Split ratio following https://arxiv.org/abs/2012.06678
# 65% is used for training. 15% of is used for validation.
# The final reminder is used for testing.
train_dataset, val_dataset, test_dataset = dataset[:0.65], dataset[
0.65:0.80], dataset[0.80:]
# Set up data loaders
train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size,
shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size)
test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size)
# Set up model and optimizer
model = TabTransformer(
channels=args.channels,
out_channels=dataset.num_classes,
num_layers=args.num_layers,
num_heads=args.num_heads,
encoder_pad_size=args.encoder_pad_size,
attn_dropout=args.attention_dropout,
ffn_dropout=args.ffn_dropout,
col_stats=dataset.col_stats,
col_names_dict=train_tensor_frame.col_names_dict,
).to(device)
model = torch.compile(model, dynamic=True) if args.compile else model
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
lr_scheduler = ExponentialLR(optimizer, gamma=0.95)
def train(epoch: int) -> float:
model.train()
loss_accum = total_count = 0
for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'):
tf = tf.to(device)
pred = model.forward(tf)
loss = F.cross_entropy(pred, tf.y)
optimizer.zero_grad()
loss.backward()
loss_accum += float(loss) * len(tf.y)
total_count += len(tf.y)
optimizer.step()
return loss_accum / total_count
@torch.no_grad()
def test(loader: DataLoader) -> float:
model.eval()
all_preds = []
all_labels = []
for tf in loader:
tf = tf.to(device)
pred = model(tf)
all_labels.append(tf.y.cpu())
all_preds.append(pred[:, 1].detach().cpu())
all_labels = torch.cat(all_labels).numpy()
all_preds = torch.cat(all_preds).numpy()
# Compute the overall AUC
overall_auc = roc_auc_score(all_labels, all_preds)
return overall_auc
metric = 'AUC'
best_val_metric = 0
best_test_metric = 0
for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_metric = test(train_loader)
val_metric = test(val_loader)
test_metric = test(test_loader)
if val_metric > best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric
print(f'Train Loss: {train_loss:.4f}, Train {metric}: {train_metric:.4f}, '
f'Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}')
lr_scheduler.step()
print(f'Best Val {metric}: {best_val_metric:.4f}, '
f'Best Test {metric}: {best_test_metric:.4f}')