diff --git a/README.md b/README.md index ec267e3d9..c677b46c9 100644 --- a/README.md +++ b/README.md @@ -231,9 +231,9 @@ The following chart shows the performance of various models on small regression | Model Name | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | |:--------------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------| -| XGBoost | **0.250±0.000** | 0.038±0.000 | 0.187±0.000 | 0.475±0.000 | 0.328±0.000 | 0.401±0.000 | **0.249±0.000** | 0.363±0.000 | 0.904±0.000 | 0.056±0.000 | 0.820±0.000 | **0.857±0.000** | 0.418±0.000 | +| XGBoost | **0.247±0.000** | 0.077±0.000 | 0.167±0.000 | 1.119±0.000 | 0.328±0.000 | 1.024±0.000 | **0.292±0.000** | 0.606±0.000 | **0.876±0.000** | 0.023±0.000 | **0.697±0.000** | 0.865±0.000 | 0.435±0.000 | | CatBoost | 0.265±0.000 | 0.062±0.000 | 0.128±0.000 | 0.336±0.000 | 0.346±0.000 | 0.443±0.000 | 0.375±0.000 | 0.273±0.000 | 0.881±0.000 | 0.040±0.000 | 0.756±0.000 | 0.876±0.000 | 0.439±0.000 | -| LightGBM | 0.253±0.000 | 0.054±0.000 | **0.112±0.000** | 0.302±0.000 | 0.325±0.000 | **0.384±0.000** | 0.295±0.000 | **0.272±0.000** | **0.877±0.000** | 0.011±0.000 | **0.702±0.000** | 0.863±0.000 | **0.395±0.000** | +| LightGBM | 0.253±0.000 | 0.054±0.000 | **0.112±0.000** | 0.302±0.000 | 0.325±0.000 | **0.384±0.000** | 0.295±0.000 | **0.272±0.000** | 0.877±0.000 | 0.011±0.000 | 0.702±0.000 | **0.863±0.000** | **0.395±0.000** | | Trompt | 0.261±0.003 | **0.015±0.005** | 0.118±0.001 | **0.262±0.001** | **0.323±0.001** | 0.418±0.003 | 0.329±0.009 | 0.312±0.002 | OOM | **0.008±0.001** | 0.779±0.006 | 0.874±0.004 | 0.424±0.005 | | ResNet | 0.288±0.006 | 0.018±0.003 | 0.124±0.001 | 0.268±0.001 | 0.335±0.001 | 0.434±0.004 | 0.325±0.012 | 0.324±0.004 | 0.895±0.005 | 0.036±0.002 | 0.794±0.006 | 0.875±0.004 | 0.468±0.004 | | FTTransformerBucket | 0.325±0.008 | 0.096±0.005 | 0.360±0.354 | 0.284±0.005 | 0.342±0.004 | 0.441±0.003 | 0.345±0.007 | 0.339±0.003 | OOM | 0.105±0.011 | 0.807±0.010 | 0.885±0.008 | 0.468±0.006 | diff --git a/benchmark/README.md b/benchmark/README.md index b72ec1d20..c234a4277 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -57,18 +57,18 @@ Metric: ROC-AUC, higher the better. Experimental setting: 20 Optuna search trials. 50 epochs of training. -| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | dataset_13 | -|:--------------------|:-----------------------|:------------------------|:-------------------------|:----------------------|:------------------------|:-----------------------|:----------------------|:-----------------------|:-----------------------|:-----------------------|:----------------------|:------------------------|:------------------------|:------------------------| -| XGBoost | **0.931±0.000 (41s)** | **1.000±0.000 (3s)** | 0.940±0.000 (389s) | **0.947±0.000 (42s)** | 0.885±0.000 (109s) | 0.966±0.000 (14s) | **0.862±0.000 (10s)** | **0.779±0.000 (79s)** | **0.984±0.000 (376s)** | 0.714±0.000 (10s) | 0.787±0.000 (9s) | 0.951±0.000 (103s) | **0.999±0.000 (434s)** | 0.925±0.000 (848s) | -| CatBoost | 0.930±0.000 (152s) | **1.000±0.000 (9s)** | 0.938±0.000 (164s) | 0.924±0.000 (29s) | 0.881±0.000 (27s) | 0.963±0.000 (48s) | 0.861±0.000 (12s) | 0.772±0.000 (10s) | 0.930±0.000 (91s) | 0.628±0.000 (10s) | **0.796±0.000 (15s)** | 0.948±0.000 (46s) | **0.998±0.000 (38s)** | 0.926±0.000 (115s) | -| LightGBM | **0.931±0.000 (15s)** | 0.999±0.000 (1s) | 0.943±0.000 (23s) | 0.943±0.000 (14s) | **0.887±0.000 (5s)** | **0.972±0.000 (11s)** | **0.862±0.000 (6s)** | 0.774±0.000 (3s) | 0.979±0.000 (41s) | **0.732±0.000** (13s) | 0.787±0.000 (3s) | 0.951±0.000 (13s) | 0.999±0.000 (10s) | **0.927±0.000 (24s)** | -| Trompt | 0.919±0.000 (9627s) | **1.000±0.000 (5341s)** | **0.945±0.000 (14679s)** | 0.942±0.001 (2752s) | 0.881±0.001 (2640s) | 0.964±0.001 (5173s) | 0.855±0.002 (4249s) | 0.778±0.002 (8789s) | 0.933±0.001 (9353s) | 0.686±0.008 (3105s) | 0.793±0.002 (8255s) | **0.952±0.001 (4876s)** | **1.000±0.000 (3558s)** | 0.916±0.001 (30002s) | -| ResNet | 0.917±0.000 (615s) | **1.000±0.000 (71s)** | 0.937±0.001 (787s) | 0.938±0.002 (230s) | 0.865±0.001 (183s) | 0.960±0.001 (349s) | 0.828±0.001 (248s) | 0.768±0.002 (205s) | 0.925±0.002 (958s) | 0.665±0.006 (140s) | **0.794±0.002 (76s)** | 0.946±0.002 (145s) | **1.000±0.000 (93s)** | 0.911±0.001 (880s) | -| MLP | 0.913±0.001 (112s) | **1.000±0.000 (45s)** | 0.934±0.001 (274s) | 0.938±0.001 (66s) | 0.863±0.002 (61s) | 0.953±0.000 (92s) | 0.830±0.001 (68s) | 0.769±0.002 (56s) | 0.903±0.002 (159s) | 0.666±0.015 (58s) | 0.789±0.001 (48s) | 0.940±0.002 (107s) | **1.000±0.000 (48s)** | 0.910±0.001 (149s) -| FTTransformerBucket | 0.915±0.001 (690s) | **0.999±0.001 (354s)** | 0.936±0.002 (1705s) | 0.939±0.002 (484s) | 0.876±0.002 (321s) | 0.960±0.001 (746s) | 0.857±0.000 (549s) | 0.771±0.003 (654s) | 0.909±0.002 (1177s) | 0.636±0.012 (244s) | 0.788±0.002 (710s) | 0.950±0.001 (510s) | **0.999±0.000 (634s)** | 0.913±0.001 (1164s) | +| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | dataset_13 | +|:--------------------|:----------------------|:------------------------|:-------------------------|:----------------------|:---------------------|:----------------------|:----------------------|:-----------------------|:-----------------------|:-----------------------|:----------------------|:------------------------|:------------------------|:-------------------------| +| XGBoost | **0.931±0.000 (41s)** | **1.000±0.000 (4s)** | 0.935±0.000 (16s) | **0.946±0.000 (26s)** | 0.881±0.000 (10s) | 0.951±0.000 (16s) | **0.862±0.000 (26s)** | **0.780±0.000 (11s)** | **0.983±0.000 (584s)** | **0.763±0.000 (240s)** | **0.795±0.000 (11s)** | 0.950±0.000 (479s) | **0.999±0.000 (148s)** | 0.926±0.000 (3042s) | +| CatBoost | 0.930±0.000 (152s) | **1.000±0.000 (9s)** | 0.938±0.000 (164s) | 0.924±0.000 (29s) | 0.881±0.000 (27s) | 0.963±0.000 (48s) | 0.861±0.000 (12s) | 0.772±0.000 (10s) | 0.930±0.000 (91s) | 0.628±0.000 (10s) | **0.796±0.000 (15s)** | 0.948±0.000 (46s) | **0.998±0.000 (38s)** | 0.926±0.000 (115s) | +| LightGBM | **0.931±0.000 (15s)** | 0.999±0.000 (1s) | 0.943±0.000 (23s) | 0.943±0.000 (14s) | **0.887±0.000 (5s)** | **0.972±0.000 (11s)** | **0.862±0.000 (6s)** | 0.774±0.000 (3s) | 0.979±0.000 (41s) | 0.732±0.000 (13s) | 0.787±0.000 (3s) | 0.951±0.000 (13s) | 0.999±0.000 (10s) | **0.927±0.000 (24s)** | +| Trompt | 0.919±0.000 (9627s) | **1.000±0.000 (5341s)** | **0.945±0.000 (14679s)** | 0.942±0.001 (2752s) | 0.881±0.001 (2640s) | 0.964±0.001 (5173s) | 0.855±0.002 (4249s) | 0.778±0.002 (8789s) | 0.933±0.001 (9353s) | 0.686±0.008 (3105s) | 0.793±0.002 (8255s) | **0.952±0.001 (4876s)** | **1.000±0.000 (3558s)** | 0.916±0.001 (30002s) | +| ResNet | 0.917±0.000 (615s) | **1.000±0.000 (71s)** | 0.937±0.001 (787s) | 0.938±0.002 (230s) | 0.865±0.001 (183s) | 0.960±0.001 (349s) | 0.828±0.001 (248s) | 0.768±0.002 (205s) | 0.925±0.002 (958s) | 0.665±0.006 (140s) | **0.794±0.002 (76s)** | 0.946±0.002 (145s) | **1.000±0.000 (93s)** | 0.911±0.001 (880s) | +| MLP | 0.913±0.001 (112s) | **1.000±0.000 (45s)** | 0.934±0.001 (274s) | 0.938±0.001 (66s) | 0.863±0.002 (61s) | 0.953±0.000 (92s) | 0.830±0.001 (68s) | 0.769±0.002 (56s) | 0.903±0.002 (159s) | 0.666±0.015 (58s) | 0.789±0.001 (48s) | 0.940±0.002 (107s) | **1.000±0.000 (48s)** | 0.910±0.001 (149s) | +| FTTransformerBucket | 0.915±0.001 (690s) | **0.999±0.001 (354s)** | 0.936±0.002 (1705s) | 0.939±0.002 (484s) | 0.876±0.002 (321s) | 0.960±0.001 (746s) | 0.857±0.000 (549s) | 0.771±0.003 (654s) | 0.909±0.002 (1177s) | 0.636±0.012 (244s) | 0.788±0.002 (710s) | 0.950±0.001 (510s) | **0.999±0.000 (634s)** | 0.913±0.001 (1164s) | | ExcelFormer | 0.918±0.001 (1587s) | **1.000±0.000 (634s)** | 0.939±0.001 (1827s) | 0.939±0.002 (378s) | 0.883±0.001 (289s) | 0.969±0.000 (678s) | 0.833±0.011 (435s) | **0.780±0.002 (938s)** | 0.940±0.003 (919s) | 0.670±0.017 (464s) | 0.794±0.003 (683s) | 0.950±0.001 (405s) | **0.999±0.000 (1169s)** | 0.919±0.001 (1798s) | -| FTTransformer | 0.918±0.001 (871s) | **1.000±0.000 (571s)** | 0.940±0.001 (1371s) | 0.936±0.001 (458s) | 0.874±0.002 (200s) | 0.959±0.001 (622s) | 0.828±0.001 (339s) | 0.773±0.002 (521s) | 0.909±0.002 (1488s) | 0.635±0.011 (392s) | 0.790±0.001 (556s) | 0.949±0.002 (374s) | **1.000±0.000 (713s)** | 0.912±0.000 (1855s) | -| TabNet | 0.911±0.001 (150s) | **1.000±0.000 (35s)** | 0.931±0.005 (254s) | 0.937±0.003 (125s) | 0.864±0.002 (52s) | 0.944±0.001 (116s) | 0.828±0.001 (79s) | 0.771±0.005 (93s) | 0.913±0.005 (177s) | 0.606±0.014 (65s) | 0.790±0.003 (41s) | 0.936±0.003 (104s) | **1.000±0.000 (64s)** | 0.910±0.001 (294s) | +| FTTransformer | 0.918±0.001 (871s) | **1.000±0.000 (571s)** | 0.940±0.001 (1371s) | 0.936±0.001 (458s) | 0.874±0.002 (200s) | 0.959±0.001 (622s) | 0.828±0.001 (339s) | 0.773±0.002 (521s) | 0.909±0.002 (1488s) | 0.635±0.011 (392s) | 0.790±0.001 (556s) | 0.949±0.002 (374s) | **1.000±0.000 (713s)** | 0.912±0.000 (1855s) | +| TabNet | 0.911±0.001 (150s) | **1.000±0.000 (35s)** | 0.931±0.005 (254s) | 0.937±0.003 (125s) | 0.864±0.002 (52s) | 0.944±0.001 (116s) | 0.828±0.001 (79s) | 0.771±0.005 (93s) | 0.913±0.005 (177s) | 0.606±0.014 (65s) | 0.790±0.003 (41s) | 0.936±0.003 (104s) | **1.000±0.000 (64s)** | 0.910±0.001 (294s) | | TabTransformer | 0.910±0.001 (2044s) | **1.000±0.000 (1321s)** | 0.928±0.001 (2519s) | 0.918±0.003 (134s) | 0.829±0.002 (64s) | 0.928±0.001 (105s) | 0.816±0.002 (99s) | 0.757±0.003 (645s) | 0.885±0.001 (1167s) | 0.652±0.006 (282s) | 0.780±0.002 (112s) | 0.937±0.001 (117s) | 0.996±0.000 (76s) | 0.905±0.001 (2283s) | #### `scale: medium` @@ -77,9 +77,9 @@ Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | |:--------------------|:------------------------|:------------------------|:------------------------|:------------------------|:-----------------------|:--------------------------|:-------------------------|:------------------------|:------------------------| -| XGBoost | 0.594±0.000 (466s) | **0.955±0.000 (6340s)** | **0.653±0.000 (19s)** | **0.986±0.000 (195s)** | 0.721±0.000 (62s) | **0.998±0.000 (70626s)** | 0.868±0.000 (159s) | 0.888±0.000 (2945s) | 0.803±0.000 (371s) | +| XGBoost | 0.597±0.000 (5717s) | **0.955±0.000 (468s)** | **0.653±0.000 (21s)** | **0.986±0.000 (2020s)** | 0.722±0.000 (358s) | **0.997±0.000 (165577s)** | 0.878±0.000 (690s) | **0.917±0.000 (6431s)** | 0.808±0.000 (7673s) | | CatBoost | 0.631±0.000 (1201s) | **0.956±0.000 (2963s)** | 0.649±0.000 (26s) | **0.986±0.000 (352s)** | 0.719±0.000 (244s) | 0.987±0.000 (2561s) | 0.863±0.000 (212s) | 0.896±0.000 (740s) | 0.803±0.000 (140s) | -| LightGBM | **0.639±0.000 (49s)** | 0.955±0.000 (126s) | 0.652±0.000 (7s) | **0.986±0.000 (99s)** | **0.723±0.000 (16s)** | 0.997±0.000 (172s) | 0.881±0.000 (83s) | **0.914±0.000 (86s)** | **0.809±0.000 (76s)** | +| LightGBM | **0.639±0.000 (49s)** | 0.955±0.000 (126s) | 0.652±0.000 (7s) | **0.986±0.000 (99s)** | **0.723±0.000 (16s)** | **0.997±0.000 (172s)** | 0.881±0.000 (83s) | 0.914±0.000 (86s) | **0.809±0.000 (76s)** | | Trompt | OOM | 0.950±0.000 (28212s) | **0.652±0.000 (5962s)** | 0.982±0.000 (19936s) | 0.716±0.000 (7110s) | 0.966±0.000 (106916s) | **0.882±0.000 (13644s)** | 0.883±0.000 (17863s) | 0.705±0.006 (11563s) | | ResNet | 0.637±0.000 (810s) | 0.948±0.000 (1051s) | 0.649±0.000 (185s) | 0.983±0.000 (239s) | 0.705±0.001 (226s) | 0.989±0.000 (1967s) | 0.871±0.001 (173s) | 0.890±0.001 (315s) | 0.719±0.001 (245s) | | MLP | 0.634±0.002 (392s) | 0.946±0.001 (2306s) | 0.650±0.000 (263s) | 0.978±0.000 (468s) | 0.699±0.001 (357s) | 0.991±0.000 (2491s) | 0.869±0.001 (449s) | 0.883±0.001 (695s) | 0.727±0.002 (368s) | @@ -115,19 +115,19 @@ Metric: RMSE, lower the better. Experimental setting: 20 Optuna search trials. 50 epochs of training. -| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | -|:--------------------|:-----------------------|:------------------------|:------------------------|:-------------------------|:------------------------|:-----------------------|:-----------------------|:-----------------------|:------------------------|:------------------------|:-----------------------|:---------------------|:------------------------| -| XGBoost | **0.250±0.000 (22s)** | 0.038±0.000 (1011s) | 0.187±0.000 (19s) | 0.475±0.000 (439s) | 0.328±0.000 (32s) | 0.401±0.000 (375s) | **0.249±0.000 (340s)** | 0.363±0.000 (378s) | 0.904±0.000 (2400s) | 0.056±0.000 (250s) | 0.820±0.000 (721s) | **0.857±0.000 (487s)** | 0.418±0.000 (46s) | -| CatBoost | 0.265±0.000 (116s) | 0.062±0.000 (129s) | 0.128±0.000 (97s) | 0.336±0.000 (103s) | 0.346±0.000 (110s) | 0.443±0.000 (97s) | 0.375±0.000 (46s) | 0.273±0.000 (693s) | 0.881±0.000 (660s) | 0.040±0.000 (80s) | 0.756±0.000 (44s) | 0.876±0.000 (110s) | 0.439±0.000 (101s) | -| LightGBM | 0.253±0.000 (38s) | 0.054±0.000 (24s) | **0.112±0.000 (10s)** | 0.302±0.000 (30s) | 0.325±0.000 (30s) | **0.384±0.000 (23s)** | 0.295±0.000 (15s) | **0.272±0.000 (26s)** | **0.877±0.000 (16s)** | 0.011±0.000 (12s) | **0.702±0.000 (13s)** | 0.863±0.000 (5s) | **0.395±0.000 (40s)** | -| Trompt | 0.261±0.003 (8390s) | **0.015±0.005 (3792s)** | 0.118±0.001 (3836s) | **0.262±0.001 (10037s)** | **0.323±0.001 (9255s)** | 0.418±0.003 (9071s) | 0.329±0.009 (2977s) | 0.312±0.002 (21967s) | OOM | **0.008±0.001 (1889s)** | 0.779±0.006 (775s) | 0.874±0.004 (3723s) | 0.424±0.005 (3185s) | -| ResNet | 0.288±0.006 (220s) | 0.018±0.003 (187s) | 0.124±0.001 (135s) | 0.268±0.001 (330s) | 0.335±0.001 (471s) | 0.434±0.004 (345s) | 0.325±0.012 (178s) | 0.324±0.004 (365s) | 0.895±0.005 (142s) | 0.036±0.002 (172s) | 0.794±0.006 (120s) | 0.875±0.004 (122s) | 0.468±0.004 (303s) | -| MLP | 0.300±0.002 (108s) | 0.141±0.015 (76s) | 0.125±0.001 (44s) | 0.272±0.002 (69s) | 0.348±0.001 (103s) | 0.435±0.002 (33s) | 0.331±0.008 (43s) | 0.380±0.004 (125s) | 0.893±0.002 (69s) | 0.017±0.001 (48s) | 0.784±0.007 (29s) | 0.881±0.005 (30s) | 0.467±0.003 (92s) -| FTTransformerBucket | 0.325±0.008 (619s) | 0.096±0.005 (290s) | 0.360±0.354 (332s) | 0.284±0.005 (768s) | 0.342±0.004 (757s) | 0.441±0.003 (835s) | 0.345±0.007 (191s) | 0.339±0.003 (3321s) | OOM | 0.105±0.011 (199s) | 0.807±0.010 (156s) | 0.885±0.008 (820s) | 0.468±0.006 (706s) | +| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | +|:--------------------|:-----------------------|:------------------------|:----------------------|:-------------------------|:------------------------|:----------------------|:-----------------------|:----------------------|:------------------------|:------------------------|:-----------------------|:---------------------|:----------------------| +| XGBoost | **0.247±0.000 (516s)** | 0.077±0.000 (14s) | 0.167±0.000 (423s) | 1.119±0.000 (1063s) | 0.328±0.000 (2044s) | 1.024±0.000 (47s) | **0.292±0.000 (844s)** | 0.606±0.000 (1765s) | **0.876±0.000 (2288s)** | 0.023±0.000 (1170s) | **0.697±0.000 (248s)** | 0.865±0.000 (8s) | 0.435±0.000 (22s) | +| CatBoost | 0.265±0.000 (116s) | 0.062±0.000 (129s) | 0.128±0.000 (97s) | 0.336±0.000 (103s) | 0.346±0.000 (110s) | 0.443±0.000 (97s) | 0.375±0.000 (46s) | 0.273±0.000 (693s) | 0.881±0.000 (660s) | 0.040±0.000 (80s) | 0.756±0.000 (44s) | 0.876±0.000 (110s) | 0.439±0.000 (101s) | +| LightGBM | 0.253±0.000 (38s) | 0.054±0.000 (24s) | **0.112±0.000 (10s)** | 0.302±0.000 (30s) | 0.325±0.000 (30s) | **0.384±0.000 (23s)** | 0.295±0.000 (15s) | **0.272±0.000 (26s)** | 0.877±0.000 (16s) | 0.011±0.000 (12s) | 0.702±0.000 (13s) | **0.863±0.000 (5s)** | **0.395±0.000 (40s)** | +| Trompt | 0.261±0.003 (8390s) | **0.015±0.005 (3792s)** | 0.118±0.001 (3836s) | **0.262±0.001 (10037s)** | **0.323±0.001 (9255s)** | 0.418±0.003 (9071s) | 0.329±0.009 (2977s) | 0.312±0.002 (21967s) | OOM | **0.008±0.001 (1889s)** | 0.779±0.006 (775s) | 0.874±0.004 (3723s) | 0.424±0.005 (3185s) | +| ResNet | 0.288±0.006 (220s) | 0.018±0.003 (187s) | 0.124±0.001 (135s) | 0.268±0.001 (330s) | 0.335±0.001 (471s) | 0.434±0.004 (345s) | 0.325±0.012 (178s) | 0.324±0.004 (365s) | 0.895±0.005 (142s) | 0.036±0.002 (172s) | 0.794±0.006 (120s) | 0.875±0.004 (122s) | 0.468±0.004 (303s) | +| MLP | 0.300±0.002 (108s) | 0.141±0.015 (76s) | 0.125±0.001 (44s) | 0.272±0.002 (69s) | 0.348±0.001 (103s) | 0.435±0.002 (33s) | 0.331±0.008 (43s) | 0.380±0.004 (125s) | 0.893±0.002 (69s) | 0.017±0.001 (48s) | 0.784±0.007 (29s) | 0.881±0.005 (30s) | 0.467±0.003 (92s) | +| FTTransformerBucket | 0.325±0.008 (619s) | 0.096±0.005 (290s) | 0.360±0.354 (332s) | 0.284±0.005 (768s) | 0.342±0.004 (757s) | 0.441±0.003 (835s) | 0.345±0.007 (191s) | 0.339±0.003 (3321s) | OOM | 0.105±0.011 (199s) | 0.807±0.010 (156s) | 0.885±0.008 (820s) | 0.468±0.006 (706s) | | ExcelFormer | 0.262±0.004 (770s) | 0.099±0.003 (490s) | 0.128±0.000 (362s) | 0.264±0.003 (796s) | 0.331±0.003 (1121s) | 0.411±0.005 (469s) | 0.298±0.012 (222s) | 0.308±0.007 (5522s) | OOM | 0.011±0.001 (227) | 0.785±0.011 (314s) | 0.890±0.003 (1186s) | 0.431±0.006 (682s) | -| FTTransformer | 0.335±0.010 (338s) | 0.161±0.022 (370s) | 0.140±0.002 (244s) | 0.277±0.004 (516s) | 0.335±0.003 (973s) | 0.445±0.003 (599s) | 0.361±0.018 (286s) | 0.345±0.005 (2443s) | OOM | 0.106±0.012 (150s) | 0.826±0.005 (121s) | 0.896±0.007 (832s) | 0.461±0.003 (647s) | -| TabNet | 0.279±0.003 (68s) | 0.224±0.016 (53s) | 0.141±0.010 (34s) | 0.275±0.002 (61s) | 0.348±0.003 (110s) | 0.451±0.007 (82s) | 0.355±0.030 (49s) | 0.332±0.004 (168s) | 0.992±0.182 (53s) | 0.015±0.002 (57s) | 0.805±0.014 (27s) | 0.885±0.013 (46s) | 0.544±0.011 (112s) | -| TabTransformer | 0.624±0.003 (1225s) | 0.229±0.003 (1200s) | 0.369±0.005 (52s) | 0.340±0.004 (163s) | 0.388±0.002 (1137s) | 0.539±0.003 (100s) | 0.619±0.005 (73s) | 0.351±0.001 (125s) | 0.893±0.005 (389s) | 0.431±0.001 (489s) | 0.819±0.002 (52s) | 0.886±0.005 (46s) | 0.545±0.004 (95s) | +| FTTransformer | 0.335±0.010 (338s) | 0.161±0.022 (370s) | 0.140±0.002 (244s) | 0.277±0.004 (516s) | 0.335±0.003 (973s) | 0.445±0.003 (599s) | 0.361±0.018 (286s) | 0.345±0.005 (2443s) | OOM | 0.106±0.012 (150s) | 0.826±0.005 (121s) | 0.896±0.007 (832s) | 0.461±0.003 (647s) | +| TabNet | 0.279±0.003 (68s) | 0.224±0.016 (53s) | 0.141±0.010 (34s) | 0.275±0.002 (61s) | 0.348±0.003 (110s) | 0.451±0.007 (82s) | 0.355±0.030 (49s) | 0.332±0.004 (168s) | 0.992±0.182 (53s) | 0.015±0.002 (57s) | 0.805±0.014 (27s) | 0.885±0.013 (46s) | 0.544±0.011 (112s) | +| TabTransformer | 0.624±0.003 (1225s) | 0.229±0.003 (1200s) | 0.369±0.005 (52s) | 0.340±0.004 (163s) | 0.388±0.002 (1137s) | 0.539±0.003 (100s) | 0.619±0.005 (73s) | 0.351±0.001 (125s) | 0.893±0.005 (389s) | 0.431±0.001 (489s) | 0.819±0.002 (52s) | 0.886±0.005 (46s) | 0.545±0.004 (95s) | #### `scale: medium` @@ -135,12 +135,12 @@ Experimental setting: 20 Optuna search trials for XGBoost and CatBoost. 5 Optuna | | 0 | 1 | 2 | 3 | 4 | 5 | |:--------------------|:-------------------------|:-------------------------|:------------------------|:------------------------|:-------------------------|:-----------------------| -| XGBoost | 0.663±0.000 (18528s) | **0.014±0.000 (380s)** | 0.089±0.000 (2441s) | 0.140±0.000 (1632s) | 0.539±0.000 (22047s) | 0.900±0.000 (1420s) | +| XGBoost | 1.100±0.000 (739s) | 0.020±0.000 (8341s) | 0.090±0.000 (3530s) | 0.142±0.000 (3664s) | 0.540±0.000 (21566s) | **0.895±0.000 (948s)** | | CatBoost | 0.669±0.000 (2037s) | 0.018±0.000 (649s) | 0.092±0.000 (391s) | 0.145±0.000 (271s) | 0.549±0.000 (1347s) | 0.898±0.000 (122s) | | LightGBM | **0.660±0.000 (199s)** | 0.015±0.000 (86s) | **0.085±0.000 (39s)** | 0.141±0.000 (35s) | **0.524±0.000 (148s)** | **0.895±0.000 (7s)** | | Trompt | OOM | **0.014±0.000 (19976s)** | 0.092±0.001 (4060s) | **0.140±0.000 (3487s)** | 0.537±0.000 (26520s) | 0.901±0.000 (2333s) | -| ResNet | 0.676±0.000 (894s) | 0.016±0.000 (548s) | 0.101±0.001 (176s) | 0.147±0.000 (503s) | 0.555±0.003 (1121s) | 0.903±0.000 (116s) | -| MLP | 0.680±0.001 (907s) | 0.016±0.000 (1015s) | 0.105±0.000 (254s) | **0.140±0.000 (313s)** | 0.558±0.001 (1756s) | 0.905±0.001 (240s) | +| ResNet | 0.676±0.000 (894s) | **0.016±0.000 (548s)** | 0.101±0.001 (176s) | 0.147±0.000 (503s) | 0.555±0.003 (1121s) | 0.903±0.000 (116s) | +| MLP | 0.680±0.001 (907s) | **0.016±0.000 (1015s)** | 0.105±0.000 (254s) | **0.140±0.000 (313s)** | 0.558±0.001 (1756s) | 0.905±0.001 (240s) | | FTTransformerBucket | 0.738±0.029 (17223s) | 0.023±0.000 (2573s) | 0.113±0.002 (645s) | 0.147±0.000 (970s) | 0.545±0.000 (3009s) | 0.908±0.000 (360s) | | ExcelFormer | **0.667±0.000 (35946s)** | 0.064±0.019 (2355s) | 0.119±0.003 (594s) | 0.220±0.009 (1285s) | 0.563±0.002 (2772s) | 0.902±0.000 (288s) | | FTTransformer | 0.673±0.000 (18524s) | 0.056±0.003 (3348s) | 0.119±0.003 (396s) | **0.141±0.000 (1049s)** | 0.561±0.001 (2403s) | 0.907±0.002 (302s) | diff --git a/pyproject.toml b/pyproject.toml index a611688ec..74e6ea6c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ full=[ "lightgbm", "datasets", "torchmetrics", + "openml", ] [project.urls] diff --git a/test/datasets/test_data_frame_openml.py b/test/datasets/test_data_frame_openml.py new file mode 100644 index 000000000..c8fbdb141 --- /dev/null +++ b/test/datasets/test_data_frame_openml.py @@ -0,0 +1,20 @@ +import pytest + +from torch_frame.datasets import OpenMLDataset +from torch_frame.typing import TaskType + + +@pytest.mark.parametrize("dataset_id", [8, 31, 455]) +def test_data_frame_openml(dataset_id): + dataset = OpenMLDataset(dataset_id) + if dataset_id == 8: + assert dataset.task_type == TaskType.REGRESSION + assert dataset.target_col == "drinks" + if dataset_id == 31: + assert dataset.task_type == TaskType.BINARY_CLASSIFICATION + assert dataset.num_classes == 2 + assert dataset.target_col == "class" + if dataset_id == 455: + assert dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION + assert dataset.num_classes == 3 + assert dataset.target_col == "origin" diff --git a/torch_frame/datasets/__init__.py b/torch_frame/datasets/__init__.py index 832a477a2..3c3a1d967 100644 --- a/torch_frame/datasets/__init__.py +++ b/torch_frame/datasets/__init__.py @@ -19,6 +19,7 @@ from .amazon_fine_food_reviews import AmazonFineFoodReviews from .diamond_images import DiamondImages from .huggingface_dataset import HuggingFaceDatasetDict +from .openml_dataset import OpenMLDataset real_world_datasets = [ 'Titanic', @@ -38,6 +39,7 @@ 'Movielens1M', 'AmazonFineFoodReviews', 'DiamondImages', + 'OpenMLDataset', ] synthetic_datasets = [ diff --git a/torch_frame/datasets/openml_dataset.py b/torch_frame/datasets/openml_dataset.py new file mode 100644 index 000000000..565a8a527 --- /dev/null +++ b/torch_frame/datasets/openml_dataset.py @@ -0,0 +1,100 @@ +import os +from typing import Optional + +import pandas as pd + +import torch_frame +from torch_frame import stype +from torch_frame.utils.infer_stype import infer_series_stype + + +class OpenMLDataset(torch_frame.data.Dataset): + r"""The `OpenML`_. + + Args: + dataset_id (int): The ID of the dataset to be loaded from OpenML. + cache_dir (str, optional): The directory where the dataset is cached. + If None, the default cache directory is used. + """ + def __init__(self, dataset_id: int, cache_dir: Optional[str] = None): + try: + import openml + except ImportError: + raise ImportError( + "The OpenML library is required by OpenMLDataset class. " + "You can install it using `pip install openml`.") + if cache_dir is not None: + openml.config.set_root_cache_directory( + os.path.expanduser(cache_dir)) + self.dataset_id = dataset_id + self._openml_dataset = openml.datasets.get_dataset( + self.dataset_id, + download_data=True, + download_qualities=True, + download_features_meta_data=True, + ) + # Get dataset info from OpenML + self.dataset_info = self._openml_dataset.qualities + target_col = self._openml_dataset.default_target_attribute + X, y, self.categorical_indicator, _ = self._openml_dataset.get_data( + target=target_col) + df = pd.concat([X, y], axis=1) + self._task_type: torch_frame.TaskType = ( + torch_frame.TaskType.BINARY_CLASSIFICATION) + self._num_classes: int = 0 + + # The column type can be inferred from the categorical_indicator + col_to_stype = { + col: + stype.categorical + if self.categorical_indicator[i] else stype.numerical + for i, col in enumerate(X.columns) + } + + # Infer the stype of the target column + target_col_type = infer_series_stype(df[target_col]) + if target_col_type == torch_frame.categorical: + assert self.dataset_info["NumberOfClasses"] > 0 + if self.dataset_info["NumberOfClasses"] == 2: + assert df[target_col].nunique() == 2 + self._task_type = torch_frame.TaskType.BINARY_CLASSIFICATION + self._num_classes = 2 + else: + assert df[target_col].nunique( + ) == self.dataset_info["NumberOfClasses"] + self._task_type = ( + torch_frame.TaskType.MULTICLASS_CLASSIFICATION) + self._num_classes = int(self.dataset_info["NumberOfClasses"]) + col_to_stype[target_col] = torch_frame.categorical + else: + assert self.dataset_info["NumberOfClasses"] == 0 + self._task_type = torch_frame.TaskType.REGRESSION + self._num_classes = 0 + col_to_stype[target_col] = torch_frame.numerical + + super().__init__(df=df, col_to_stype=col_to_stype, + target_col=target_col) + + # NOTE: Overriding the `task_type()` and `num_classes` property method + @property + def task_type(self) -> torch_frame.TaskType: + """Returns the task type of the dataset. + + Returns: + torch_frame.TaskType: The task type of the dataset. + """ + return self._task_type + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset. + + Returns: + int: The number of classes in the dataset. + """ + return self._num_classes