-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Parametrizations tutorial #1444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Deploy preview for pytorch-tutorials-preview ready! Built with commit 496327c https://deploy-preview-1444--pytorch-tutorials-preview.netlify.app |
Looks like this requires PyTorch 1.9 which we don't build on yet. You can remove the _tutorial from the file name to add it without running the code. If you want to wait for 1.9 to publish, I have some other options for testing this. |
|
||
############################################################################### | ||
# We can then use this idea to implement a linear layer with symmetric weights: | ||
class LinearSymmetric(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Less of a straw-man (for this simple case) would be a module deriving from Linear and adding a weight
property
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That one's tricky. If you think about it, it'd go as something like this but with a bit more flare:
import torch.nn as nn
class MyLin(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._weight = self.weight # Why does this line work?
delattr(self, "weight")
@property
def weight(self):
return self._weight
lin = MyLin(3, 4)
The line above works because of how nn.Module
handles __getitem__
and __setitem__
. One could almost say that works "out of pure chance". What happens in that line is:
- It calls
__getattribute__
__getattribute__
finds the propertyself.weight
and calls it- The property looks for
self._weight
. At that time it does not exist, so it raises anAttributeError
. - Since
__getattribute__
got anAttributeError
,nn.Module.__getattr__
is called nn.Module.__getattr__
finds theself.weight
that was created innn.Linear.__init__
and returns it
This is quite a mess really. That's why I went for the simpler method here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had used register_parameter('weight', None)
instead: https://gist.github.com/vadimkantorov/4f34fe60d2ef00e72dcad16512d224af, seems to work for Conv1d
. And maybe even directly delattr(self, 'weight')
could work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the trick is that you still want to access old _weight
, while I didn't need it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works for exactly the same convoluted reason, but changing __getattr__
and __getattribute__
with __setattr__
and __setattribute__
. Again, it works, but it's tricky to know why it works.
# are properly registered as submodules of the original module. As such, the same rules | ||
# for registering parameters in a module apply to register a parametrization. | ||
# For example, if a parametrization has parameters, these will be moved from CPU | ||
# to CUDA when calling ``model = model.cuda()``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be good to add a note explaining how reparametrization magic is implemented under the hood
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that that would be too technical (it uses dynamically generated classes). I think something that could be done would be to stress more what happens after you call register_parametrization
. In particular, the fact that it creates a ModuleDict
under module.parametrizations
, and each of those modules is a ParametrizationList
and so on.
In fact, that's all register_parametrization
is doing modulo the dynamically generated classes magic. This would also help clarifying why it can be used with nn.Modules
but not with plain old functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's still better to mention it (even if technical), then there wouldn't be a surprise of some generated class name if there is some exception (and especially when users debug it interactively. does it affect interactive debugging in ide? that's a valid question). Also this would be a useful reminder for the user to not do any similar magics if they fear interference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do agree that such explanation would be nice, but I'm not sure this tutorial is the right place for it.
A Note in the doc similar to the ones we have about autograd herre for example sound more appropriate.
# matrices. Using these two facts, we may reuse the parametrizations | ||
class MatrixExponential(nn.Module): | ||
def forward(X): | ||
return torch.matrix_exp(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could torch.matrix_exp
be directly used instead of MatrixExponential()
? In both cases, I think this should be discussed explicitly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it could not. At the moment Parametrizations
are defined to be nn.Modules
, so they do not support the functional API.
I do not know whether it would be necessary to discuss this, as in the whole tutorial it's been made clear that a parametrization is just a plain nn.Module
. What do you think @albanD ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My reasoning: in many places in PyTorch one expects both functions and module objects work (especially in older areas of PyTorch). Whenever this is violated - quantization, module tracing etc, I'm always suspicious of undeclared magic :)
In this case, it begs the question, since the module here is just a wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is indeed an interesting question.
But for now everything is expected to be a nn.Module and that sounds enough for me.
If we want to relax that in the future we might be able to do so but that would be out of scope of this tutorial.
return A | ||
|
||
############################################################################### | ||
# In this case, it is not true that ``forward(right_inverse(X)) == X``. This is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would reparametrizations work for double-backward? A related question asked here: pytorch/pytorch#55368
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a good point. At the moment that example but using
import torch.nn.utils.parametrize as P
model = P.register_parametrization(torch.nn.Linear(2, 5), "weight", torch.nn.ReLU())
also breaks with the same error message. That being said, it smells like there's a problem in the implementation of register_parametrizaton
(?). Perhaps @albanD can give a bit more insight on what's going on..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no reason it wouldn't.
Your implementation of the reparametrization will need to be double differentiable though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's also worth an explicit discussion somewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean about the relu part though. This works fine on colab for me:
# !pip uninstall --y torch
# !pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
import torch
import torch.nn.utils.parametrize as P
model = P.register_parametrization(torch.nn.Linear(2, 5), "weight", torch.nn.ReLU())
inp = torch.rand(1, 2, requires_grad=True)
out = model(inp)
g = torch.autograd.grad(out.sum(), model.parameters(), create_graph=True)
print(g)
g[1].exp().sum().backward()
print(inp.grad)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant that this example (which is a modification of that in that PR) breaks:
import torch
import torch.nn.utils.parametrize as P
model = P.register_parametrization(torch.nn.Linear(2, 5), "weight", torch.nn.ReLU())
opt1 = torch.optim.SGD(model.parameters(), lr=1e-3)
opt2 = torch.optim.SGD(model.parameters(), lr=1e-3)
output = model(torch.randn(7, 2))
loss = output.abs().mean()
opt1.zero_grad(); loss.backward(retain_graph=True); opt1.step() # first propagation
opt2.zero_grad(); loss.backward(); opt2.step() # second
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks quite good to me.
Only phrasing and minor comments.
return A | ||
|
||
############################################################################### | ||
# In this case, it is not true that ``forward(right_inverse(X)) == X``. This is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no reason it wouldn't.
Your implementation of the reparametrization will need to be double differentiable though.
@albanD I addressed the points that you raised and I corrected a few other things (the code does not break now... I had forgotten to check that...). Even then, there were no major changes in the text. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor update and it looks mostly good for me.
Add Alban's suggestions Correct the code Beter spacing after enumeration
* Parametrizaitons tutorial * Add remove_parametrization * Correct name * minor * Proper version number * Fuzzy spellcheck * version * Remove _tutorial from name * Forgot to add the file... * Rename parametrizations_tutorial by parametrizations everywhere Add Alban's suggestions Correct the code Beter spacing after enumeration * Minor * Add more comments * Minor * Prefer unicode over math * Minor * minor * Corrections Co-authored-by: Brian Johnson <[email protected]>
Creates the tutorial for the parametrizations functionality. This was discussed in the issue pytorch/pytorch#7313 and implemented in the PR pytorch/pytorch#33344
cc @albanD @IvanYashchuk @toshas