Skip to content

Commit 1c4cf7c

Browse files
Richard ZhangRichard Zhang
Richard Zhang
authored and
Richard Zhang
committed
fix ret per layer bug
1 parent 8db312a commit 1c4cf7c

File tree

3 files changed

+40
-44
lines changed

3 files changed

+40
-44
lines changed

lpips/__init__.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,6 @@
1010
from lpips.trainer import *
1111
from lpips.lpips import *
1212

13-
# class PerceptualLoss(torch.nn.Module):
14-
# def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric)
15-
# # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16-
# super(PerceptualLoss, self).__init__()
17-
# print('Setting up Perceptual loss...')
18-
# self.use_gpu = use_gpu
19-
# self.spatial = spatial
20-
# self.gpu_ids = gpu_ids
21-
# self.model = dist_model.DistModel()
22-
# self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version)
23-
# print('...[%s] initialized'%self.model.name())
24-
# print('...Done')
25-
26-
# def forward(self, pred, target, normalize=False):
27-
# """
28-
# Pred and target are Variables.
29-
# If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30-
# If normalize is False, assumes the images are already between [-1,+1]
31-
32-
# Inputs pred and target are Nx3xHxW
33-
# Output pytorch Variable N long
34-
# """
35-
36-
# if normalize:
37-
# target = 2 * target - 1
38-
# pred = 2 * pred - 1
39-
40-
# return self.model.forward(target, pred)
41-
4213
def normalize_tensor(in_feat,eps=1e-10):
4314
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
4415
return in_feat/(norm_factor+eps)

lpips/lpips.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,40 @@ def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and
2222
class LPIPS(nn.Module):
2323
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
2424
pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
25-
# lpips - [True] means with linear calibration on top of base network
26-
# pretrained - [True] means load linear weights
25+
""" Initializes a perceptual loss torch.nn.Module
26+
27+
Parameters (default listed first)
28+
---------------------------------
29+
lpips : bool
30+
[True] use linear layers on top of base/trunk network
31+
[False] means no linear layers; each layer is averaged together
32+
pretrained : bool
33+
This flag controls the linear layers, which are only in effect when lpips=True above
34+
[True] means linear layers are calibrated with human perceptual judgments
35+
[False] means linear layers are randomly initialized
36+
pnet_rand : bool
37+
[False] means trunk loaded with ImageNet classification weights
38+
[True] means randomly initialized trunk
39+
net : str
40+
['alex','vgg','squeeze'] are the base/trunk networks available
41+
version : str
42+
['v0.1'] is the default and latest
43+
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
44+
model_path : 'str'
45+
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
46+
47+
The following parameters should only be changed if training the network
48+
49+
eval_mode : bool
50+
[True] is for test mode (default)
51+
[False] is for training mode
52+
pnet_tune
53+
[False] tune the base/trunk network
54+
[True] keep base/trunk frozen
55+
use_dropout : bool
56+
[True] to use dropout when training linear layers
57+
[False] for no dropout when training linear layers
58+
"""
2759

2860
super(LPIPS, self).__init__()
2961
if(verbose):
@@ -102,19 +134,9 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False):
102134
else:
103135
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
104136

105-
val = res[0]
106-
for l in range(1,self.L):
137+
val = 0
138+
for l in range(self.L):
107139
val += res[l]
108-
109-
# a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
110-
# b = torch.max(self.lins[kk](feats0[kk]**2))
111-
# for kk in range(self.L):
112-
# a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
113-
# b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
114-
# a = a/self.L
115-
# from IPython import embed
116-
# embed()
117-
# return 10*torch.log10(b/a)
118140

119141
if(retPerLayer):
120142
return (val, res)

lpips_2imgs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,8 @@
2424
img1 = img1.cuda()
2525

2626
# Compute distance
27-
dist01 = loss_fn.forward(img0,img1)
27+
dist01 = loss_fn.forward(img0, img1)
2828
print('Distance: %.3f'%dist01)
29+
30+
dist01 = loss_fn.forward(img0, img1, retPerLayer=True)
31+
print(dist01)

0 commit comments

Comments
 (0)