Skip to content

Parametrizations tutorial #1444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 19, 2021
Merged

Parametrizations tutorial #1444

merged 19 commits into from
Apr 19, 2021

Conversation

lezcano
Copy link
Contributor

@lezcano lezcano commented Mar 25, 2021

Creates the tutorial for the parametrizations functionality. This was discussed in the issue pytorch/pytorch#7313 and implemented in the PR pytorch/pytorch#33344

cc @albanD @IvanYashchuk @toshas

@netlify
Copy link

netlify bot commented Mar 25, 2021

Deploy preview for pytorch-tutorials-preview ready!

Built with commit 496327c

https://deploy-preview-1444--pytorch-tutorials-preview.netlify.app

@IvanYashchuk IvanYashchuk changed the title Parametrizaitons tutorial Parameterizations tutorial Mar 25, 2021
@IvanYashchuk IvanYashchuk changed the title Parameterizations tutorial Parametrizations tutorial Mar 25, 2021
@brianjo
Copy link
Contributor

brianjo commented Mar 26, 2021

Looks like this requires PyTorch 1.9 which we don't build on yet. You can remove the _tutorial from the file name to add it without running the code. If you want to wait for 1.9 to publish, I have some other options for testing this.


###############################################################################
# We can then use this idea to implement a linear layer with symmetric weights:
class LinearSymmetric(nn.Module):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less of a straw-man (for this simple case) would be a module deriving from Linear and adding a weight property

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That one's tricky. If you think about it, it'd go as something like this but with a bit more flare:

import torch.nn as nn

class MyLin(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._weight = self.weight  # Why does this line work?
        delattr(self, "weight")

    @property
    def weight(self):
        return self._weight

lin = MyLin(3, 4)

The line above works because of how nn.Module handles __getitem__ and __setitem__. One could almost say that works "out of pure chance". What happens in that line is:

  • It calls __getattribute__
  • __getattribute__ finds the property self.weight and calls it
  • The property looks for self._weight. At that time it does not exist, so it raises an AttributeError.
  • Since __getattribute__ got an AttributeError, nn.Module.__getattr__ is called
  • nn.Module.__getattr__ finds the self.weight that was created in nn.Linear.__init__ and returns it

This is quite a mess really. That's why I went for the simpler method here.

Copy link

@vadimkantorov vadimkantorov Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had used register_parameter('weight', None) instead: https://gist.github.com/vadimkantorov/4f34fe60d2ef00e72dcad16512d224af, seems to work for Conv1d. And maybe even directly delattr(self, 'weight') could work?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the trick is that you still want to access old _weight, while I didn't need it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works for exactly the same convoluted reason, but changing __getattr__ and __getattribute__ with __setattr__ and __setattribute__. Again, it works, but it's tricky to know why it works.

# are properly registered as submodules of the original module. As such, the same rules
# for registering parameters in a module apply to register a parametrization.
# For example, if a parametrization has parameters, these will be moved from CPU
# to CUDA when calling ``model = model.cuda()``.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be good to add a note explaining how reparametrization magic is implemented under the hood

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that that would be too technical (it uses dynamically generated classes). I think something that could be done would be to stress more what happens after you call register_parametrization. In particular, the fact that it creates a ModuleDict under module.parametrizations, and each of those modules is a ParametrizationList and so on.
In fact, that's all register_parametrization is doing modulo the dynamically generated classes magic. This would also help clarifying why it can be used with nn.Modules but not with plain old functions.

Copy link

@vadimkantorov vadimkantorov Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still better to mention it (even if technical), then there wouldn't be a surprise of some generated class name if there is some exception (and especially when users debug it interactively. does it affect interactive debugging in ide? that's a valid question). Also this would be a useful reminder for the user to not do any similar magics if they fear interference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree that such explanation would be nice, but I'm not sure this tutorial is the right place for it.
A Note in the doc similar to the ones we have about autograd herre for example sound more appropriate.

# matrices. Using these two facts, we may reuse the parametrizations
class MatrixExponential(nn.Module):
def forward(X):
return torch.matrix_exp(X)
Copy link

