Skip to content

Commit 6b29211

Browse files
amisevAlexey MisevAlexey Misevawaelchli
authored
Fixed bug: replaced bce_loss_with_logits with bce_loss (#7096)
* Fixed bug: replaced bce_loss_with_logits with bec_loss * Fixed bug: removed sigmoid activation from forward pass * switched names for scores and logits Co-authored-by: Alexey Misev <[email protected]> Co-authored-by: Alexey Misev <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 78d45a1 commit 6b29211

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def forward(self, x):
225225
# 2. Classifier (returns logits):
226226
x = self.fc(x)
227227

228-
return torch.sigmoid(x)
228+
return x
229229

230230
def loss(self, logits, labels):
231231
return self.loss_func(input=logits, target=labels)
@@ -234,27 +234,29 @@ def training_step(self, batch, batch_idx):
234234
# 1. Forward pass:
235235
x, y = batch
236236
y_logits = self.forward(x)
237+
y_scores = torch.sigmoid(y_logits)
237238
y_true = y.view((-1, 1)).type_as(x)
238239

239240
# 2. Compute loss
240241
train_loss = self.loss(y_logits, y_true)
241242

242243
# 3. Compute accuracy:
243-
self.log("train_acc", self.train_acc(y_logits, y_true.int()), prog_bar=True)
244+
self.log("train_acc", self.train_acc(y_scores, y_true.int()), prog_bar=True)
244245

245246
return train_loss
246247

247248
def validation_step(self, batch, batch_idx):
248249
# 1. Forward pass:
249250
x, y = batch
250251
y_logits = self.forward(x)
252+
y_scores = torch.sigmoid(y_logits)
251253
y_true = y.view((-1, 1)).type_as(x)
252254

253255
# 2. Compute loss
254256
self.log("val_loss", self.loss(y_logits, y_true), prog_bar=True)
255257

256258
# 3. Compute accuracy:
257-
self.log("val_acc", self.valid_acc(y_logits, y_true.int()), prog_bar=True)
259+
self.log("val_acc", self.valid_acc(y_scores, y_true.int()), prog_bar=True)
258260

259261
def configure_optimizers(self):
260262
parameters = list(self.parameters())

0 commit comments

Comments
 (0)