1
- from skorch import NeuralNetClassifier , NeuralNet
2
- from skorch .dataset import Dataset as SkorchDataset
3
- import torch .nn as nn
4
- import torch_frame
5
- from torch_frame .data .tensor_frame import TensorFrame
6
- from torch_frame .utils import infer_df_stype
7
- from torch_frame .data .dataset import DataFrameToTensorFrameConverter , Dataset
8
- from torch_frame .data .loader import DataLoader
9
- import torch
10
- from torch_frame .typing import IndexSelectType
11
- from torch import Tensor
12
- from pandas import DataFrame
13
1
from typing import Any
2
+
14
3
import pandas as pd
4
+ import torch
5
+ import torch .nn as nn
15
6
from numpy .typing import ArrayLike
7
+ from pandas import DataFrame
8
+ from skorch import NeuralNet , NeuralNetClassifier
9
+ from skorch .dataset import Dataset as SkorchDataset
10
+ from torch import Tensor
11
+
12
+ import torch_frame
16
13
from torch_frame .config import (
17
14
ImageEmbedderConfig ,
18
15
TextEmbedderConfig ,
19
16
TextTokenizerConfig ,
20
17
)
18
+ from torch_frame .data .dataset import DataFrameToTensorFrameConverter , Dataset
19
+ from torch_frame .data .loader import DataLoader
20
+ from torch_frame .data .tensor_frame import TensorFrame
21
+ from torch_frame .typing import IndexSelectType
22
+ from torch_frame .utils import infer_df_stype
23
+
21
24
22
25
class NeuralNetPytorchFrameDataLoader (DataLoader ):
23
- def __init__ (
24
- self , dataset : Dataset | TensorFrame , * args , device : torch .device , ** kwargs
25
- ):
26
+ def __init__ (self , dataset : Dataset | TensorFrame , * args ,
27
+ device : torch .device , ** kwargs ):
26
28
super ().__init__ (dataset , * args , ** kwargs )
27
29
self .device = device
28
30
29
- def collate_fn (self , index : IndexSelectType ) -> tuple [TensorFrame , Tensor | None ]:
31
+ def collate_fn (
32
+ self , index : IndexSelectType ) -> tuple [TensorFrame , Tensor | None ]:
30
33
index = torch .tensor (index )
31
34
res = super ().collate_fn (index ).to (self .device )
32
35
return res , res .y
@@ -112,14 +115,18 @@ def create_dataset(self, df: DataFrame, _: Any) -> Dataset:
112
115
dataset_ .materialize ()
113
116
return dataset_
114
117
115
- def split_dataset (self , dataset : Dataset ) -> tuple [TensorFrame , TensorFrame ]:
118
+ def split_dataset (self ,
119
+ dataset : Dataset ) -> tuple [TensorFrame , TensorFrame ]:
116
120
datasets = dataset .split ()[:2 ]
117
121
return datasets [0 ].tensor_frame , datasets [1 ].tensor_frame
118
122
119
- def iterator_train_valid (self , dataset : Dataset , ** kwargs : Any ) -> DataLoader :
120
- return NeuralNetPytorchFrameDataLoader (dataset , device = self .device , ** kwargs )
123
+ def iterator_train_valid (self , dataset : Dataset ,
124
+ ** kwargs : Any ) -> DataLoader :
125
+ return NeuralNetPytorchFrameDataLoader (dataset , device = self .device ,
126
+ ** kwargs )
121
127
122
- def fit (self , X : Dataset | DataFrame , y : ArrayLike | None = None , ** fit_params ):
128
+ def fit (self , X : Dataset | DataFrame , y : ArrayLike | None = None ,
129
+ ** fit_params ):
123
130
if isinstance (X , DataFrame ):
124
131
if y is not None :
125
132
X ["target_col" ] = y
@@ -138,9 +145,11 @@ def fit(self, X: Dataset | DataFrame, y: ArrayLike | None=None, **fit_params):
138
145
self .dataset_ = X
139
146
return super ().fit (self .dataset_ .df , None , ** fit_params )
140
147
148
+
141
149
# TODO: make this behave more like NeuralNetClassifier
142
150
class NeuralNetClassifierPytorchFrame (NeuralNetPytorchFrame ):
143
- def fit (self , X : Dataset | DataFrame , y : ArrayLike | None = None , ** fit_params ):
151
+ def fit (self , X : Dataset | DataFrame , y : ArrayLike | None = None ,
152
+ ** fit_params ):
144
153
fit_result = super ().fit (X , y , ** fit_params )
145
154
self .classes = self .dataset_ .df ["target_col" ].unique ()
146
- return fit_result
155
+ return fit_result
0 commit comments