3
3
===============================================================================
4
4
5
5
This tutorial introduces a new utility function ``torch.utils.swap_tensors``
6
- and introduces two new extension points where it has been integrated in
6
+ as well as two new extension points where it has been integrated in
7
7
``nn.Module``.
8
8
9
- 1. ``nn.Module.to()` and related methods
9
+ 1. ``nn.Module.to()`` and related methods
10
10
2. ``nn.Module.load_state_dict()``
11
11
12
12
.. note::
27
27
torch .utils .swap_tensors (t1 , t2 )
28
28
print (f"After swapping, t1: { t1 } , t2: { t2 } " )
29
29
30
+ ################################################################################
30
31
# More specifically, ``swap_tensors`` swaps the python ``__class__``, ``__dict__``
31
32
# and ``__slots__`` of the two tensors, as well as their associated ``at::Tensor``.
32
-
33
- ################################################################################
33
+ #
34
+ #
35
+ # Application to ``nn.Module``
36
+ # ----------------------------
34
37
# This utility is pertinent to ``nn.Module`` when a python object outside
35
38
# of the module holds a reference to parameters of the module. If an ``nn.Module``
36
39
# modifies any of its parameters out of place, the object holding references to
37
40
# the parameters will not see the change. A classic example of this is the
38
- # optimizer, which holds a reference to the parameters of the nn.Module.
41
+ # optimizer, which holds a reference to the parameters of the `` nn.Module`` .
39
42
40
43
mod = torch .nn .Linear (1 , 2 , bias = False )
41
44
optimizer = torch .optim .SGD (mod .parameters ())
48
51
################################################################################
49
52
# Presently, the two broad classes of ``nn.Module`` methods that modify the
50
53
# parameters are
51
- # 1. ``nn.Module.to()`` and related methods
52
- # 2. ``nn.Module.load_state_dict()``
54
+ #
55
+ # * ``nn.Module.to()`` and related methods
56
+ # * ``nn.Module.load_state_dict()``
57
+ #
53
58
# We discuss these in detail below.
54
-
55
- ################################################################################
59
+ #
60
+ #
56
61
# ``nn.Module.to()`` and related methods
57
62
# --------------------------------------
58
63
# This includes methods that change the device of the module (e.g. ``nn.Module.cpu()``),
59
- # methods that change the dtype of the module (e.g. ``nn.Module.float()``) as well
60
- # as methods that allow the module to be materialized (``nn.Module.to_empty()``).
64
+ # methods that change the ``dtype`` of the module (e.g. ``nn.Module.float()``)
65
+ # as well as methods that allow the module to be materialized
66
+ # (i.e. ``nn.Module.to_empty()``).
61
67
#
62
68
# At first glance, it might be non-intuitive that these methods are able to
63
69
# modify the parameters of the module in-place. The existing approach has been
64
70
# to set the ``.data`` of the module under the hood (``param.data = new_param``).
65
71
#
66
- # Notably, the existing approach does not work
67
- # 1. when using ``__torch_dispatch__`` subclasses
68
- # 2. when ``param`` and ``new_param `` do not have the same type
69
-
70
- ################################################################################
72
+ # Notably, the existing approach does not work in these cases:
73
+ #
74
+ # * when using ``__torch_dispatch__ `` subclasses
75
+ # * when ``param`` and ``new_param`` do not have the same type
76
+ #
71
77
# In the following part of this tutorial, we will define a toy ``__torch_dispatch__``
72
78
# subclass ``MyQuantizedLinearWeight`` that represents quantized linear weights.
73
79
# This subclass will be used for illustration purposes throughout the rest of
@@ -100,17 +106,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
100
106
if func in (aten .detach .default , aten ._to_copy .default ):
101
107
new_elem = func (args [0 ].elem , * args [1 :], ** kwargs )
102
108
return cls (new_elem , args [0 ].scale )
103
- # Special implementations for certains ops would be added here .
109
+ # Implementations for certain ops would be added to ``OP_TABLE`` .
104
110
# We omit this for brevity.
105
- # OP_TABLE = ...
106
- # elif func is in OP_TABLE:
107
- # return OP_TABLE[func](func, args, kwargs)
111
+ OP_TABLE = dict ()
112
+ if func in OP_TABLE :
113
+ return OP_TABLE [func ](func , args , kwargs )
108
114
raise NotImplementedError (f"Unsupported function { func } " )
109
115
110
116
#################################################################################
111
- # Let us create a Linear layer of dtype ``torch.float32`` where the weight is
117
+ # Let us create a Linear layer of `` dtype`` ``torch.float32`` where the weight is
112
118
# a ``MyQuantizedLinearWeight`` and try to convert it to ``torch.bfloat16``.
113
- # Observe that the weight's dtype changes as expected. However, the dtype
119
+ # Observe that the weight's `` dtype`` changes as expected. However, the `` dtype``
114
120
# of the subclass' payload (i.e.``elem``) does not change.
115
121
116
122
m = nn .Linear (3 , 5 , dtype = torch .float32 )
@@ -128,7 +134,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
128
134
# ``swap_tensors`` to swap the parameters of the module while preserving
129
135
# references in place of ``.data`` setting. When this config is set,
130
136
# ``swap_tensors`` will be used during the conversion, which ensures that
131
- # the dtype of the payload is properly converted.
137
+ # the `` dtype`` of the payload is properly converted.
132
138
133
139
torch .__future__ .set_swap_module_params_on_conversion (True )
134
140
m = nn .Linear (3 , 5 , dtype = torch .float32 )
@@ -150,32 +156,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
150
156
# 1. ``assign=False``: in-place copy (i.e. ``param.copy_(state_dict['param'])``)
151
157
# 2. ``assign=True``: ``__setattr__`` (i.e. ``module.param = state_dict['param']``)
152
158
#
159
+ #
153
160
# Each has its own limitations -- ``assign=False`` imposes the constraint that
154
161
# the type of the parameter in the state_dict must be the same as the type of
155
162
# the parameter in the module while ``assign=True`` imposes the constraint that
156
163
# anything that holds references to the module's parameters must be initialized
157
164
# after ``nn.Module.load_state_dict()``.
158
165
#
159
- # We address both constraints by adding a swap_tensors path to ``load_state_dict()``
166
+ # We address both constraints by adding a `` swap_tensors`` path to ``load_state_dict()``
160
167
# and introducing a new extension point ``torch.Tensor.module_load(self, other, assign=False)``.
161
168
# When the ``swap_tensors`` path is enabled via the ``__future__`` mentioned above,
162
- # ``module_load`` can be overriden to apply a custom transformation to the value
163
- # in the ``state_dict``. The result of this transformation will be swapped with
164
- # the parameter in the module.
165
- #
166
- # In this world, ``assign=True`` is a directive to preserve the properties of the tensor
167
- # in the state_dict (excluding ``requires_grad``-ness) and ``assign=False`` is a directive
168
- # to preserve the properties of the tensor in the module.
169
+ # we can use a ``__torch_function__`` handler for ``module_load`` to apply a
170
+ # custom transformation to the value in the ``state_dict``. The result of this
171
+ # transformation will be swapped with the parameter in the module.
169
172
#
170
173
# In the following example, we will use the ``MyQuantizedLinearWeight`` subclass
171
174
# defined above to illustrate how we can use these features to apply a
172
175
# custom quantization scheme to the weights of a linear layer when
173
176
# loading the ``state_dict``.
174
-
175
-
176
- ################################################################################
177
+ #
177
178
# Recall that the ``__torch_function__`` handler for ``module_load`` will be
178
- # invoked if either ``self`` or ``other`` (in this instance ``param`` or
179
+ # invoked if either ``self`` or ``other`` (in this case ``param`` or
179
180
# ``state_dict[param_key]``) are ``MyQuantizedLinearWeight`` subclasses.
180
181
#
181
182
# Assume that we expect the ``state_dict`` to contain plain tensors and the
@@ -218,7 +219,7 @@ def fn(m):
218
219
#################################################################################
219
220
# We can then load the ``state_dict``. Observe that we use ``assign=True`` because
220
221
# for biases, we want to preserve the properties of the tensor in the ``state_dict``
221
- # (i.e. we do not want the bias to be on the meta device after loading).
222
+ # (i.e. we do not want the bias to be on the `` meta`` device after loading).
222
223
223
224
torch .__future__ .set_swap_module_params_on_conversion (True )
224
225
print (f"Before: id(weight)={ id (m .weight )} , id(bias)={ id (m .bias )} " )
@@ -229,17 +230,17 @@ def fn(m):
229
230
print (f"After: id(weight)={ id (m .weight )} , id(bias)={ id (m .bias )} " )
230
231
print (f"m.state_dict() after load_state_dict():\n { m .state_dict ()} " )
231
232
233
+ #################################################################################
232
234
# The above is a toy example of how we can use the new extension point in
233
235
# ``nn.Module.load_state_dict()``. One can also imagine alternate scenarios such
234
236
# as when we have tensor subclasses in the state_dict and plain ``nn.Parameters``/
235
237
# tensors in the module or when both are tensor subclasses. Based on the use
236
238
# case, we can define the ``__torch_function__`` handler for ``module_load``
237
239
# to apply the transforms as needed.
238
-
239
- ###############################################################################
240
+ #
240
241
# Conclusion
241
242
# ----------
242
- # In this tutorial, we learnt about ``swap_tensors``, the importance
243
+ # In this tutorial, we learned about ``swap_tensors``, the importance
243
244
# of preserving references for parameters in ``nn.Module`` as well as how to
244
245
# use the two new extension points that are gated by
245
246
# ``torch.__future__.set_swap_module_params_on_conversion``.
0 commit comments