@vadimkantorov vadimkantorov Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could torch.matrix_exp be directly used instead of MatrixExponential()? In both cases, I think this should be discussed explicitly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it could not. At the moment Parametrizations are defined to be nn.Modules, so they do not support the functional API.

I do not know whether it would be necessary to discuss this, as in the whole tutorial it's been made clear that a parametrization is just a plain nn.Module. What do you think @albanD ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning: in many places in PyTorch one expects both functions and module objects work (especially in older areas of PyTorch). Whenever this is violated - quantization, module tracing etc, I'm always suspicious of undeclared magic :)

In this case, it begs the question, since the module here is just a wrapper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is indeed an interesting question.
But for now everything is expected to be a nn.Module and that sounds enough for me.
If we want to relax that in the future we might be able to do so but that would be out of scope of this tutorial.

return A

###############################################################################
# In this case, it is not true that ``forward(right_inverse(X)) == X``. This is

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would reparametrizations work for double-backward? A related question asked here: pytorch/pytorch#55368

Copy link
Contributor Author

@lezcano lezcano Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good point. At the moment that example but using

import torch.nn.utils.parametrize as P
model = P.register_parametrization(torch.nn.Linear(2, 5), "weight", torch.nn.ReLU())

also breaks with the same error message. That being said, it smells like there's a problem in the implementation of register_parametrizaton (?). Perhaps @albanD can give a bit more insight on what's going on..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no reason it wouldn't.
Your implementation of the reparametrization will need to be double differentiable though.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's also worth an explicit discussion somewhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean about the relu part though. This works fine on colab for me:

# !pip uninstall --y torch
# !pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html

import torch
import torch.nn.utils.parametrize as P
model = P.register_parametrization(torch.nn.Linear(2, 5), "weight", torch.nn.ReLU())


inp = torch.rand(1, 2, requires_grad=True)

out = model(inp)

g = torch.autograd.grad(out.sum(), model.parameters(), create_graph=True)

print(g)

g[1].exp().sum().backward()

print(inp.grad)

Copy link
Contributor Author

@lezcano lezcano Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that this example (which is a modification of that in that PR) breaks:

import torch
import torch.nn.utils.parametrize as P
model = P.register_parametrization(torch.nn.Linear(2, 5), "weight", torch.nn.ReLU())
opt1 = torch.optim.SGD(model.parameters(), lr=1e-3)
opt2 = torch.optim.SGD(model.parameters(), lr=1e-3)

output = model(torch.randn(7, 2))
loss = output.abs().mean()

opt1.zero_grad(); loss.backward(retain_graph=True); opt1.step() # first propagation
opt2.zero_grad(); loss.backward(); opt2.step() # second

Copy link
Contributor

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks quite good to me.
Only phrasing and minor comments.

return A

###############################################################################
# In this case, it is not true that ``forward(right_inverse(X)) == X``. This is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no reason it wouldn't.
Your implementation of the reparametrization will need to be double differentiable though.

@lezcano
Copy link
Contributor Author

lezcano commented Apr 7, 2021

@albanD I addressed the points that you raised and I corrected a few other things (the code does not break now... I had forgotten to check that...). Even then, there were no major changes in the text.

Copy link
Contributor

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor update and it looks mostly good for me.

@brianjo brianjo merged commit dc5c41c into pytorch:master Apr 19, 2021
rodrigo-techera pushed a commit to Experience-Monks/tutorials that referenced this pull request Nov 29, 2021
* Parametrizaitons tutorial

* Add remove_parametrization

* Correct name

* minor

* Proper version number

* Fuzzy spellcheck

* version

* Remove _tutorial from name

* Forgot to add the file...

* Rename parametrizations_tutorial by parametrizations everywhere
Add Alban's suggestions
Correct the code
Beter spacing after enumeration

* Minor

* Add more comments

* Minor

* Prefer unicode over math

* Minor

* minor

* Corrections

Co-authored-by: Brian Johnson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants