Skip to content

Commit db5de3a

Browse files
speediedanBorda
andcommitted
update broken gan/datamodules tutorial links (#164)
* update both datamodules and basic-gan tutorials to reference stable doc version * fix show progress Co-authored-by: Jirka Borovec <[email protected]>
1 parent 22717e7 commit db5de3a

File tree

4 files changed

+12
-18
lines changed

4 files changed

+12
-18
lines changed

lightning_examples/augmentation_kornia/.meta.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ description: |
1111
to perform efficient data augmentation to train a simpple model using the GPU in batch
1212
mode without additional effort.
1313
requirements:
14-
- pytorch-lightning
1514
- kornia
1615
- torchmetrics
1716
- torchvision
1817
- matplotlib
1918
- pandas
19+
- seaborn
2020
accelerator:
2121
- CPU
2222
- GPU

lightning_examples/augmentation_kornia/augmentation.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import matplotlib.pyplot as plt
55
import numpy as np
66
import pandas as pd
7+
import seaborn as sn
78
import torch
89
import torch.nn as nn
910
import torchmetrics
@@ -18,6 +19,8 @@
1819
from torch.utils.data import DataLoader
1920
from torchvision.datasets import CIFAR10
2021

22+
sn.set()
23+
2124
# %% [markdown]
2225
# ## Define Data Augmentations module
2326
#
@@ -100,11 +103,8 @@ def __init__(self):
100103
super().__init__()
101104
# not the best model: expereiment yourself
102105
self.model = torchvision.models.resnet18(pretrained=True)
103-
104106
self.preprocess = Preprocess() # per sample transforms
105-
106107
self.transform = DataAugmentation() # per batch augmentation_kornia
107-
108108
self.train_accuracy = torchmetrics.Accuracy()
109109
self.val_accuracy = torchmetrics.Accuracy()
110110

@@ -201,18 +201,12 @@ def val_dataloader(self):
201201

202202
# %%
203203
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
204-
print(metrics.head())
205-
206-
aggreg_metrics = []
207-
agg_col = "epoch"
208-
for i, dfg in metrics.groupby(agg_col):
209-
agg = dict(dfg.mean())
210-
agg[agg_col] = i
211-
aggreg_metrics.append(agg)
212-
213-
df_metrics = pd.DataFrame(aggreg_metrics)
214-
df_metrics[["train_loss", "valid_loss"]].plot(grid=True, legend=True)
215-
df_metrics[["valid_acc", "train_acc"]].plot(grid=True, legend=True)
204+
del metrics["step"]
205+
metrics.set_index("epoch", inplace=True)
206+
print(metrics.dropna(axis=1, how="all").head())
207+
g = sn.relplot(data=metrics, kind="line")
208+
plt.gcf().set_size_inches(12, 4)
209+
plt.grid()
216210

217211
# %% [markdown]
218212
# ## Tensorboard

lightning_examples/basic-gan/gan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# ### MNIST DataModule
2121
#
2222
# Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial
23-
# on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html).
23+
# on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html).
2424

2525

2626
# %%

lightning_examples/datamodules/.meta.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build: 3
77
description: This notebook will walk you through how to start using Datamodules. With
88
the release of `pytorch-lightning` version 0.9.0, we have included a new class called
99
`LightningDataModule` to help you decouple data related hooks from your `LightningModule`.
10-
The most up to date documentation on datamodules can be found
10+
The most up-to-date documentation on datamodules can be found
1111
[here](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html).
1212
requirements:
1313
- torchvision

0 commit comments

Comments
 (0)