@@ -7,33 +7,137 @@ neighbors, similarity search, transfer learning, or data analytics.
7
7
8
8
Additionally, you can use the Lightly framework to directly interact with the `lightly platform <https://www.lightly.ai >`_.
9
9
10
- Walk-through of an example using Lightly
11
- ----------------------------------------
12
- In this short example, we will train a model using self-supervision and use it to
13
- create embeddings.
10
+ How Lightly Works
11
+ -----------------
12
+ The flexible design of Lightly makes it easy to integrate in your Python code. Lightly is built completely around PyTorch
13
+ frameworks and the different pieces can be put together to fit *your * requirements.
14
+
15
+ Data and Transformations
16
+ ^^^^^^^^^^^^^^^^^^^^^^^^
17
+ The basic building block of self-supervised methods
18
+ such as `SimCLR <https://arxiv.org/abs/2002.05709 >`_ are image transformations. Each image is transformed into
19
+ two new images by randomly applied augmentations. The task of the self-supervised model is then to identify the
20
+ images which come from the same original among a set of negative examples.
21
+
22
+ Lightly implements these transformations
23
+ as torchvision transforms in the collate function of the dataloader. For example, the collate
24
+ function below will apply two different, randomized transforms to each image: A randomized resized crop and a
25
+ random color jitter.
14
26
15
27
.. code-block :: python
16
28
17
- from lightly import train_embedding_model
18
- from lightly import embed_images
29
+ import lightly.data as data
19
30
20
- # first we train our model for 1 epoch using a folder of cat images 'cats'
21
- checkpoint = train_embedding_model( input_dir = ' cats ' , trainer = { ' max_epochs ' : 1 } )
31
+ # the collate function applies random transforms to the input images
32
+ collate_fn = data.ImageCollateFunction( input_size = 32 , cj_prob = 0.5 )
22
33
23
- # let's embed our 'cats' dataset using our trained model
24
- embeddings, labels, filenames = embed_images(input_dir = ' cats' , checkpoint = checkpoint)
34
+ Let's now load an image dataset and create a PyTorch dataloader with the collate function from above.
25
35
26
- # now, let's inspect the shape of our embeddings
27
- print (embeddings.shape)
36
+ .. code-block :: python
37
+
38
+ import torch
39
+
40
+ # create a dataset from your image folder
41
+ dataset = data.LightlyDataset(from_folder = ' ./my/cute/cats/dataset/' )
42
+
43
+ # build a PyTorch dataloader
44
+ dataloader = torch.utils.data.DataLoader(
45
+ dataset, # pass the dataset to the dataloader
46
+ batch_size = 128 , # a large batch size helps with the learning
47
+ shuffle = True , # shuffling is important!
48
+ collate_fn = collate_fn) # apply transformations to the input images
49
+
50
+ Head to the next section to see how you can train a ResNet on the data you just prepared.
51
+
52
+ Training
53
+ ^^^^^^^^
54
+
55
+ Now, we need an embedding model, an optimizer and a loss function. We use a ResNet together
56
+ with the normalized temperature-scaled cross entropy loss and simple stochastic gradient descent.
57
+
58
+ .. code-block :: python
59
+
60
+ import lightly.models as models
61
+ import lightly.loss as loss
62
+
63
+ # build a resnet-34 with 32 embedding neurons
64
+ model = models.ResNetSimCLR(name = ' resnet-34' , num_ftrs = 32 )
65
+
66
+ # use a criterion for self-supervised learning
67
+ # (normalized temperature-scaled cross entropy loss)
68
+ criterion = loss.NTXentLoss(temperature = 0.5 )
69
+
70
+ # get a PyTorch optimizer
71
+ optimizer = torch.optim.SGD(model.parameters(), lr = 1e-0 , weight_decay = 1e-5 )
72
+
73
+ Put everything together in an embedding model and train it for 10 epochs on a single GPU.
74
+
75
+ .. code-block :: python
76
+
77
+ import lightly.embedding as embedding
78
+
79
+ # put all the pieces together in a single pytorch_lightning trainable!
80
+ embedding_model = embedding.SelfSupervisedEmbedding(
81
+ model,
82
+ criterion,
83
+ optimizer,
84
+ dataloader)
85
+
86
+ # do self-supervised learning for 10 epochs
87
+ embedding_model.train_embedding(gpus = 1 , max_epochs = 10 )
28
88
29
89
Congrats, you just trained your first model using self-supervised learning!
30
90
91
+ Embeddings
92
+ ^^^^^^^^^^
93
+ You can use the trained model to embed your images or even access the embedding
94
+ model directly.
95
+
96
+ .. code-block :: python
97
+
98
+ # make a new dataloader without the transformations
99
+ dataloader = torch.utils.data.DataLoader(
100
+ dataset, # use the same dataset as before
101
+ batch_size = 1 , # we can use batch size 1 for inference
102
+ shuffle = False , # don't shuffle your data during inference
103
+ )
104
+
105
+ # embed your image dataset
106
+ embeddings, labels, filenames = embedding_model.embed(dataloader)
107
+
108
+ # access the ResNet backbone
109
+ resnet = embedding_model.model.features
110
+
111
+ Done! You can continue to use the embeddings to find nearest neighbors or do similarity search.
112
+ Furthermore, the ResNet backbone can be used for transfer and few-shot learning.
113
+
31
114
.. note ::
32
115
Self-supervised learning does not require labels for a model to be trained on. Lightly,
33
116
however, supports the use of additional labels. For example, if you train a model
34
117
on a folder 'cats' with subfolders 'Maine Coon', 'Bengal' and 'British Shorthair'
35
118
Lightly automatically returns the enumerated labels as a list.
36
119
120
+ Lightly in Three Lines
121
+ ----------------------------------------
122
+
123
+ Lightly also offers an easy-to-use interface. The following lines show how the package can
124
+ be used to train a model with self-supervision and create embeddings with only three lines
125
+ of code.
126
+
127
+ .. code-block :: python
128
+
129
+ from lightly import train_embedding_model, embed_images
130
+
131
+ # first we train our model for 10 epochs
132
+ checkpoint = train_embedding_model(input_dir = ' ./my/cute/cats/dataset/' , trainer = {' max_epochs' : 10 })
133
+
134
+ # let's embed our 'cats' dataset using our trained model
135
+ embeddings, labels, filenames = embed_images(input_dir = ' ./my/cute/cats/dataset/' , checkpoint = checkpoint)
136
+
137
+ # now, let's inspect the shape of our embeddings
138
+ print (embeddings.shape)
139
+
140
+
37
141
What's next?
38
142
------------
39
143
Get started by :ref: `rst-installing ` and follow through the tutorial to learn how to get the most out of using Lightly
0 commit comments