Skip to content

Commit cb76e8d

Browse files
authored
Merge branch 'master' into feat/skorch-compatible
2 parents aa5484d + a10ba71 commit cb76e8d

File tree

11 files changed

+171
-103
lines changed

11 files changed

+171
-103
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313

1414
### Changed
1515

16+
- Updated `ExcelFormer` implementation and related scripts ([#391](https://github.com/pyg-team/pytorch-frame/pull/391))
17+
1618
### Deprecated
1719

1820
### Removed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
</div>
3030

31-
**[Documentation](https://pytorch-frame.readthedocs.io)**
31+
**[Documentation](https://pytorch-frame.readthedocs.io)** | **[Paper](https://arxiv.org/abs/2404.00776)**
3232

3333
PyTorch Frame is a deep learning extension for [PyTorch](https://pytorch.org/), designed for heterogeneous tabular data with different column types, including numerical, categorical, time, text, and images. It offers a modular framework for implementing existing and future methods. The library features methods from state-of-the-art models, user-friendly mini-batch loaders, benchmark datasets, and interfaces for custom data integration.
3434

@@ -80,7 +80,7 @@ PyTorch Frame democratizes deep learning research for tabular data, catering to
8080
PyTorch Frame builds directly upon PyTorch, ensuring a smooth transition for existing PyTorch users. Key features include:
8181

8282
* **Diverse column types**:
83-
PyTorch Frame supports learning across various column types: `numerical`, `categorical`, `multicategorical`, `text_embedded`, `text_tokenized`, `timestamp`, and `embedding`. See [here](https://pytorch-frame.readthedocs.io/en/latest/handling_advanced_stypes/handle_heterogeneous_stypes.html) for the detailed tutorial.
83+
PyTorch Frame supports learning across various column types: `numerical`, `categorical`, `multicategorical`, `text_embedded`, `text_tokenized`, `timestamp`, `image_embedded`, and `embedding`. See [here](https://pytorch-frame.readthedocs.io/en/latest/handling_advanced_stypes/handle_heterogeneous_stypes.html) for the detailed tutorial.
8484
* **Modular model design**:
8585
Enables modular deep learning model implementations, promoting reusability, clear coding, and experimentation flexibility. Further details in the [architecture overview](#architecture-overview).
8686
* **Models**
@@ -96,7 +96,7 @@ PyTorch Frame builds directly upon PyTorch, ensuring a smooth transition for exi
9696
Models in PyTorch Frame follow a modular design of `FeatureEncoder`, `TableConv`, and `Decoder`, as shown in the figure below:
9797

9898
<p align="center">
99-
<img width="100%" src="https://raw.githubusercontent.com/pyg-team/pytorch-frame/master/docs/source/_figures/modular.png" />
99+
<img width="50%" src="https://raw.githubusercontent.com/pyg-team/pytorch-frame/master/docs/source/_figures/architecture.png" />
100100
</p>
101101

102102
In essence, this modular setup empowers users to effortlessly experiment with myriad architectures:

benchmark/data_frame_benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@
229229
'diam_dropout': [0, 0.2],
230230
'residual_dropout': [0, 0.2],
231231
'aium_dropout': [0, 0.2],
232+
'mixup': [None, 'feature', 'hidden'],
233+
'beta': [0.5],
232234
'num_cols': [train_tensor_frame.num_cols],
233235
}
234236
train_search_space = {
@@ -257,7 +259,8 @@ def train(
257259
tf = tf.to(device)
258260
y = tf.y
259261
if isinstance(model, ExcelFormer):
260-
pred, y = model.forward_mixup(tf)
262+
# Train with FEAT-MIX or HIDDEN-MIX
263+
pred, y = model(tf, mixup_encoded=True)
261264
elif isinstance(model, Trompt):
262265
# Trompt uses the layer-wise loss
263266
pred = model.forward_stacked(tf)

docs/source/_figures/architecture.png

273 KB
Loading

examples/excelformer.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1-
"""Reported (reproduced) accuracy(rmse for regression task) of ExcelFormer
1+
"""Reported (reproduced) accuracy (for multi-classification task), auc
2+
(for binary classification task) and rmse (for regression task)
23
based on Table 1 of the paper https://arxiv.org/abs/2301.02819.
34
ExcelFormer uses the same train-validation-test split as the Yandex paper.
4-
5-
california_housing: 0.4587 (0.4733) num_layers=5, num_heads=4, num_layers=5,
6-
channels=32, lr: 0.001,
7-
jannis : 72.51 (72.38) num_heads=32, lr: 0.0001
8-
covtype: 97.17 (95.37)
9-
helena: 38.20 (36.80)
10-
higgs_small: 80.75 (65.17) lr: 0.0001
5+
The reproduced results are based on Z-score Normalization, and the
6+
reported ones are based on :class:`QuantileTransformer` preprocessing
7+
in the Sklearn Python package. The above preprocessing is applied
8+
to numerical features.
9+
10+
california_housing: 0.4587 (0.4550) mixup: feature, num_layers: 3,
11+
gamma: 1.00, epochs: 300
12+
jannis : 72.51 (72.80) mixup: feature
13+
covtype: 97.17 (97.02) mixup: hidden
14+
helena: 38.20 (37.68) mixup: feature
15+
higgs_small: 80.75 (79.27) mixup: hidden
1116
"""
1217
import argparse
1318
import os.path as osp
1419

1520
import torch
1621
import torch.nn.functional as F
1722
from torch.optim.lr_scheduler import ExponentialLR
23+
from torchmetrics import AUROC, Accuracy, MeanSquaredError
1824
from tqdm import tqdm
1925

2026
from torch_frame.data.loader import DataLoader
@@ -23,14 +29,16 @@
2329
from torch_frame.transforms import CatToNumTransform, MutualInformationSort
2430

2531
parser = argparse.ArgumentParser()
26-
parser.add_argument('--dataset', type=str, default='higgs_small')
32+
parser.add_argument('--dataset', type=str, default='california_housing')
33+
parser.add_argument('--mixup', type=str, default=None,
34+
choices=[None, 'feature', 'hidden'])
2735
parser.add_argument('--channels', type=int, default=256)
2836
parser.add_argument('--batch_size', type=int, default=512)
2937
parser.add_argument('--num_heads', type=int, default=4)
3038
parser.add_argument('--num_layers', type=int, default=5)
3139
parser.add_argument('--lr', type=float, default=0.001)
40+
parser.add_argument('--gamma', type=float, default=0.95)
3241
parser.add_argument('--epochs', type=int, default=100)
33-
parser.add_argument('--mixup', type=bool, default=True)
3442
parser.add_argument('--compile', action='store_true')
3543
args = parser.parse_args()
3644

@@ -76,6 +84,16 @@
7684
else:
7785
out_channels = 1
7886

87+
is_binary_class = is_classification and out_channels == 2
88+
89+
if is_binary_class:
90+
metric_computer = AUROC(task='binary')
91+
elif is_classification:
92+
metric_computer = Accuracy(task='multiclass', num_classes=out_channels)
93+
else:
94+
metric_computer = MeanSquaredError()
95+
metric_computer = metric_computer.to(device)
96+
7997
model = ExcelFormer(
8098
in_channels=args.channels,
8199
out_channels=out_channels,
@@ -85,12 +103,13 @@
85103
residual_dropout=0.,
86104
diam_dropout=0.3,
87105
aium_dropout=0.,
106+
mixup=args.mixup,
88107
col_stats=mutual_info_sort.transformed_stats,
89108
col_names_dict=train_tensor_frame.col_names_dict,
90109
).to(device)
91110
model = torch.compile(model, dynamic=True) if args.compile else model
92111
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
93-
lr_scheduler = ExponentialLR(optimizer, gamma=0.95)
112+
lr_scheduler = ExponentialLR(optimizer, gamma=args.gamma)
94113

95114

96115
def train(epoch: int) -> float:
@@ -99,8 +118,10 @@ def train(epoch: int) -> float:
99118

100119
for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'):
101120
tf = tf.to(device)
102-
pred_mixedup, y_mixedup = model.forward_mixup(tf)
121+
# Train with FEAT-MIX or HIDDEN-MIX
122+
pred_mixedup, y_mixedup = model(tf, mixup_encoded=True)
103123
if is_classification:
124+
# Softly mixed one-hot labels
104125
loss = F.cross_entropy(pred_mixedup, y_mixedup)
105126
else:
106127
loss = F.mse_loss(pred_mixedup.view(-1), y_mixedup.view(-1))
@@ -115,29 +136,26 @@ def train(epoch: int) -> float:
115136
@torch.no_grad()
116137
def test(loader: DataLoader) -> float:
117138
model.eval()
118-
accum = total_count = 0
119-
139+
metric_computer.reset()
120140
for tf in loader:
121141
tf = tf.to(device)
122142
pred = model(tf)
123-
if is_classification:
143+
if is_binary_class:
144+
metric_computer.update(pred[:, 1], tf.y)
145+
elif is_classification:
124146
pred_class = pred.argmax(dim=-1)
125-
accum += float((tf.y == pred_class).sum())
147+
metric_computer.update(pred_class, tf.y)
126148
else:
127-
accum += float(
128-
F.mse_loss(pred.view(-1), tf.y.view(-1), reduction='sum'))
129-
total_count += len(tf.y)
149+
metric_computer.update(pred.view(-1), tf.y.view(-1))
130150

131151
if is_classification:
132-
accuracy = accum / total_count
133-
return accuracy
152+
return metric_computer.compute().item()
134153
else:
135-
rmse = (accum / total_count)**0.5
136-
return rmse
154+
return metric_computer.compute().item()**0.5
137155

138156

139157
if is_classification:
140-
metric = 'Acc'
158+
metric = 'Acc' if not is_binary_class else 'AUC'
141159
best_val_metric = 0
142160
best_test_metric = 0
143161
else:

test/nn/models/test_excelformer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
TaskType.MULTICLASS_CLASSIFICATION,
1616
])
1717
@pytest.mark.parametrize('batch_size', [0, 5])
18-
def test_excelformer(task_type, batch_size):
18+
@pytest.mark.parametrize('mixup', [None, 'feature', 'hidden'])
19+
def test_excelformer(task_type, batch_size, mixup):
1920
in_channels = 8
2021
num_heads = 2
2122
num_layers = 6
@@ -35,6 +36,7 @@ def test_excelformer(task_type, batch_size):
3536
num_cols=num_cols,
3637
num_layers=num_layers,
3738
num_heads=num_heads,
39+
mixup=mixup,
3840
col_stats=dataset.col_stats,
3941
col_names_dict=tensor_frame.col_names_dict,
4042
)
@@ -46,7 +48,9 @@ def test_excelformer(task_type, batch_size):
4648

4749
# Test the mixup forward pass
4850
feat_num = copy.copy(tensor_frame.feat_dict[stype.numerical])
49-
out_mixedup, y_mixedup = model.forward_mixup(tensor_frame)
51+
# Set lazy mutual information scores for `feature` mixup
52+
tensor_frame.mi_scores = torch.rand(torch.Size((feat_num.shape[1], )))
53+
out_mixedup, y_mixedup = model(tensor_frame, mixup_encoded=True)
5054
assert out_mixedup.shape == (batch_size, out_channels)
5155
# Make sure the numerical feature is not modified.
5256
assert torch.allclose(feat_num, tensor_frame.feat_dict[stype.numerical])

test/transforms/test_cat_to_num_transform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ def test_cat_to_num_transform_on_categorical_only_dataset(with_nan):
6565
out.col_names_dict[stype.numerical]) == ((dataset.num_classes - 1) *
6666
total_cols))
6767

68+
tensor_frame.feat_dict[stype.categorical] += 1
69+
with pytest.raises(RuntimeError, match="contains new category"):
70+
# Raise informative error when input tensor frame contains new category
71+
out = transform(tensor_frame)
72+
6873

6974
@pytest.mark.parametrize('task_type', [
7075
TaskType.MULTICLASS_CLASSIFICATION, TaskType.REGRESSION,

torch_frame/datasets/fake.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,22 @@ def __init__(
8181
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
8282
labels = np.random.randint(0, 3, size=(num_rows, ))
8383
if num_rows < 3:
84-
raise ValueError("Number of rows needs to be at"
85-
" least 3 for multiclass classification")
84+
raise ValueError("Number of rows needs to be at "
85+
"least 3 for multiclass classification")
8686
# make sure every label exists
8787
labels[0] = 0
8888
labels[1] = 1
8989
labels[2] = 2
9090
df_dict = {'target': labels}
9191
col_to_stype = {'target': stype.categorical}
9292
elif task_type == TaskType.BINARY_CLASSIFICATION:
93-
df_dict = {'target': np.random.randint(0, 2, size=(num_rows, ))}
93+
labels = np.random.randint(0, 2, size=(num_rows, ))
94+
if num_rows < 2:
95+
raise ValueError("Number of rows needs to be at "
96+
"least 2 for binary classification")
97+
labels[0] = 0
98+
labels[1] = 1
99+
df_dict = {'target': labels}
94100
col_to_stype = {'target': stype.categorical}
95101
else:
96102
raise ValueError(

0 commit comments

Comments
 (0)