Skip to content

Commit df8ecc4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b8e8ae4 commit df8ecc4

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

examples/tutorial.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,17 @@ def test(loader: DataLoader) -> float:
262262
if best_val_acc < val_acc:
263263
best_val_acc = val_acc
264264
best_test_acc = test_acc
265-
print(
266-
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
267-
f"Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}"
268-
)
265+
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
266+
f"Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")
269267

270-
print(f"Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}")
268+
print(
269+
f"Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}"
270+
)
271271
elif args.framework == "skorch":
272-
from torch_frame.utils.skorch import NeuralNetClassifierPytorchFrame
273272
import torch.nn as nn
274273

274+
from torch_frame.utils.skorch import NeuralNetClassifierPytorchFrame
275+
275276
net = NeuralNetClassifierPytorchFrame(
276277
module=model,
277278
criterion=nn.CrossEntropyLoss,

torch_frame/utils/skorch.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,35 @@
1-
from skorch import NeuralNetClassifier, NeuralNet
2-
from skorch.dataset import Dataset as SkorchDataset
3-
import torch.nn as nn
4-
import torch_frame
5-
from torch_frame.data.tensor_frame import TensorFrame
6-
from torch_frame.utils import infer_df_stype
7-
from torch_frame.data.dataset import DataFrameToTensorFrameConverter, Dataset
8-
from torch_frame.data.loader import DataLoader
9-
import torch
10-
from torch_frame.typing import IndexSelectType
11-
from torch import Tensor
12-
from pandas import DataFrame
131
from typing import Any
2+
143
import pandas as pd
4+
import torch
5+
import torch.nn as nn
156
from numpy.typing import ArrayLike
7+
from pandas import DataFrame
8+
from skorch import NeuralNet, NeuralNetClassifier
9+
from skorch.dataset import Dataset as SkorchDataset
10+
from torch import Tensor
11+
12+
import torch_frame
1613
from torch_frame.config import (
1714
ImageEmbedderConfig,
1815
TextEmbedderConfig,
1916
TextTokenizerConfig,
2017
)
18+
from torch_frame.data.dataset import DataFrameToTensorFrameConverter, Dataset
19+
from torch_frame.data.loader import DataLoader
20+
from torch_frame.data.tensor_frame import TensorFrame
21+
from torch_frame.typing import IndexSelectType
22+
from torch_frame.utils import infer_df_stype
23+
2124

2225
class NeuralNetPytorchFrameDataLoader(DataLoader):
23-
def __init__(
24-
self, dataset: Dataset | TensorFrame, *args, device: torch.device, **kwargs
25-
):
26+
def __init__(self, dataset: Dataset | TensorFrame, *args,
27+
device: torch.device, **kwargs):
2628
super().__init__(dataset, *args, **kwargs)
2729
self.device = device
2830

29-
def collate_fn(self, index: IndexSelectType) -> tuple[TensorFrame, Tensor | None]:
31+
def collate_fn(
32+
self, index: IndexSelectType) -> tuple[TensorFrame, Tensor | None]:
3033
index = torch.tensor(index)
3134
res = super().collate_fn(index).to(self.device)
3235
return res, res.y
@@ -112,14 +115,18 @@ def create_dataset(self, df: DataFrame, _: Any) -> Dataset:
112115
dataset_.materialize()
113116
return dataset_
114117

115-
def split_dataset(self, dataset: Dataset) -> tuple[TensorFrame, TensorFrame]:
118+
def split_dataset(self,
119+
dataset: Dataset) -> tuple[TensorFrame, TensorFrame]:
116120
datasets = dataset.split()[:2]
117121
return datasets[0].tensor_frame, datasets[1].tensor_frame
118122

119-
def iterator_train_valid(self, dataset: Dataset, **kwargs: Any) -> DataLoader:
120-
return NeuralNetPytorchFrameDataLoader(dataset, device=self.device, **kwargs)
123+
def iterator_train_valid(self, dataset: Dataset,
124+
**kwargs: Any) -> DataLoader:
125+
return NeuralNetPytorchFrameDataLoader(dataset, device=self.device,
126+
**kwargs)
121127

122-
def fit(self, X: Dataset | DataFrame, y: ArrayLike | None=None, **fit_params):
128+
def fit(self, X: Dataset | DataFrame, y: ArrayLike | None = None,
129+
**fit_params):
123130
if isinstance(X, DataFrame):
124131
if y is not None:
125132
X["target_col"] = y
@@ -138,9 +145,11 @@ def fit(self, X: Dataset | DataFrame, y: ArrayLike | None=None, **fit_params):
138145
self.dataset_ = X
139146
return super().fit(self.dataset_.df, None, **fit_params)
140147

148+
141149
# TODO: make this behave more like NeuralNetClassifier
142150
class NeuralNetClassifierPytorchFrame(NeuralNetPytorchFrame):
143-
def fit(self, X: Dataset | DataFrame, y: ArrayLike | None=None, **fit_params):
151+
def fit(self, X: Dataset | DataFrame, y: ArrayLike | None = None,
152+
**fit_params):
144153
fit_result = super().fit(X, y, **fit_params)
145154
self.classes = self.dataset_.df["target_col"].unique()
146-
return fit_result
155+
return fit_result

0 commit comments

Comments
 (0)