Skip to content

Commit b1b4bdd

Browse files
committed
plot reconstructions
1 parent f1f3f3d commit b1b4bdd

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

examples/autoencoder_fsq.py

100644100755
+45-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def iterate_dataset(data_loader):
8080
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
8181
)
8282
train_dataset = DataLoader(
83-
datasets.FashionMNIST(
83+
datasets.MNIST(
8484
root="~/data/fashion_mnist", train=True, download=True, transform=transform
8585
),
8686
batch_size=256,
@@ -92,3 +92,47 @@ def iterate_dataset(data_loader):
9292
model = SimpleFSQAutoEncoder(levels).to(device)
9393
opt = torch.optim.AdamW(model.parameters(), lr=lr)
9494
train(model, train_dataset, train_iterations=train_iter)
95+
96+
# ---- 8< -----
97+
98+
batch = next(iter(train_dataset))
99+
img, _ = batch
100+
img = img.to(device)
101+
rec_x2 = model(img)
102+
103+
# Extracting recorded information
104+
temp = rec_x2[0].cpu().detach().numpy()
105+
106+
import matplotlib.pyplot as plt
107+
108+
# Initializing subplot counter
109+
counter = 1
110+
111+
# Plotting first five images of the last batch
112+
for idx in range(5):
113+
plt.subplot(2, 5, counter)
114+
plt.title(f"index {idx}")
115+
plt.imshow(temp[idx].reshape(28,28), cmap= 'gray')
116+
plt.axis('off')
117+
118+
# Incrementing the subplot counter
119+
counter+=1
120+
121+
# Iterating over first five
122+
# images of the last batch
123+
124+
# Obtaining image from the dictionary
125+
val = img.cpu()
126+
127+
for idx in range(5):
128+
# Plotting image
129+
plt.subplot(2,5,counter)
130+
plt.imshow(val[idx].reshape(28, 28), cmap = 'gray')
131+
plt.title("Original Image")
132+
plt.axis('off')
133+
134+
# Incrementing subplot counter
135+
counter+=1
136+
137+
plt.tight_layout()
138+
plt.savefig('figgy2.png')

0 commit comments

Comments
 (0)