Skip to content

Commit abb6049

Browse files
authored
Update documentation for the basic skills tutorial level 2 on how to validate and test a model (#14874)
1 parent 633d14e commit abb6049

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

docs/source-pytorch/common/evaluation_basic.rst

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ Datasets come with two splits. Refer to the dataset documentation to find the *t
2424
2525
import torch.utils.data as data
2626
from torchvision import datasets
27+
import torchvision.transforms as transforms
2728
2829
# Load data sets
29-
train_set = datasets.MNIST(root="MNIST", download=True, train=True)
30-
test_set = datasets.MNIST(root="MNIST", download=True, train=False)
30+
transform = transforms.ToTensor()
31+
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
32+
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
3133
3234
----
3335

@@ -107,8 +109,8 @@ To add a validation loop, implement the **validation_step** method of the Lightn
107109
x = x.view(x.size(0), -1)
108110
z = self.encoder(x)
109111
x_hat = self.decoder(z)
110-
test_loss = F.mse_loss(x_hat, x)
111-
self.log("val_loss", test_loss)
112+
val_loss = F.mse_loss(x_hat, x)
113+
self.log("val_loss", val_loss)
112114
113115
----
114116

@@ -120,9 +122,9 @@ To run the validation loop, pass in the validation set to **.fit**
120122
121123
from torch.utils.data import DataLoader
122124
123-
train_set = DataLoader(train_set)
124-
val_set = DataLoader(val_set)
125+
train_loader = DataLoader(train_set)
126+
valid_loader = DataLoader(valid_set)
125127
126128
# train with both splits
127129
trainer = Trainer()
128-
trainer.fit(model, train_set, val_set)
130+
trainer.fit(model, train_loader, valid_loader)

0 commit comments

Comments
 (0)