Skip to content

Commit e76bae4

Browse files
authored
Merge pull request #5 from lightly-ai/lightly_at_a_glance_pw
Add section "How Lightly Works" to "Getting Started"
2 parents eac43b3 + e8db50a commit e76bae4

File tree

1 file changed

+116
-12
lines changed

1 file changed

+116
-12
lines changed

docs/source/getting_started/lightly_at_a_glance.rst

+116-12
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,137 @@ neighbors, similarity search, transfer learning, or data analytics.
77

88
Additionally, you can use the Lightly framework to directly interact with the `lightly platform <https://www.lightly.ai>`_.
99

10-
Walk-through of an example using Lightly
11-
----------------------------------------
12-
In this short example, we will train a model using self-supervision and use it to
13-
create embeddings.
10+
How Lightly Works
11+
-----------------
12+
The flexible design of Lightly makes it easy to integrate in your Python code. Lightly is built completely around PyTorch
13+
frameworks and the different pieces can be put together to fit *your* requirements.
14+
15+
Data and Transformations
16+
^^^^^^^^^^^^^^^^^^^^^^^^
17+
The basic building block of self-supervised methods
18+
such as `SimCLR <https://arxiv.org/abs/2002.05709>`_ are image transformations. Each image is transformed into
19+
two new images by randomly applied augmentations. The task of the self-supervised model is then to identify the
20+
images which come from the same original among a set of negative examples.
21+
22+
Lightly implements these transformations
23+
as torchvision transforms in the collate function of the dataloader. For example, the collate
24+
function below will apply two different, randomized transforms to each image: A randomized resized crop and a
25+
random color jitter.
1426

1527
.. code-block:: python
1628
17-
from lightly import train_embedding_model
18-
from lightly import embed_images
29+
import lightly.data as data
1930
20-
# first we train our model for 1 epoch using a folder of cat images 'cats'
21-
checkpoint = train_embedding_model(input_dir='cats', trainer={'max_epochs': 1})
31+
# the collate function applies random transforms to the input images
32+
collate_fn = data.ImageCollateFunction(input_size=32, cj_prob=0.5)
2233
23-
# let's embed our 'cats' dataset using our trained model
24-
embeddings, labels, filenames = embed_images(input_dir='cats', checkpoint=checkpoint)
34+
Let's now load an image dataset and create a PyTorch dataloader with the collate function from above.
2535

26-
# now, let's inspect the shape of our embeddings
27-
print(embeddings.shape)
36+
.. code-block:: python
37+
38+
import torch
39+
40+
# create a dataset from your image folder
41+
dataset = data.LightlyDataset(from_folder='./my/cute/cats/dataset/')
42+
43+
# build a PyTorch dataloader
44+
dataloader = torch.utils.data.DataLoader(
45+
dataset, # pass the dataset to the dataloader
46+
batch_size=128, # a large batch size helps with the learning
47+
shuffle=True, # shuffling is important!
48+
collate_fn=collate_fn) # apply transformations to the input images
49+
50+
Head to the next section to see how you can train a ResNet on the data you just prepared.
51+
52+
Training
53+
^^^^^^^^
54+
55+
Now, we need an embedding model, an optimizer and a loss function. We use a ResNet together
56+
with the normalized temperature-scaled cross entropy loss and simple stochastic gradient descent.
57+
58+
.. code-block:: python
59+
60+
import lightly.models as models
61+
import lightly.loss as loss
62+
63+
# build a resnet-34 with 32 embedding neurons
64+
model = models.ResNetSimCLR(name='resnet-34', num_ftrs=32)
65+
66+
# use a criterion for self-supervised learning
67+
# (normalized temperature-scaled cross entropy loss)
68+
criterion = loss.NTXentLoss(temperature=0.5)
69+
70+
# get a PyTorch optimizer
71+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-0, weight_decay=1e-5)
72+
73+
Put everything together in an embedding model and train it for 10 epochs on a single GPU.
74+
75+
.. code-block:: python
76+
77+
import lightly.embedding as embedding
78+
79+
# put all the pieces together in a single pytorch_lightning trainable!
80+
embedding_model = embedding.SelfSupervisedEmbedding(
81+
model,
82+
criterion,
83+
optimizer,
84+
dataloader)
85+
86+
# do self-supervised learning for 10 epochs
87+
embedding_model.train_embedding(gpus=1, max_epochs=10)
2888
2989
Congrats, you just trained your first model using self-supervised learning!
3090

91+
Embeddings
92+
^^^^^^^^^^
93+
You can use the trained model to embed your images or even access the embedding
94+
model directly.
95+
96+
.. code-block:: python
97+
98+
# make a new dataloader without the transformations
99+
dataloader = torch.utils.data.DataLoader(
100+
dataset, # use the same dataset as before
101+
batch_size=1, # we can use batch size 1 for inference
102+
shuffle=False, # don't shuffle your data during inference
103+
)
104+
105+
# embed your image dataset
106+
embeddings, labels, filenames = embedding_model.embed(dataloader)
107+
108+
# access the ResNet backbone
109+
resnet = embedding_model.model.features
110+
111+
Done! You can continue to use the embeddings to find nearest neighbors or do similarity search.
112+
Furthermore, the ResNet backbone can be used for transfer and few-shot learning.
113+
31114
.. note::
32115
Self-supervised learning does not require labels for a model to be trained on. Lightly,
33116
however, supports the use of additional labels. For example, if you train a model
34117
on a folder 'cats' with subfolders 'Maine Coon', 'Bengal' and 'British Shorthair'
35118
Lightly automatically returns the enumerated labels as a list.
36119

120+
Lightly in Three Lines
121+
----------------------------------------
122+
123+
Lightly also offers an easy-to-use interface. The following lines show how the package can
124+
be used to train a model with self-supervision and create embeddings with only three lines
125+
of code.
126+
127+
.. code-block:: python
128+
129+
from lightly import train_embedding_model, embed_images
130+
131+
# first we train our model for 10 epochs
132+
checkpoint = train_embedding_model(input_dir='./my/cute/cats/dataset/', trainer={'max_epochs': 10})
133+
134+
# let's embed our 'cats' dataset using our trained model
135+
embeddings, labels, filenames = embed_images(input_dir='./my/cute/cats/dataset/', checkpoint=checkpoint)
136+
137+
# now, let's inspect the shape of our embeddings
138+
print(embeddings.shape)
139+
140+
37141
What's next?
38142
------------
39143
Get started by :ref:`rst-installing` and follow through the tutorial to learn how to get the most out of using Lightly

0 commit comments

Comments
 (0)