@@ -80,7 +80,7 @@ def iterate_dataset(data_loader):
80
80
[transforms .ToTensor (), transforms .Normalize ((0.5 ,), (0.5 ,))]
81
81
)
82
82
train_dataset = DataLoader (
83
- datasets .FashionMNIST (
83
+ datasets .MNIST (
84
84
root = "~/data/fashion_mnist" , train = True , download = True , transform = transform
85
85
),
86
86
batch_size = 256 ,
@@ -92,3 +92,47 @@ def iterate_dataset(data_loader):
92
92
model = SimpleFSQAutoEncoder (levels ).to (device )
93
93
opt = torch .optim .AdamW (model .parameters (), lr = lr )
94
94
train (model , train_dataset , train_iterations = train_iter )
95
+
96
+ # ---- 8< -----
97
+
98
+ batch = next (iter (train_dataset ))
99
+ img , _ = batch
100
+ img = img .to (device )
101
+ rec_x2 = model (img )
102
+
103
+ # Extracting recorded information
104
+ temp = rec_x2 [0 ].cpu ().detach ().numpy ()
105
+
106
+ import matplotlib .pyplot as plt
107
+
108
+ # Initializing subplot counter
109
+ counter = 1
110
+
111
+ # Plotting first five images of the last batch
112
+ for idx in range (5 ):
113
+ plt .subplot (2 , 5 , counter )
114
+ plt .title (f"index { idx } " )
115
+ plt .imshow (temp [idx ].reshape (28 ,28 ), cmap = 'gray' )
116
+ plt .axis ('off' )
117
+
118
+ # Incrementing the subplot counter
119
+ counter += 1
120
+
121
+ # Iterating over first five
122
+ # images of the last batch
123
+
124
+ # Obtaining image from the dictionary
125
+ val = img .cpu ()
126
+
127
+ for idx in range (5 ):
128
+ # Plotting image
129
+ plt .subplot (2 ,5 ,counter )
130
+ plt .imshow (val [idx ].reshape (28 , 28 ), cmap = 'gray' )
131
+ plt .title ("Original Image" )
132
+ plt .axis ('off' )
133
+
134
+ # Incrementing subplot counter
135
+ counter += 1
136
+
137
+ plt .tight_layout ()
138
+ plt .savefig ('figgy2.png' )
0 commit comments