Skip to content

Commit ee615e4

Browse files
authored
Add files via upload
1 parent cb41a72 commit ee615e4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+2819
-0
lines changed

Diff for: 4.2 Comparison between random and learned illumination patterns.ipynb

+237
Large diffs are not rendered by default.

Diff for: 4.3 Effect of number of iterations.ipynb

+220
Large diffs are not rendered by default.

Diff for: 4.4 Generalization of learned patterns on different datasets.ipynb

+342
Large diffs are not rendered by default.

Diff for: 4.5 Robustness to noise.ipynb

+157
Large diffs are not rendered by default.

Diff for: CDP_test_diff_K_real.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from IPython.display import clear_output
2+
import torch
3+
import torch.nn.functional as F
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
from skimage.measure import compare_ssim,compare_psnr, compare_mse
7+
from utils import compute_psnr, plot_test
8+
from tqdm import tqdm
9+
10+
11+
def test_diff_k(mask,alpha,x_test,n_test,n_batch,n_steps,rec_save_step = 50):
12+
torch.cuda.set_device(0)
13+
14+
n_masks = mask.shape[1]
15+
x_test = np.expand_dims(x_test, axis=3)
16+
_, height, width, nc = x_test.shape
17+
18+
x_test = x_test[:n_test,:,:,:].reshape(-1,nc,height,width)
19+
x_test_rec = np.zeros([n_steps//rec_save_step,*x_test.shape])
20+
n_iter = int(np.ceil(n_test/n_batch))
21+
eps_tensor = torch.cuda.FloatTensor([1e-15])
22+
epoch_idx = np.arange(n_test)
23+
24+
# image loss and measurement loss
25+
loss_x = np.zeros([n_test,n_steps])
26+
loss_y = np.zeros([n_test,n_steps])
27+
psnr_x = np.zeros([n_test,n_steps])
28+
29+
for iters in tqdm(range(n_iter)):
30+
# for iters in range(n_iter):
31+
x = x_test[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:]
32+
x_gt = torch.cuda.FloatTensor(x).view(-1, 1, nc, height, width).cuda()
33+
mask_k = torch.cuda.FloatTensor(mask).view(-1,n_masks,nc,height,width)
34+
35+
# masked signal z = mask * x
36+
z = x_gt * mask_k
37+
z_complex = F.pad(z.unsqueeze(5), (0,1), mode="constant") # pad last dim on the right
38+
Fz = torch.fft(z_complex, 2, normalized=True)
39+
# measurement y = |Fz|
40+
y = torch.norm(Fz, dim=5)
41+
42+
x_est = x_test_rec[0,epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:]
43+
x_est = torch.cuda.FloatTensor(x_est.reshape(-1,1,nc,height,width)).cuda()
44+
45+
for k in range(n_steps):
46+
# z_est = x_est * mask_k # would fail without eps_tensor
47+
z_est = x_est * mask_k + eps_tensor
48+
z_est_complex = F.pad(z_est.unsqueeze(5), (0,1), mode="constant")
49+
Fz_est = torch.fft(z_est_complex,2, normalized=True)
50+
y_est = torch.norm(Fz_est,dim=5)
51+
# angle Fz
52+
Fz_est_phase = Fz_est / (y_est.unsqueeze(5) + eps_tensor)
53+
54+
# update x
55+
x_grad = mask_k * torch.ifft( Fz_est - torch.mul(Fz_est_phase, y.unsqueeze(5)), 2, normalized=True )[:,:,:,:,:,0]
56+
x_grad = torch.sum(x_grad,dim=1)
57+
x_est = x_est - alpha * x_grad.view(x_est.shape)
58+
x_est = torch.clamp(x_est, 0, 1)
59+
60+
x_est_np = x_est.cpu().detach().numpy().reshape(-1,nc,height,width)
61+
y_np = y.cpu().detach().numpy().reshape(-1,n_masks,height,width)
62+
y_est_np = y_est.cpu().detach().numpy().reshape(-1,n_masks,height,width)
63+
64+
# loss_x is image reconstruction loss, loss_y is the measurement loss (MSE)
65+
loss_x[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],k] = np.array([compare_mse(x1,x2) for x1,x2 in zip(x,x_est_np)])
66+
psnr_x[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],k] = np.array([compute_psnr(x1,x2) for x1,x2 in zip(x,x_est_np)])
67+
loss_y[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],k] = np.array([compare_mse(y1,y2) for y1,y2 in zip(y_np,y_est_np)])
68+
69+
if (k+1)%rec_save_step == 0:
70+
x_test_rec[k//rec_save_step,epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:] = x_est.cpu().detach().numpy().reshape(-1,nc,height,width)
71+
72+
return loss_x,psnr_x,loss_y,x_test_rec

Diff for: CDP_test_real.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from IPython.display import clear_output
2+
import torch
3+
from torch import nn
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
from skimage.measure import compare_ssim,compare_psnr, compare_mse
7+
from utils import compute_psnr, plot_test
8+
9+
from pathlib import Path
10+
from dataset import *
11+
from tqdm import tqdm
12+
13+
14+
def test(u,alpha,x_test,n_test,n_batch,n_steps,plot_loss=False):
15+
torch.cuda.set_device(0)
16+
N_mask = u.shape[1]
17+
x_test = np.expand_dims(x_test, axis=3)
18+
_, height, width, nc = x_test.shape
19+
x_test = x_test[:n_test,:,:,:].reshape(-1,nc,height,width)
20+
21+
N_iter = int(np.ceil(n_test/n_batch))
22+
x_test_rec = np.zeros_like(x_test)
23+
24+
eps_tensor = torch.cuda.FloatTensor([1e-15])
25+
epoch_idx = np.arange(n_test)
26+
27+
for iters in tqdm(range(N_iter)):
28+
x = x_test[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:]
29+
x_gt = torch.cuda.FloatTensor(x).view(-1, 1, nc, height, width).cuda()
30+
uk = torch.cuda.FloatTensor(u).view(-1,N_mask,nc,height,width)
31+
32+
# z = x * u, multiplicative masks
33+
z = x_gt * uk
34+
dummy_zeros = torch.zeros_like(z).cuda()
35+
z_complex = torch.cat((z.unsqueeze(5), dummy_zeros.unsqueeze(5)), 5)
36+
37+
Fz = torch.fft(z_complex, 2, normalized=True)
38+
# y = |F(x*u)| = |Fz|
39+
y = torch.norm(Fz, dim=5)
40+
y_dual = torch.cat((y.unsqueeze(5), y.unsqueeze(5)), 5)
41+
42+
x_est = x_test_rec[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:]
43+
x_est = torch.cuda.FloatTensor(x_est.reshape(-1,1,nc,height,width)).cuda()
44+
45+
# image loss and measurement loss
46+
loss_x_pr=[]
47+
loss_y_pr=[]
48+
for kx in range(n_steps):
49+
50+
z_est = x_est * uk + eps_tensor
51+
z_est_complex = torch.cat((z_est.unsqueeze(5), dummy_zeros.unsqueeze(5)), 5)
52+
Fz_est = torch.fft(z_est_complex,2, normalized=True)
53+
y_est = torch.norm(Fz_est,dim=5)
54+
y_est_dual = torch.cat((y_est.unsqueeze(5), y_est.unsqueeze(5)), 5)
55+
# angle Fz
56+
Fz_est_phase = Fz_est / (y_est_dual + eps_tensor)
57+
# update x
58+
x_grad_complex = torch.ifft( Fz_est - torch.mul(Fz_est_phase, y_dual), 2, normalized=True)
59+
x_grad = uk * x_grad_complex[:,:,:,:,:,0]
60+
x_grad = torch.sum(x_grad,dim=1)
61+
x_est = x_est - alpha * x_grad.view(x_est.shape)
62+
x_est = torch.clamp(x_est, 0, 1)
63+
64+
# loss_x is image reconstruction loss, loss_y is the measurement loss
65+
loss_x_pr.append(np.mean((x-x_est.cpu().detach().numpy())**2))
66+
loss_y_pr.append(height*width*np.mean((y.cpu().detach().numpy() - y_est.cpu().detach().numpy())**2))
67+
68+
x_test_rec[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:] = x_est.cpu().detach().numpy().reshape(-1,nc,height,width)
69+
70+
if plot_loss:
71+
plt.figure(figsize = (12,4))
72+
plt.subplot(121)
73+
plt.plot(loss_x_pr)
74+
plt.yscale('log')
75+
plt.title(f'loss x @ iter {iters}')
76+
plt.subplot(122)
77+
plt.plot(loss_y_pr)
78+
plt.yscale('log')
79+
plt.title(f'loss y @ iter {iters}')
80+
plt.show()
81+
82+
83+
mse_list = [compare_mse(x_test[i,0,:,:],x_test_rec[i,0,:,:]) for i in range(n_test)]
84+
psnr_list = [compute_psnr(x_test[i,0,:,:],x_test_rec[i,0,:,:]) for i in range(n_test)]
85+
ssim_list = [compare_ssim(x_test[i,0,:,:],x_test_rec[i,0,:,:]) for i in range(n_test)]
86+
mean_of_psnr = np.mean(psnr_list)
87+
psnr_of_mean = 20*np.log10((np.max(x_test)-np.min(x_test))/np.sqrt(np.mean(mse_list)))
88+
89+
return x_test_rec,mse_list,psnr_list

Diff for: CDP_test_real_noisy.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from IPython.display import clear_output
2+
import torch
3+
from torch import nn
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
from skimage.measure import compare_ssim,compare_psnr, compare_mse
7+
from utils import compute_psnr, plot_test
8+
9+
from pathlib import Path
10+
from dataset import *
11+
from tqdm import tqdm
12+
13+
14+
def test(u,alpha,x_test,n_test,n_batch,n_steps,noise_type='Poisson', noise_snr=20,plot_loss=False):
15+
torch.cuda.set_device(0)
16+
N_mask = u.shape[1]
17+
x_test = np.expand_dims(x_test, axis=3)
18+
_, height, width, nc = x_test.shape
19+
x_test = x_test[:n_test,:,:,:].reshape(-1,nc,height,width)
20+
21+
N_iter = int(np.ceil(n_test/n_batch))
22+
x_test_rec = np.zeros_like(x_test)
23+
24+
eps_tensor = torch.cuda.FloatTensor([1e-15])
25+
epoch_idx = np.arange(n_test)
26+
27+
for iters in tqdm(range(N_iter)):
28+
# for iters in range(N_iter):
29+
x = x_test[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:]
30+
x_gt = torch.cuda.FloatTensor(x).view(-1, 1, nc, height, width).cuda()
31+
uk = torch.cuda.FloatTensor(u).view(-1,N_mask,nc,height,width)
32+
33+
# z = x * u, multiplicative masks
34+
z = x_gt * uk
35+
dummy_zeros = torch.zeros_like(z).cuda()
36+
z_complex = torch.cat((z.unsqueeze(5), dummy_zeros.unsqueeze(5)), 5)
37+
38+
Fz = torch.fft(z_complex, 2, normalized=True)
39+
# y = |F(x*u)| = |Fz|
40+
y = torch.norm(Fz, dim=5)
41+
42+
if noise_type=='Poisson':
43+
true_meas=y.cpu().detach()
44+
noise=np.random.normal(0,true_meas,(y.shape))
45+
46+
noise_tensor=torch.cuda.FloatTensor(noise)
47+
noise_coeff=(y.pow(2).mean()/noise_tensor.pow(2).mean()/np.power(10,noise_snr/10.0)).pow(0.5)
48+
y=y+noise_coeff*noise_tensor
49+
y=torch.relu(y)
50+
elif noise_type=='Gaussian':
51+
noise=np.random.normal(0,0.1,(y.shape))
52+
53+
noise_tensor=torch.cuda.FloatTensor(noise)
54+
noise_coeff=(y.pow(2).mean()/noise_tensor.pow(2).mean()/np.power(10,noise_snr/10.0)).pow(0.5)
55+
y_=y+noise_coeff*noise_tensor
56+
y=torch.relu(y)
57+
else:
58+
print('Unsupported noise type')
59+
60+
61+
y_dual = torch.cat((y.unsqueeze(5), y.unsqueeze(5)), 5)
62+
63+
x_est = x_test_rec[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:]
64+
x_est = torch.cuda.FloatTensor(x_est.reshape(-1,1,nc,height,width)).cuda()
65+
66+
# image loss and measurement loss
67+
loss_x_pr=[]
68+
loss_y_pr=[]
69+
for kx in range(n_steps):
70+
71+
z_est = x_est * uk + eps_tensor
72+
z_est_complex = torch.cat((z_est.unsqueeze(5), dummy_zeros.unsqueeze(5)), 5)
73+
Fz_est = torch.fft(z_est_complex,2, normalized=True)
74+
y_est = torch.norm(Fz_est,dim=5)
75+
y_est_dual = torch.cat((y_est.unsqueeze(5), y_est.unsqueeze(5)), 5)
76+
# angle Fz
77+
Fz_est_phase = Fz_est / (y_est_dual + eps_tensor)
78+
# update x
79+
x_grad_complex = torch.ifft( Fz_est - torch.mul(Fz_est_phase, y_dual), 2, normalized=True)
80+
x_grad = uk * x_grad_complex[:,:,:,:,:,0]
81+
x_grad = torch.sum(x_grad,dim=1)
82+
x_est = x_est - alpha * x_grad.view(x_est.shape)
83+
x_est = torch.clamp(x_est, 0, 1)
84+
85+
# loss_x is image reconstruction loss, loss_y is the measurement loss
86+
loss_x_pr.append(np.mean((x-x_est.cpu().detach().numpy())**2))
87+
loss_y_pr.append(height*width*np.mean((y.cpu().detach().numpy() - y_est.cpu().detach().numpy())**2))
88+
89+
x_test_rec[epoch_idx[iters*n_batch:np.min([(iters+1)*n_batch,n_test])],:,:,:] = x_est.cpu().detach().numpy().reshape(-1,nc,height,width)
90+
91+
if plot_loss:
92+
plt.figure(figsize = (12,4))
93+
plt.subplot(121)
94+
plt.plot(loss_x_pr)
95+
plt.yscale('log')
96+
plt.title(f'loss x @ iter {iters}')
97+
plt.subplot(122)
98+
plt.plot(loss_y_pr)
99+
plt.yscale('log')
100+
plt.title(f'loss y @ iter {iters}')
101+
plt.show()
102+
103+
104+
mse_list = [compare_mse(x_test[i,0,:,:],x_test_rec[i,0,:,:]) for i in range(n_test)]
105+
psnr_list = [compute_psnr(x_test[i,0,:,:],x_test_rec[i,0,:,:]) for i in range(n_test)]
106+
ssim_list = [compare_ssim(x_test[i,0,:,:],x_test_rec[i,0,:,:]) for i in range(n_test)]
107+
mean_of_psnr = np.mean(psnr_list)
108+
psnr_of_mean = 20*np.log10((np.max(x_test)-np.min(x_test))/np.sqrt(np.mean(mse_list)))
109+
110+
return x_test_rec,mse_list,psnr_list

0 commit comments

Comments
 (0)