|
27 | 27 | import torch.nn.init as init
|
28 | 28 | import torch.nn.utils.rnn as rnn_utils
|
29 | 29 | from torch.nn.utils import clip_grad_norm_, clip_grad_value_
|
| 30 | +import torch.nn.utils.parametrize as parametrize |
30 | 31 | import torch.nn.utils.prune as prune
|
31 | 32 | from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
32 | 33 | from torch.nn import Parameter
|
@@ -1939,6 +1940,354 @@ def test_vector_to_parameters(self):
|
1939 | 1940 | sample = next(model.parameters())[0, 0, 0]
|
1940 | 1941 | self.assertTrue(torch.equal(sample.data, vec.data[:5]))
|
1941 | 1942 |
|
| 1943 | + # torch/nn/utils/parametrize |
| 1944 | + def test_register_and_remove_parametrization(self): |
| 1945 | + r"""Test that it is possible to add a few parametrizations |
| 1946 | + on a parameter or a buffer and that removing them restores the initial state |
| 1947 | + It also tests that backpropagating through them works as expected |
| 1948 | + """ |
| 1949 | + # Define a couple matrix parametrizations |
| 1950 | + class Skew(nn.Module): |
| 1951 | + def forward(self, X): |
| 1952 | + X = X.tril(-1) |
| 1953 | + return X - X.T |
| 1954 | + |
| 1955 | + class Orthogonal(nn.Module): |
| 1956 | + def forward(self, X): |
| 1957 | + # Cayley map |
| 1958 | + # If X is skew-symmetric it returns an orthogonal matrix |
| 1959 | + Id = torch.eye(X.size(0), device=X.device) |
| 1960 | + return torch.solve(Id - X, Id + X).solution |
| 1961 | + |
| 1962 | + # Define a couple vector parametrizations |
| 1963 | + class FirstZero(nn.Module): |
| 1964 | + def forward(self, x): |
| 1965 | + return torch.cat([x.new_zeros(1), x[1:]]) |
| 1966 | + |
| 1967 | + class LastZero(nn.Module): |
| 1968 | + def forward(self, x): |
| 1969 | + return torch.cat([x[:-1], x.new_zeros(1)]) |
| 1970 | + |
| 1971 | + model = nn.Linear(8, 8) |
| 1972 | + initial_weight_id = id(model.weight) |
| 1973 | + initial_bias_id = id(model.bias) |
| 1974 | + initial_model = deepcopy(model) |
| 1975 | + |
| 1976 | + # Test one parametrization |
| 1977 | + parametrize.register_parametrization(model, "weight", Skew()) |
| 1978 | + self.assertTrue(hasattr(model, "parametrizations")) |
| 1979 | + self.assertTrue(parametrize.is_parametrized(model)) |
| 1980 | + self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| 1981 | + self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| 1982 | + self.assertNotIn("weight", model._parameters) |
| 1983 | + # Result should be skew-symmetric |
| 1984 | + A = model.weight |
| 1985 | + self.assertTrue(torch.allclose(A, -A.T)) |
| 1986 | + # Remove and check consistency |
| 1987 | + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| 1988 | + self.assertFalse(hasattr(model, "parametrizations")) |
| 1989 | + self.assertEqual(model.weight, initial_model.weight) |
| 1990 | + self.assertEqual(id(model.weight), initial_weight_id) |
| 1991 | + self.assertEqual(model.__class__, nn.Linear) |
| 1992 | + |
| 1993 | + # Test two parametrizations at the same time and removing them |
| 1994 | + parametrize.register_parametrization(model, "weight", Skew()) |
| 1995 | + parametrize.register_parametrization(model, "weight", Orthogonal()) |
| 1996 | + # Result should be orthogonal |
| 1997 | + X = model.weight |
| 1998 | + Id = torch.eye(X.size(0), device=X.device) |
| 1999 | + self.assertTrue(torch.allclose(X.T @ X, Id)) |
| 2000 | + # Structure tests |
| 2001 | + self.assertTrue(hasattr(model, "parametrizations")) |
| 2002 | + self.assertTrue(parametrize.is_parametrized(model)) |
| 2003 | + self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| 2004 | + self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| 2005 | + self.assertIn("weight", model.parametrizations) |
| 2006 | + self.assertNotIn("weight", model._parameters) |
| 2007 | + # Remove |
| 2008 | + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| 2009 | + self.assertEqual(model.weight, initial_model.weight) |
| 2010 | + self.assertEqual(id(model.weight), initial_weight_id) |
| 2011 | + self.assertFalse(hasattr(model, "parametrizations")) |
| 2012 | + self.assertEqual(model.__class__, nn.Linear) |
| 2013 | + |
| 2014 | + # Add everything |
| 2015 | + parametrize.register_parametrization(model, "weight", Skew()) |
| 2016 | + parametrize.register_parametrization(model, "weight", Orthogonal()) |
| 2017 | + parametrize.register_parametrization(model, "bias", FirstZero()) |
| 2018 | + parametrize.register_parametrization(model, "bias", LastZero()) |
| 2019 | + |
| 2020 | + # Basic tests |
| 2021 | + self.assertTrue(parametrize.is_parametrized(model)) |
| 2022 | + self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| 2023 | + self.assertTrue(parametrize.is_parametrized(model, "bias")) |
| 2024 | + self.assertEqual(model.bias[0].item(), 0.) |
| 2025 | + self.assertEqual(model.bias[-1].item(), 0.) |
| 2026 | + self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happpened |
| 2027 | + # Should not throw |
| 2028 | + (model.weight.T @ model.bias).sum().backward() |
| 2029 | + with torch.no_grad(): |
| 2030 | + for p in model.parameters(): |
| 2031 | + p.add_(- p.grad, alpha=0.01) |
| 2032 | + |
| 2033 | + # Remove first parametrization. |
| 2034 | + # Check that the model is still parametrized and so is the second parameter |
| 2035 | + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| 2036 | + self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized |
| 2037 | + self.assertFalse(parametrize.is_parametrized(model, "weight")) # Parametrization removed |
| 2038 | + self.assertTrue(parametrize.is_parametrized(model, "bias")) # Still parametrized |
| 2039 | + self.assertEqual(model.bias[0].item(), 0.) # Still parametrized |
| 2040 | + self.assertEqual(model.bias[-1].item(), 0.) # Still parametrized |
| 2041 | + self.assertNotEqual(model.weight, initial_model.weight) # Has been updated |
| 2042 | + self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id |
| 2043 | + self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened |
| 2044 | + # Should not throw |
| 2045 | + (model.weight.T @ model.bias).sum().backward() |
| 2046 | + with torch.no_grad(): |
| 2047 | + for p in model.parameters(): |
| 2048 | + p.add_(- p.grad, alpha=0.01) |
| 2049 | + |
| 2050 | + # Remove the second parametrization. |
| 2051 | + # Check that the module is not parametrized |
| 2052 | + parametrize.remove_parametrizations(model, "bias", leave_parametrized=False) |
| 2053 | + self.assertFalse(parametrize.is_parametrized(model)) # Still parametrized |
| 2054 | + self.assertNotEqual(model.bias, initial_model.bias) # Has been updated |
| 2055 | + self.assertNotEqual(model.bias[0].item(), 0.) # Still parametrized |
| 2056 | + self.assertNotEqual(model.bias[-1].item(), 0.) # Still parametrized |
| 2057 | + self.assertEqual(id(model.bias), initial_bias_id) |
| 2058 | + self.assertFalse(hasattr(model, "parametrizations")) |
| 2059 | + self.assertEqual(model.__class__, nn.Linear) |
| 2060 | + self.assertEqual(len(list(model.parameters())), 2) |
| 2061 | + # Should not throw |
| 2062 | + (model.weight.T @ model.bias).sum().backward() |
| 2063 | + with torch.no_grad(): |
| 2064 | + for p in model.parameters(): |
| 2065 | + p.add_(- p.grad, alpha=0.01) |
| 2066 | + |
| 2067 | + def test_register_and_remove_buffer_parametrization(self): |
| 2068 | + r"""Test that it is possible to add and remove parametrizations on buffers""" |
| 2069 | + # Define a couple vector parametrizations |
| 2070 | + class FirstZero(nn.Module): |
| 2071 | + def forward(self, x): |
| 2072 | + return torch.cat([x.new_zeros(1), x[1:]]) |
| 2073 | + |
| 2074 | + class LastZero(nn.Module): |
| 2075 | + def forward(self, x): |
| 2076 | + return torch.cat([x[:-1], x.new_zeros(1)]) |
| 2077 | + |
| 2078 | + model = nn.Linear(8, 8) |
| 2079 | + |
| 2080 | + # Instantiate parametrizations on buffers. It should work as expected |
| 2081 | + delattr(model, "bias") |
| 2082 | + model.register_buffer("bias", torch.ones(8)) |
| 2083 | + parametrize.register_parametrization(model, "bias", FirstZero()) |
| 2084 | + parametrize.register_parametrization(model, "bias", LastZero()) |
| 2085 | + self.assertTrue(parametrize.is_parametrized(model)) |
| 2086 | + self.assertTrue(parametrize.is_parametrized(model, "bias")) |
| 2087 | + self.assertEqual(model.bias[0].item(), 0.) |
| 2088 | + self.assertEqual(model.bias[-1].item(), 0.) |
| 2089 | + self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) |
| 2090 | + self.assertEqual(len(list(model.parameters())), 1) |
| 2091 | + |
| 2092 | + # Remove parametrizations on buffers. It should work as expected |
| 2093 | + parametrize.remove_parametrizations(model, "bias", leave_parametrized=True) |
| 2094 | + self.assertFalse(parametrize.is_parametrized(model)) |
| 2095 | + self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| 2096 | + self.assertEqual(model.bias[0].item(), 0.) |
| 2097 | + self.assertEqual(model.bias[-1].item(), 0.) |
| 2098 | + self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) |
| 2099 | + self.assertEqual(len(list(model.parameters())), 1) |
| 2100 | + |
| 2101 | + def test_serialization_parametrization(self): |
| 2102 | + r"""Test that it is possible to serialize a parametrized model via state_dict""" |
| 2103 | + # A stateful parametrization |
| 2104 | + class Orthogonal(nn.Module): |
| 2105 | + def __init__(self, n): |
| 2106 | + super().__init__() |
| 2107 | + self.register_buffer("id", torch.eye(n)) |
| 2108 | + self.register_buffer("B", torch.empty(n, n)) |
| 2109 | + init.orthogonal_(self.B) |
| 2110 | + |
| 2111 | + def forward(self, X): |
| 2112 | + A = X.triu(1) |
| 2113 | + A = A - A.T |
| 2114 | + return self.B @ torch.solve(self.id - A, self.id + A).solution |
| 2115 | + |
| 2116 | + def get_model(): |
| 2117 | + model = torch.nn.Sequential( |
| 2118 | + torch.nn.Linear(5, 5), |
| 2119 | + torch.nn.ReLU(), |
| 2120 | + torch.nn.Linear(5, 1), |
| 2121 | + ) |
| 2122 | + |
| 2123 | + parametrize.register_parametrization(model[0], "weight", Orthogonal(5)) |
| 2124 | + return model |
| 2125 | + |
| 2126 | + model = get_model() |
| 2127 | + |
| 2128 | + prev_weight = model[0].weight |
| 2129 | + prev_B = model[0].parametrizations.weight[0].B |
| 2130 | + |
| 2131 | + new_model = get_model() |
| 2132 | + with TemporaryFileName() as fname: |
| 2133 | + torch.save(model.state_dict(), fname) |
| 2134 | + new_model.load_state_dict(torch.load(fname)) |
| 2135 | + |
| 2136 | + # Integrity tests |
| 2137 | + self.assertTrue(parametrize.is_parametrized(new_model[0], "weight")) |
| 2138 | + self.assertEqual(prev_weight, new_model[0].weight) |
| 2139 | + self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B) |
| 2140 | + |
| 2141 | + # Trying to save the whole parametrized model raises |
| 2142 | + with self.assertRaisesRegex(RuntimeError, "state_dict"): |
| 2143 | + with TemporaryFileName() as fname: |
| 2144 | + torch.save(model, fname) |
| 2145 | + |
| 2146 | + def test_initialization_parametrization(self): |
| 2147 | + r"""Test that it is possible to initialize a parametrization when it |
| 2148 | + implements a `right_inverse` method |
| 2149 | + """ |
| 2150 | + class Skew(nn.Module): |
| 2151 | + def forward(self, X): |
| 2152 | + A = X.triu(1) |
| 2153 | + return A - A.T |
| 2154 | + |
| 2155 | + def is_skew(self, A): |
| 2156 | + return torch.allclose(A, -A.T, atol=1e-6) |
| 2157 | + |
| 2158 | + def right_inverse(self, X): |
| 2159 | + if not self.is_skew(X): |
| 2160 | + raise ValueError("The matrix is not skew-symmetric.") |
| 2161 | + return X.triu(1) |
| 2162 | + |
| 2163 | + # Implements a Cayley map where right_inverse is not quite the inverse of forward |
| 2164 | + class Orthogonal(nn.Module): |
| 2165 | + def __init__(self, n): |
| 2166 | + super().__init__() |
| 2167 | + self.register_buffer("B", torch.eye(n)) |
| 2168 | + |
| 2169 | + def forward(self, A): |
| 2170 | + Id = torch.eye(X.size(0)) |
| 2171 | + return self.B @ torch.solve(Id - A, Id + A).solution |
| 2172 | + |
| 2173 | + def is_orthogonal(self, X): |
| 2174 | + Id = torch.eye(X.size(0)) |
| 2175 | + return torch.allclose(X.T @ X, Id, atol=1e-4) |
| 2176 | + |
| 2177 | + def right_inverse(self, X): |
| 2178 | + if not self.is_orthogonal(X): |
| 2179 | + raise ValueError("The input is not orthogonal.") |
| 2180 | + # cayley(0) == Id, so B @ cayley(0) == B |
| 2181 | + self.B = X |
| 2182 | + return torch.zeros_like(X) |
| 2183 | + |
| 2184 | + N = 5 |
| 2185 | + model = nn.Linear(N, N) |
| 2186 | + # Register the skew-symmetric onstraint. The result is now skew-symmetric |
| 2187 | + parametrize.register_parametrization(model, "weight", Skew()) |
| 2188 | + X = torch.rand(N, N) |
| 2189 | + # X is not skew-symmetric, so it throws an error |
| 2190 | + with self.assertRaises(ValueError): |
| 2191 | + model.weight = X |
| 2192 | + # Make X skew-symmetric |
| 2193 | + X = X - X.T |
| 2194 | + model.weight = X |
| 2195 | + self.assertEqual(model.parametrizations.weight.original, X.triu(1)) |
| 2196 | + self.assertEqual(model.weight, X) |
| 2197 | + |
| 2198 | + # Having several parametrizations registered should work in the same way |
| 2199 | + parametrize.register_parametrization(model, "weight", Orthogonal(N)) |
| 2200 | + # Register now the Cayley map. The result is now orthogonal |
| 2201 | + X = torch.rand(N, N) |
| 2202 | + # X is not orthogonal, so it throws an error |
| 2203 | + with self.assertRaises(ValueError): |
| 2204 | + model.weight = X |
| 2205 | + init.orthogonal_(X) |
| 2206 | + model.weight = X |
| 2207 | + self.assertEqual(model.weight, X) |
| 2208 | + self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X)) |
| 2209 | + |
| 2210 | + def test_errors_parametrization(self): |
| 2211 | + # A parametrization shall not change the size of the parameter |
| 2212 | + class ChangeSize(nn.Module): |
| 2213 | + def forward(self, x): |
| 2214 | + return x[:-1] |
| 2215 | + |
| 2216 | + # A simple parametrization that does not implement a right_inverse |
| 2217 | + class Double(nn.Module): |
| 2218 | + def forward(self, x): |
| 2219 | + return 2 * x |
| 2220 | + |
| 2221 | + module = nn.Linear(3, 4) |
| 2222 | + # This should not throw when registering |
| 2223 | + parametrize.register_parametrization(module, "weight", ChangeSize()) |
| 2224 | + # It throws in the forward |
| 2225 | + with self.assertRaisesRegex(RuntimeError, "may not change the size"): |
| 2226 | + module(torch.rand(2)) |
| 2227 | + # Undo |
| 2228 | + parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) |
| 2229 | + self.assertFalse(parametrize.is_parametrized(module)) |
| 2230 | + |
| 2231 | + # Removing a parametrization from an unparametrized tensor throws |
| 2232 | + with self.assertRaisesRegex(ValueError, "does not have a parametrization"): |
| 2233 | + parametrize.remove_parametrizations(module, "bias") |
| 2234 | + # Nothing odd happens |
| 2235 | + self.assertFalse(parametrize.is_parametrized(module)) |
| 2236 | + |
| 2237 | + # Register a parametrization on a non-existing parameter breaks |
| 2238 | + with self.assertRaisesRegex(ValueError, "does not have a parameter"): |
| 2239 | + parametrize.register_parametrization(module, "foo", ChangeSize()) |
| 2240 | + self.assertFalse(parametrize.is_parametrized(module)) |
| 2241 | + |
| 2242 | + # Try to assign to a parametrization that does not implement `right_inverse` |
| 2243 | + parametrize.register_parametrization(module, "weight", Double()) |
| 2244 | + with self.assertRaisesRegex(RuntimeError, "right_inverse"): |
| 2245 | + module.weight = torch.rand(4, 3) |
| 2246 | + # Undo |
| 2247 | + parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) |
| 2248 | + self.assertFalse(parametrize.is_parametrized(module)) |
| 2249 | + |
| 2250 | + def test_caching_parametrization(self): |
| 2251 | + r"""Test the caching system of a parametrization""" |
| 2252 | + # Define a couple matrix parametrizations |
| 2253 | + class Skew(nn.Module): |
| 2254 | + def forward(self, X): |
| 2255 | + X = X.tril(-1) |
| 2256 | + return X - X.T |
| 2257 | + |
| 2258 | + class Orthogonal(nn.Module): |
| 2259 | + def forward(self, X): |
| 2260 | + Id = torch.eye(X.size(0), device=X.device) |
| 2261 | + return torch.solve(Id - X, Id + X).solution |
| 2262 | + |
| 2263 | + model = nn.Linear(5, 5) |
| 2264 | + parametrize.register_parametrization(model, "weight", Skew()) |
| 2265 | + parametrize.register_parametrization(model, "weight", Orthogonal()) |
| 2266 | + |
| 2267 | + # Test that the caching system works |
| 2268 | + with parametrize.cached(): |
| 2269 | + X = model.weight |
| 2270 | + Y = model.weight |
| 2271 | + self.assertEqual(id(X), id(Y)) |
| 2272 | + |
| 2273 | + def test_dtype_parametrization(self): |
| 2274 | + r"""Test a case that is not allowed when removing a parametrization""" |
| 2275 | + class ChangeType(nn.Module): |
| 2276 | + def forward(self, X): |
| 2277 | + return X.double() |
| 2278 | + |
| 2279 | + module = nn.Linear(4, 4).float() |
| 2280 | + input_ = torch.rand(4).double() |
| 2281 | + # It is allowed to register a parametrization that changes the dtype |
| 2282 | + parametrize.register_parametrization(module, "weight", ChangeType()) |
| 2283 | + module(input_) |
| 2284 | + # We can remove it leaving the original tensor |
| 2285 | + parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) |
| 2286 | + # But leaving it parametrized breaks |
| 2287 | + parametrize.register_parametrization(module, "weight", ChangeType()) |
| 2288 | + with self.assertRaisesRegex(ValueError, "changes the dtype"): |
| 2289 | + parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) |
| 2290 | + |
1942 | 2291 | # torch/nn/utils/prune.py
|
1943 | 2292 | @unittest.skipIf(not TEST_NUMPY, "numpy not found")
|
1944 | 2293 | def test_validate_pruning_amount_init(self):
|
|
0 commit comments