@@ -22,8 +22,40 @@ def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and
22
22
class LPIPS (nn .Module ):
23
23
def __init__ (self , pretrained = True , net = 'alex' , version = '0.1' , lpips = True , spatial = False ,
24
24
pnet_rand = False , pnet_tune = False , use_dropout = True , model_path = None , eval_mode = True , verbose = True ):
25
- # lpips - [True] means with linear calibration on top of base network
26
- # pretrained - [True] means load linear weights
25
+ """ Initializes a perceptual loss torch.nn.Module
26
+
27
+ Parameters (default listed first)
28
+ ---------------------------------
29
+ lpips : bool
30
+ [True] use linear layers on top of base/trunk network
31
+ [False] means no linear layers; each layer is averaged together
32
+ pretrained : bool
33
+ This flag controls the linear layers, which are only in effect when lpips=True above
34
+ [True] means linear layers are calibrated with human perceptual judgments
35
+ [False] means linear layers are randomly initialized
36
+ pnet_rand : bool
37
+ [False] means trunk loaded with ImageNet classification weights
38
+ [True] means randomly initialized trunk
39
+ net : str
40
+ ['alex','vgg','squeeze'] are the base/trunk networks available
41
+ version : str
42
+ ['v0.1'] is the default and latest
43
+ ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
44
+ model_path : 'str'
45
+ [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
46
+
47
+ The following parameters should only be changed if training the network
48
+
49
+ eval_mode : bool
50
+ [True] is for test mode (default)
51
+ [False] is for training mode
52
+ pnet_tune
53
+ [False] tune the base/trunk network
54
+ [True] keep base/trunk frozen
55
+ use_dropout : bool
56
+ [True] to use dropout when training linear layers
57
+ [False] for no dropout when training linear layers
58
+ """
27
59
28
60
super (LPIPS , self ).__init__ ()
29
61
if (verbose ):
@@ -102,19 +134,9 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False):
102
134
else :
103
135
res = [spatial_average (diffs [kk ].sum (dim = 1 ,keepdim = True ), keepdim = True ) for kk in range (self .L )]
104
136
105
- val = res [ 0 ]
106
- for l in range (1 , self .L ):
137
+ val = 0
138
+ for l in range (self .L ):
107
139
val += res [l ]
108
-
109
- # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
110
- # b = torch.max(self.lins[kk](feats0[kk]**2))
111
- # for kk in range(self.L):
112
- # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
113
- # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
114
- # a = a/self.L
115
- # from IPython import embed
116
- # embed()
117
- # return 10*torch.log10(b/a)
118
140
119
141
if (retPerLayer ):
120
142
return (val , res )
0 commit comments