Skip to content

Commit bdd7516

Browse files
author
KevinMusgrave
committed
Moved the loss wrappers into losses
1 parent 4cd2376 commit bdd7516

12 files changed

+92
-99
lines changed

docs/losses.md

+78-4
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,44 @@ loss_optimizer.step()
297297
* **loss**: The loss per element in the batch. Reduction type is ```"element"```.
298298

299299

300+
## CrossBatchMemory
301+
This wraps a loss function, and implements [Cross-Batch Memory for Embedding Learning](https://arxiv.org/pdf/1912.06798.pdf){target=_blank}. It stores embeddings from previous iterations in a queue, and uses them to form more pairs/triplets with the current iteration's embeddings.
302+
303+
```python
304+
losses.CrossBatchMemory(loss, embedding_size, memory_size=1024, miner=None)
305+
```
306+
307+
**Parameters**:
308+
309+
* **loss**: The loss function to be wrapped. For example, you could pass in ```ContrastiveLoss()```.
310+
* **embedding_size**: The size of the embeddings that you pass into the loss function. For example, if your batch size is 128 and your network outputs 512 dimensional embeddings, then set ```embedding_size``` to 512.
311+
* **memory_size**: The size of the memory queue.
312+
* **miner**: An optional [tuple miner](miners.md), which will be used to mine pairs/triplets from the memory queue.
313+
314+
**Forward function**
315+
```python
316+
loss_fn(embeddings, labels, indices_tuple=None, enqueue_mask=None)
317+
```
318+
319+
As shown above, CrossBatchMemory comes with a 4th argument in its ```forward``` function:
320+
321+
* **enqueue_mask**: A boolean tensor where `enqueue_mask[i]` is True if `embeddings[i]` should be added to the memory queue. This enables CrossBatchMemory to be used in self-supervision frameworks like [MoCo](https://arxiv.org/pdf/1911.05722.pdf). Check out the [MoCo on CIFAR100](https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/examples#simple-examples) notebook to see how this works.
322+
323+
324+
**Supported Loss Functions**:
325+
- [AngularLoss](losses.md#AngularLoss)
326+
- [CircleLoss](losses.md#CircleLoss)
327+
- [ContrastiveLoss](losses.md#ContrastiveLoss)
328+
- [GeneralizedLiftedStructureLoss](losses.md#GeneralizedLiftedStructureLoss)
329+
- [IntraPairVarianceLoss](losses.md#IntraPairVarianceLoss)
330+
- [LiftedStructureLoss](losses.md#LiftedStructureLoss)
331+
- [MultiSimilarityLoss](losses.md#MultiSimilarityLoss)
332+
- [NTXentLoss](losses.md#NTXentLoss)
333+
- [SignalToNoiseRatioContrastiveLoss](losses.md#SignalToNoiseRatioContrastiveLoss)
334+
- [SupConLoss](losses.md#SupConLoss)
335+
- [TripletMarginLoss](losses.md#TripletMarginLoss)
336+
- [TupletMarginLoss](losses.md#TupletMarginLoss)
337+
300338

301339
**Reset queue**
302340

@@ -401,11 +439,11 @@ losses.IntraPairVarianceLoss(pos_eps=0.01, neg_eps=0.01, **kwargs)
401439
* **pos_eps**: The epsilon in the L<sub>pos</sub> equation. The paper uses 0.01.
402440
* **neg_eps**: The epsilon in the L<sub>neg</sub> equation. The paper uses 0.01.
403441

404-
You should probably use this in conjunction with another loss, as described in the paper. You can accomplish this by using [MultipleLosses](wrappers.md#multiplelosses):
442+
You should probably use this in conjunction with another loss, as described in the paper. You can accomplish this by using [MultipleLosses](losses.md#multiplelosses):
405443
```python
406444
main_loss = losses.TupletMarginLoss()
407445
var_loss = losses.IntraPairVarianceLoss()
408-
complete_loss = wrappers.MultipleLosses([main_loss, var_loss], weights=[1, 0.5])
446+
complete_loss = losses.MultipleLosses([main_loss, var_loss], weights=[1, 0.5])
409447
```
410448

411449
**Default distance**:
@@ -579,6 +617,18 @@ losses.MultiSimilarityLoss(alpha=2, beta=50, base=0.5, **kwargs)
579617

580618
* **loss**: The loss per element in the batch. Reduction type is ```"element"```.
581619

620+
## MultipleLosses
621+
This is a simple wrapper for multiple losses. Pass in a list of already-initialized loss functions. Then, when you call forward on this object, it will return the sum of all wrapped losses.
622+
```python
623+
losses.MultipleLosses(losses, miners=None, weights=None)
624+
```
625+
**Parameters**:
626+
627+
* **losses**: A list or dictionary of initialized loss functions. On the forward call of MultipleLosses, each wrapped loss will be computed, and then the average will be returned.
628+
* **miners**: Optional. A list or dictionary of mining functions. This allows you to pair mining functions with loss functions. For example, if ```losses = [loss_A, loss_B]```, and ```miners = [None, miner_B]``` then no mining will be done for ```loss_A```, but the output of ```miner_B``` will be passed to ```loss_B```. The same logic applies if ```losses = {"loss_A": loss_A, "loss_B": loss_B}``` and ```miners = {"loss_B": miner_B}```.
629+
* **weights**: Optional. A list or dictionary of loss weights, which will be multiplied by the corresponding losses obtained by the loss functions. The default is to multiply each loss by 1. If ```losses``` is a list, then ```weights``` must be a list. If ```losses``` is a dictionary, ```weights``` must contain the same keys as ```losses```.
630+
631+
582632
## NCALoss
583633
[Neighbourhood Components Analysis](https://www.cs.toronto.edu/~hinton/absps/nca.pdf){target=_blank}
584634
```python
@@ -787,6 +837,30 @@ loss_optimizer.step()
787837
* **loss**: The loss per element in the batch, that results in a non zero exponent in the cross entropy expression. Reduction type is ```"element"```.
788838

789839

840+
## SelfSupervisedLoss
841+
842+
A common use case is to have embeddings and ref_emb be augmented versions of each other. For most losses right now you have to create labels to indicate which embeddings correspond with which ref_emb. `SelfSupervisedLoss` automates this.
843+
844+
```python
845+
loss_fn = losses.TripletMarginLoss()
846+
loss_fn = SelfSupervisedLoss(loss_fn)
847+
loss = loss_fn(embeddings, labels)
848+
```
849+
850+
**Supported Loss Functions**:
851+
- [AngularLoss](losses.md#AngularLoss)
852+
- [CircleLoss](losses.md#CircleLoss)
853+
- [ContrastiveLoss](losses.md#ContrastiveLoss)
854+
- [IntraPairVarianceLoss](losses.md#IntraPairVarianceLoss)
855+
- [MultiSimilarityLoss](losses.md#MultiSimilarityLoss)
856+
- [NTXentLoss](losses.md#NTXentLoss)
857+
- [SignalToNoiseRatioContrastiveLoss](losses.md#SignalToNoiseRatioContrastiveLoss)
858+
- [SupConLoss](losses.md#SupConLoss)
859+
- [TripletMarginLoss](losses.md#TripletMarginLoss)
860+
- [TupletMarginLoss](losses.md#TupletMarginLoss)
861+
862+
863+
790864
## SignalToNoiseRatioContrastiveLoss
791865
[Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf){target=_blank}
792866
```python
@@ -1023,11 +1097,11 @@ losses.TupletMarginLoss(margin=5.73, scale=64, **kwargs)
10231097
* **margin**: The angular margin (in degrees) applied to positive pairs. This is beta in the above equation. The paper uses a value of 5.73 degrees (0.1 radians).
10241098
* **scale**: This is ```s``` in the above equation.
10251099

1026-
The paper combines this loss with [IntraPairVarianceLoss](losses.md#intrapairvarianceloss). You can accomplish this by using [MultipleLosses](wrappers.md#multiplelosses):
1100+
The paper combines this loss with [IntraPairVarianceLoss](losses.md#intrapairvarianceloss). You can accomplish this by using [MultipleLosses](losses.md#multiplelosses):
10271101
```python
10281102
main_loss = losses.TupletMarginLoss()
10291103
var_loss = losses.IntraPairVarianceLoss()
1030-
complete_loss = wrappers.MultipleLosses([main_loss, var_loss], weights=[1, 0.5])
1104+
complete_loss = losses.MultipleLosses([main_loss, var_loss], weights=[1, 0.5])
10311105
```
10321106

10331107
**Default distance**:

docs/wrappers.md

-85
This file was deleted.

src/pytorch_metric_learning/losses/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .angular_loss import AngularLoss
22
from .arcface_loss import ArcFaceLoss
3+
from .base_loss_wrapper import BaseLossWrapper
34
from .base_metric_loss_function import BaseMetricLossFunction
45
from .circle_loss import CircleLoss
56
from .contrastive_loss import ContrastiveLoss
67
from .cosface_loss import CosFaceLoss
8+
from .cross_batch_memory import CrossBatchMemory
79
from .fast_ap_loss import FastAPLoss
810
from .generic_pair_loss import GenericPairLoss
911
from .instance_loss import InstanceLoss
@@ -13,12 +15,14 @@
1315
from .margin_loss import MarginLoss
1416
from .mixins import EmbeddingRegularizerMixin, WeightRegularizerMixin
1517
from .multi_similarity_loss import MultiSimilarityLoss
18+
from .multiple_losses import MultipleLosses
1619
from .n_pairs_loss import NPairsLoss
1720
from .nca_loss import NCALoss
1821
from .normalized_softmax_loss import NormalizedSoftmaxLoss
1922
from .ntxent_loss import NTXentLoss
2023
from .proxy_anchor_loss import ProxyAnchorLoss
2124
from .proxy_losses import ProxyNCALoss
25+
from .self_supervised_loss import SelfSupervisedLoss
2226
from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss
2327
from .soft_triple_loss import SoftTripleLoss
2428
from .sphereface_loss import SphereFaceLoss

src/pytorch_metric_learning/utils/distributed.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import torch
22

3-
from ..losses import BaseMetricLossFunction
3+
from ..losses import BaseMetricLossFunction, CrossBatchMemory
44
from ..miners import BaseMiner
55
from ..utils import common_functions as c_f
66
from ..utils import loss_and_miner_utils as lmu
7-
from ..wrappers import CrossBatchMemory
87

98

109
# modified from https://github.com/allenai/allennlp

tests/utils/test_distributed.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import torch.optim as optim
99
from torch.nn.parallel import DistributedDataParallel as DDP
1010

11-
from pytorch_metric_learning.losses import ContrastiveLoss
11+
from pytorch_metric_learning.losses import ContrastiveLoss, CrossBatchMemory
1212
from pytorch_metric_learning.miners import PairMarginMiner
1313
from pytorch_metric_learning.utils import distributed
14-
from pytorch_metric_learning.wrappers import CrossBatchMemory
1514

1615
from .. import TEST_DEVICE, TEST_DTYPES
1716

tests/wrappers/test_cross_batch_memory_wrapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytorch_metric_learning.losses as losses
66
from pytorch_metric_learning.losses import (
77
ContrastiveLoss,
8+
CrossBatchMemory,
89
MultiSimilarityLoss,
910
NTXentLoss,
1011
)
@@ -15,7 +16,6 @@
1516
TripletMarginMiner,
1617
)
1718
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
18-
from pytorch_metric_learning.wrappers import CrossBatchMemory
1919

2020
from .. import TEST_DEVICE, TEST_DTYPES
2121
from ..zzz_testing_utils.testing_utils import angle_to_coord

tests/wrappers/test_multiple_losses_wrapper.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
import torch
44

5-
from pytorch_metric_learning.losses import ContrastiveLoss, TripletMarginLoss
5+
from pytorch_metric_learning.losses import (
6+
ContrastiveLoss,
7+
MultipleLosses,
8+
TripletMarginLoss,
9+
)
610
from pytorch_metric_learning.miners import MultiSimilarityMiner
7-
from pytorch_metric_learning.wrappers import MultipleLosses
811

912
from .. import TEST_DEVICE, TEST_DTYPES
1013
from ..zzz_testing_utils.testing_utils import angle_to_coord

tests/wrappers/test_self_supervised_loss_wrapper.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44

55
import pytorch_metric_learning.losses as losses
6-
from pytorch_metric_learning.wrappers import SelfSupervisedLoss
76

87
from .. import TEST_DEVICE, TEST_DTYPES
98

@@ -60,14 +59,14 @@ def run_all_loss_fns_wrapped(self, embeddings, ref_emb):
6059
loss_fns = dict()
6160
for loss_fn in loss_fns_list:
6261
loss_name = type(loss_fn).__name__
63-
loss_fn = SelfSupervisedLoss(loss_fn)
62+
loss_fn = losses.SelfSupervisedLoss(loss_fn)
6463
loss_value = loss_fn(embeddings=embeddings, ref_emb=ref_emb)
6564
loss_fns[loss_name] = loss_value
6665

6766
return loss_fns
6867

6968
def load_valid_loss_fns(self):
70-
supported_losses = SelfSupervisedLoss.supported_losses()
69+
supported_losses = losses.SelfSupervisedLoss.supported_losses()
7170

7271
loss_fns = [
7372
losses.AngularLoss(),

0 commit comments

Comments
 (0)