Skip to content

Commit 8bef3ff

Browse files
lezcanoSacha Refshauge
authored and
Sacha Refshauge
committed
Parametrization Functionality (pytorch#33344)
Summary: Provides the implementation for feature request issue pytorch#28937. Adds the `Parametrization` functionality and implements `Pruning` on top of it. It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example. It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions. As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...) TODO (when implementation is validated): - More thorough test - Documentation Resolves pytorch#28937 albanD Pull Request resolved: pytorch#33344 Reviewed By: zhangguanheng66 Differential Revision: D26816708 Pulled By: albanD fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
1 parent 9dc6c53 commit 8bef3ff

File tree

3 files changed

+723
-0
lines changed

3 files changed

+723
-0
lines changed

docs/source/nn.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,21 @@ From the ``torch.nn.utils`` module
346346
parameters_to_vector
347347
vector_to_parameters
348348

349+
.. autosummary::
350+
:toctree: generated
351+
:nosignatures:
352+
353+
parametrize.register_parametrization
354+
parametrize.remove_parametrizations
355+
parametrize.cached
356+
parametrize.is_parametrized
357+
349358
.. autosummary::
350359
:toctree: generated
351360
:nosignatures:
352361
:template: classtemplate.rst
353362

363+
parametrize.ParametrizationList
354364
prune.BasePruningMethod
355365

356366
.. autosummary::

test/test_nn.py

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch.nn.init as init
2828
import torch.nn.utils.rnn as rnn_utils
2929
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
30+
import torch.nn.utils.parametrize as parametrize
3031
import torch.nn.utils.prune as prune
3132
from torch.nn.utils import parameters_to_vector, vector_to_parameters
3233
from torch.nn import Parameter
@@ -1939,6 +1940,354 @@ def test_vector_to_parameters(self):
19391940
sample = next(model.parameters())[0, 0, 0]
19401941
self.assertTrue(torch.equal(sample.data, vec.data[:5]))
19411942

1943+
# torch/nn/utils/parametrize
1944+
def test_register_and_remove_parametrization(self):
1945+
r"""Test that it is possible to add a few parametrizations
1946+
on a parameter or a buffer and that removing them restores the initial state
1947+
It also tests that backpropagating through them works as expected
1948+
"""
1949+
# Define a couple matrix parametrizations
1950+
class Skew(nn.Module):
1951+
def forward(self, X):
1952+
X = X.tril(-1)
1953+
return X - X.T
1954+
1955+
class Orthogonal(nn.Module):
1956+
def forward(self, X):
1957+
# Cayley map
1958+
# If X is skew-symmetric it returns an orthogonal matrix
1959+
Id = torch.eye(X.size(0), device=X.device)
1960+
return torch.solve(Id - X, Id + X).solution
1961+
1962+
# Define a couple vector parametrizations
1963+
class FirstZero(nn.Module):
1964+
def forward(self, x):
1965+
return torch.cat([x.new_zeros(1), x[1:]])
1966+
1967+
class LastZero(nn.Module):
1968+
def forward(self, x):
1969+
return torch.cat([x[:-1], x.new_zeros(1)])
1970+
1971+
model = nn.Linear(8, 8)
1972+
initial_weight_id = id(model.weight)
1973+
initial_bias_id = id(model.bias)
1974+
initial_model = deepcopy(model)
1975+
1976+
# Test one parametrization
1977+
parametrize.register_parametrization(model, "weight", Skew())
1978+
self.assertTrue(hasattr(model, "parametrizations"))
1979+
self.assertTrue(parametrize.is_parametrized(model))
1980+
self.assertTrue(parametrize.is_parametrized(model, "weight"))
1981+
self.assertFalse(parametrize.is_parametrized(model, "bias"))
1982+
self.assertNotIn("weight", model._parameters)
1983+
# Result should be skew-symmetric
1984+
A = model.weight
1985+
self.assertTrue(torch.allclose(A, -A.T))
1986+
# Remove and check consistency
1987+
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
1988+
self.assertFalse(hasattr(model, "parametrizations"))
1989+
self.assertEqual(model.weight, initial_model.weight)
1990+
self.assertEqual(id(model.weight), initial_weight_id)
1991+
self.assertEqual(model.__class__, nn.Linear)
1992+
1993+
# Test two parametrizations at the same time and removing them
1994+
parametrize.register_parametrization(model, "weight", Skew())
1995+
parametrize.register_parametrization(model, "weight", Orthogonal())
1996+
# Result should be orthogonal
1997+
X = model.weight
1998+
Id = torch.eye(X.size(0), device=X.device)
1999+
self.assertTrue(torch.allclose(X.T @ X, Id))
2000+
# Structure tests
2001+
self.assertTrue(hasattr(model, "parametrizations"))
2002+
self.assertTrue(parametrize.is_parametrized(model))
2003+
self.assertTrue(parametrize.is_parametrized(model, "weight"))
2004+
self.assertFalse(parametrize.is_parametrized(model, "bias"))
2005+
self.assertIn("weight", model.parametrizations)
2006+
self.assertNotIn("weight", model._parameters)
2007+
# Remove
2008+
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
2009+
self.assertEqual(model.weight, initial_model.weight)
2010+
self.assertEqual(id(model.weight), initial_weight_id)
2011+
self.assertFalse(hasattr(model, "parametrizations"))
2012+
self.assertEqual(model.__class__, nn.Linear)
2013+
2014+
# Add everything
2015+
parametrize.register_parametrization(model, "weight", Skew())
2016+
parametrize.register_parametrization(model, "weight", Orthogonal())
2017+
parametrize.register_parametrization(model, "bias", FirstZero())
2018+
parametrize.register_parametrization(model, "bias", LastZero())
2019+
2020+
# Basic tests
2021+
self.assertTrue(parametrize.is_parametrized(model))
2022+
self.assertTrue(parametrize.is_parametrized(model, "weight"))
2023+
self.assertTrue(parametrize.is_parametrized(model, "bias"))
2024+
self.assertEqual(model.bias[0].item(), 0.)
2025+
self.assertEqual(model.bias[-1].item(), 0.)
2026+
self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happpened
2027+
# Should not throw
2028+
(model.weight.T @ model.bias).sum().backward()
2029+
with torch.no_grad():
2030+
for p in model.parameters():
2031+
p.add_(- p.grad, alpha=0.01)
2032+
2033+
# Remove first parametrization.
2034+
# Check that the model is still parametrized and so is the second parameter
2035+
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
2036+
self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized
2037+
self.assertFalse(parametrize.is_parametrized(model, "weight")) # Parametrization removed
2038+
self.assertTrue(parametrize.is_parametrized(model, "bias")) # Still parametrized
2039+
self.assertEqual(model.bias[0].item(), 0.) # Still parametrized
2040+
self.assertEqual(model.bias[-1].item(), 0.) # Still parametrized
2041+
self.assertNotEqual(model.weight, initial_model.weight) # Has been updated
2042+
self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id
2043+
self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened
2044+
# Should not throw
2045+
(model.weight.T @ model.bias).sum().backward()
2046+
with torch.no_grad():
2047+
for p in model.parameters():
2048+
p.add_(- p.grad, alpha=0.01)
2049+
2050+
# Remove the second parametrization.
2051+
# Check that the module is not parametrized
2052+
parametrize.remove_parametrizations(model, "bias", leave_parametrized=False)
2053+
self.assertFalse(parametrize.is_parametrized(model)) # Still parametrized
2054+
self.assertNotEqual(model.bias, initial_model.bias) # Has been updated
2055+
self.assertNotEqual(model.bias[0].item(), 0.) # Still parametrized
2056+
self.assertNotEqual(model.bias[-1].item(), 0.) # Still parametrized
2057+
self.assertEqual(id(model.bias), initial_bias_id)
2058+
self.assertFalse(hasattr(model, "parametrizations"))
2059+
self.assertEqual(model.__class__, nn.Linear)
2060+
self.assertEqual(len(list(model.parameters())), 2)
2061+
# Should not throw
2062+
(model.weight.T @ model.bias).sum().backward()
2063+
with torch.no_grad():
2064+
for p in model.parameters():
2065+
p.add_(- p.grad, alpha=0.01)
2066+
2067+
def test_register_and_remove_buffer_parametrization(self):
2068+
r"""Test that it is possible to add and remove parametrizations on buffers"""
2069+
# Define a couple vector parametrizations
2070+
class FirstZero(nn.Module):
2071+
def forward(self, x):
2072+
return torch.cat([x.new_zeros(1), x[1:]])
2073+
2074+
class LastZero(nn.Module):
2075+
def forward(self, x):
2076+
return torch.cat([x[:-1], x.new_zeros(1)])
2077+
2078+
model = nn.Linear(8, 8)
2079+
2080+
# Instantiate parametrizations on buffers. It should work as expected
2081+
delattr(model, "bias")
2082+
model.register_buffer("bias", torch.ones(8))
2083+
parametrize.register_parametrization(model, "bias", FirstZero())
2084+
parametrize.register_parametrization(model, "bias", LastZero())
2085+
self.assertTrue(parametrize.is_parametrized(model))
2086+
self.assertTrue(parametrize.is_parametrized(model, "bias"))
2087+
self.assertEqual(model.bias[0].item(), 0.)
2088+
self.assertEqual(model.bias[-1].item(), 0.)
2089+
self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
2090+
self.assertEqual(len(list(model.parameters())), 1)
2091+
2092+
# Remove parametrizations on buffers. It should work as expected
2093+
parametrize.remove_parametrizations(model, "bias", leave_parametrized=True)
2094+
self.assertFalse(parametrize.is_parametrized(model))
2095+
self.assertFalse(parametrize.is_parametrized(model, "bias"))
2096+
self.assertEqual(model.bias[0].item(), 0.)
2097+
self.assertEqual(model.bias[-1].item(), 0.)
2098+
self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
2099+
self.assertEqual(len(list(model.parameters())), 1)
2100+
2101+
def test_serialization_parametrization(self):
2102+
r"""Test that it is possible to serialize a parametrized model via state_dict"""
2103+
# A stateful parametrization
2104+
class Orthogonal(nn.Module):
2105+
def __init__(self, n):
2106+
super().__init__()
2107+
self.register_buffer("id", torch.eye(n))
2108+
self.register_buffer("B", torch.empty(n, n))
2109+
init.orthogonal_(self.B)
2110+
2111+
def forward(self, X):
2112+
A = X.triu(1)
2113+
A = A - A.T
2114+
return self.B @ torch.solve(self.id - A, self.id + A).solution
2115+
2116+
def get_model():
2117+
model = torch.nn.Sequential(
2118+
torch.nn.Linear(5, 5),
2119+
torch.nn.ReLU(),
2120+
torch.nn.Linear(5, 1),
2121+
)
2122+
2123+
parametrize.register_parametrization(model[0], "weight", Orthogonal(5))
2124+
return model
2125+
2126+
model = get_model()
2127+
2128+
prev_weight = model[0].weight
2129+
prev_B = model[0].parametrizations.weight[0].B
2130+
2131+
new_model = get_model()
2132+
with TemporaryFileName() as fname:
2133+
torch.save(model.state_dict(), fname)
2134+
new_model.load_state_dict(torch.load(fname))
2135+
2136+
# Integrity tests
2137+
self.assertTrue(parametrize.is_parametrized(new_model[0], "weight"))
2138+
self.assertEqual(prev_weight, new_model[0].weight)
2139+
self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B)
2140+
2141+
# Trying to save the whole parametrized model raises
2142+
with self.assertRaisesRegex(RuntimeError, "state_dict"):
2143+
with TemporaryFileName() as fname:
2144+
torch.save(model, fname)
2145+
2146+
def test_initialization_parametrization(self):
2147+
r"""Test that it is possible to initialize a parametrization when it
2148+
implements a `right_inverse` method
2149+
"""
2150+
class Skew(nn.Module):
2151+
def forward(self, X):
2152+
A = X.triu(1)
2153+
return A - A.T
2154+
2155+
def is_skew(self, A):
2156+
return torch.allclose(A, -A.T, atol=1e-6)
2157+
2158+
def right_inverse(self, X):
2159+
if not self.is_skew(X):
2160+
raise ValueError("The matrix is not skew-symmetric.")
2161+
return X.triu(1)
2162+
2163+
# Implements a Cayley map where right_inverse is not quite the inverse of forward
2164+
class Orthogonal(nn.Module):
2165+
def __init__(self, n):
2166+
super().__init__()
2167+
self.register_buffer("B", torch.eye(n))
2168+
2169+
def forward(self, A):
2170+
Id = torch.eye(X.size(0))
2171+
return self.B @ torch.solve(Id - A, Id + A).solution
2172+
2173+
def is_orthogonal(self, X):
2174+
Id = torch.eye(X.size(0))
2175+
return torch.allclose(X.T @ X, Id, atol=1e-4)
2176+
2177+
def right_inverse(self, X):
2178+
if not self.is_orthogonal(X):
2179+
raise ValueError("The input is not orthogonal.")
2180+
# cayley(0) == Id, so B @ cayley(0) == B
2181+
self.B = X
2182+
return torch.zeros_like(X)
2183+
2184+
N = 5
2185+
model = nn.Linear(N, N)
2186+
# Register the skew-symmetric onstraint. The result is now skew-symmetric
2187+
parametrize.register_parametrization(model, "weight", Skew())
2188+
X = torch.rand(N, N)
2189+
# X is not skew-symmetric, so it throws an error
2190+
with self.assertRaises(ValueError):
2191+
model.weight = X
2192+
# Make X skew-symmetric
2193+
X = X - X.T
2194+
model.weight = X
2195+
self.assertEqual(model.parametrizations.weight.original, X.triu(1))
2196+
self.assertEqual(model.weight, X)
2197+
2198+
# Having several parametrizations registered should work in the same way
2199+
parametrize.register_parametrization(model, "weight", Orthogonal(N))
2200+
# Register now the Cayley map. The result is now orthogonal
2201+
X = torch.rand(N, N)
2202+
# X is not orthogonal, so it throws an error
2203+
with self.assertRaises(ValueError):
2204+
model.weight = X
2205+
init.orthogonal_(X)
2206+
model.weight = X
2207+
self.assertEqual(model.weight, X)
2208+
self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))
2209+
2210+
def test_errors_parametrization(self):
2211+
# A parametrization shall not change the size of the parameter
2212+
class ChangeSize(nn.Module):
2213+
def forward(self, x):
2214+
return x[:-1]
2215+
2216+
# A simple parametrization that does not implement a right_inverse
2217+
class Double(nn.Module):
2218+
def forward(self, x):
2219+
return 2 * x
2220+
2221+
module = nn.Linear(3, 4)
2222+
# This should not throw when registering
2223+
parametrize.register_parametrization(module, "weight", ChangeSize())
2224+
# It throws in the forward
2225+
with self.assertRaisesRegex(RuntimeError, "may not change the size"):
2226+
module(torch.rand(2))
2227+
# Undo
2228+
parametrize.remove_parametrizations(module, "weight", leave_parametrized=False)
2229+
self.assertFalse(parametrize.is_parametrized(module))
2230+
2231+
# Removing a parametrization from an unparametrized tensor throws
2232+
with self.assertRaisesRegex(ValueError, "does not have a parametrization"):
2233+
parametrize.remove_parametrizations(module, "bias")
2234+
# Nothing odd happens
2235+
self.assertFalse(parametrize.is_parametrized(module))
2236+
2237+
# Register a parametrization on a non-existing parameter breaks
2238+
with self.assertRaisesRegex(ValueError, "does not have a parameter"):
2239+
parametrize.register_parametrization(module, "foo", ChangeSize())
2240+
self.assertFalse(parametrize.is_parametrized(module))
2241+
2242+
# Try to assign to a parametrization that does not implement `right_inverse`
2243+
parametrize.register_parametrization(module, "weight", Double())
2244+
with self.assertRaisesRegex(RuntimeError, "right_inverse"):
2245+
module.weight = torch.rand(4, 3)
2246+
# Undo
2247+
parametrize.remove_parametrizations(module, "weight", leave_parametrized=False)
2248+
self.assertFalse(parametrize.is_parametrized(module))
2249+
2250+
def test_caching_parametrization(self):
2251+
r"""Test the caching system of a parametrization"""
2252+
# Define a couple matrix parametrizations
2253+
class Skew(nn.Module):
2254+
def forward(self, X):
2255+
X = X.tril(-1)
2256+
return X - X.T
2257+
2258+
class Orthogonal(nn.Module):
2259+
def forward(self, X):
2260+
Id = torch.eye(X.size(0), device=X.device)
2261+
return torch.solve(Id - X, Id + X).solution
2262+
2263+
model = nn.Linear(5, 5)
2264+
parametrize.register_parametrization(model, "weight", Skew())
2265+
parametrize.register_parametrization(model, "weight", Orthogonal())
2266+
2267+
# Test that the caching system works
2268+
with parametrize.cached():
2269+
X = model.weight
2270+
Y = model.weight
2271+
self.assertEqual(id(X), id(Y))
2272+
2273+
def test_dtype_parametrization(self):
2274+
r"""Test a case that is not allowed when removing a parametrization"""
2275+
class ChangeType(nn.Module):
2276+
def forward(self, X):
2277+
return X.double()
2278+
2279+
module = nn.Linear(4, 4).float()
2280+
input_ = torch.rand(4).double()
2281+
# It is allowed to register a parametrization that changes the dtype
2282+
parametrize.register_parametrization(module, "weight", ChangeType())
2283+
module(input_)
2284+
# We can remove it leaving the original tensor
2285+
parametrize.remove_parametrizations(module, "weight", leave_parametrized=False)
2286+
# But leaving it parametrized breaks
2287+
parametrize.register_parametrization(module, "weight", ChangeType())
2288+
with self.assertRaisesRegex(ValueError, "changes the dtype"):
2289+
parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
2290+
19422291
# torch/nn/utils/prune.py
19432292
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
19442293
def test_validate_pruning_amount_init(self):

0 commit comments

Comments
 (0)