@@ -24,10 +24,12 @@ Datasets come with two splits. Refer to the dataset documentation to find the *t
24
24
25
25
import torch.utils.data as data
26
26
from torchvision import datasets
27
+ import torchvision.transforms as transforms
27
28
28
29
# Load data sets
29
- train_set = datasets.MNIST(root = " MNIST" , download = True , train = True )
30
- test_set = datasets.MNIST(root = " MNIST" , download = True , train = False )
30
+ transform = transforms.ToTensor()
31
+ train_set = datasets.MNIST(root = " MNIST" , download = True , train = True , transform = transform)
32
+ test_set = datasets.MNIST(root = " MNIST" , download = True , train = False , transform = transform)
31
33
32
34
----
33
35
@@ -107,8 +109,8 @@ To add a validation loop, implement the **validation_step** method of the Lightn
107
109
x = x.view(x.size(0 ), - 1 )
108
110
z = self .encoder(x)
109
111
x_hat = self .decoder(z)
110
- test_loss = F.mse_loss(x_hat, x)
111
- self .log(" val_loss" , test_loss )
112
+ val_loss = F.mse_loss(x_hat, x)
113
+ self .log(" val_loss" , val_loss )
112
114
113
115
----
114
116
@@ -120,9 +122,9 @@ To run the validation loop, pass in the validation set to **.fit**
120
122
121
123
from torch.utils.data import DataLoader
122
124
123
- train_set = DataLoader(train_set)
124
- val_set = DataLoader(val_set )
125
+ train_loader = DataLoader(train_set)
126
+ valid_loader = DataLoader(valid_set )
125
127
126
128
# train with both splits
127
129
trainer = Trainer()
128
- trainer.fit(model, train_set, val_set )
130
+ trainer.fit(model, train_loader, valid_loader )
0 commit comments