Skip to content

Commit 05995f9

Browse files
Fix spelling and formatting
1 parent 98d233a commit 05995f9

File tree

1 file changed

+40
-39
lines changed

1 file changed

+40
-39
lines changed

recipes_source/recipes/swap_tensors.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
===============================================================================
44
55
This tutorial introduces a new utility function ``torch.utils.swap_tensors``
6-
and introduces two new extension points where it has been integrated in
6+
as well as two new extension points where it has been integrated in
77
``nn.Module``.
88
9-
1. ``nn.Module.to()` and related methods
9+
1. ``nn.Module.to()`` and related methods
1010
2. ``nn.Module.load_state_dict()``
1111
1212
.. note::
@@ -27,15 +27,18 @@
2727
torch.utils.swap_tensors(t1, t2)
2828
print(f"After swapping, t1: {t1}, t2: {t2}")
2929

30+
################################################################################
3031
# More specifically, ``swap_tensors`` swaps the python ``__class__``, ``__dict__``
3132
# and ``__slots__`` of the two tensors, as well as their associated ``at::Tensor``.
32-
33-
################################################################################
33+
#
34+
#
35+
# Application to ``nn.Module``
36+
# ----------------------------
3437
# This utility is pertinent to ``nn.Module`` when a python object outside
3538
# of the module holds a reference to parameters of the module. If an ``nn.Module``
3639
# modifies any of its parameters out of place, the object holding references to
3740
# the parameters will not see the change. A classic example of this is the
38-
# optimizer, which holds a reference to the parameters of the nn.Module.
41+
# optimizer, which holds a reference to the parameters of the ``nn.Module``.
3942

4043
mod = torch.nn.Linear(1, 2, bias=False)
4144
optimizer = torch.optim.SGD(mod.parameters())
@@ -48,26 +51,29 @@
4851
################################################################################
4952
# Presently, the two broad classes of ``nn.Module`` methods that modify the
5053
# parameters are
51-
# 1. ``nn.Module.to()`` and related methods
52-
# 2. ``nn.Module.load_state_dict()``
54+
#
55+
# * ``nn.Module.to()`` and related methods
56+
# * ``nn.Module.load_state_dict()``
57+
#
5358
# We discuss these in detail below.
54-
55-
################################################################################
59+
#
60+
#
5661
# ``nn.Module.to()`` and related methods
5762
# --------------------------------------
5863
# This includes methods that change the device of the module (e.g. ``nn.Module.cpu()``),
59-
# methods that change the dtype of the module (e.g. ``nn.Module.float()``) as well
60-
# as methods that allow the module to be materialized (``nn.Module.to_empty()``).
64+
# methods that change the ``dtype`` of the module (e.g. ``nn.Module.float()``)
65+
# as well as methods that allow the module to be materialized
66+
# (i.e. ``nn.Module.to_empty()``).
6167
#
6268
# At first glance, it might be non-intuitive that these methods are able to
6369
# modify the parameters of the module in-place. The existing approach has been
6470
# to set the ``.data`` of the module under the hood (``param.data = new_param``).
6571
#
66-
# Notably, the existing approach does not work
67-
# 1. when using ``__torch_dispatch__`` subclasses
68-
# 2. when ``param`` and ``new_param`` do not have the same type
69-
70-
################################################################################
72+
# Notably, the existing approach does not work in these cases:
73+
#
74+
# * when using ``__torch_dispatch__`` subclasses
75+
# * when ``param`` and ``new_param`` do not have the same type
76+
#
7177
# In the following part of this tutorial, we will define a toy ``__torch_dispatch__``
7278
# subclass ``MyQuantizedLinearWeight`` that represents quantized linear weights.
7379
# This subclass will be used for illustration purposes throughout the rest of
@@ -100,17 +106,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
100106
if func in (aten.detach.default, aten._to_copy.default):
101107
new_elem = func(args[0].elem, *args[1:], **kwargs)
102108
return cls(new_elem, args[0].scale)
103-
# Special implementations for certains ops would be added here.
109+
# Implementations for certain ops would be added to ``OP_TABLE``.
104110
# We omit this for brevity.
105-
# OP_TABLE = ...
106-
# elif func is in OP_TABLE:
107-
# return OP_TABLE[func](func, args, kwargs)
111+
OP_TABLE = dict()
112+
if func in OP_TABLE:
113+
return OP_TABLE[func](func, args, kwargs)
108114
raise NotImplementedError(f"Unsupported function {func}")
109115

110116
#################################################################################
111-
# Let us create a Linear layer of dtype ``torch.float32`` where the weight is
117+
# Let us create a Linear layer of ``dtype`` ``torch.float32`` where the weight is
112118
# a ``MyQuantizedLinearWeight`` and try to convert it to ``torch.bfloat16``.
113-
# Observe that the weight's dtype changes as expected. However, the dtype
119+
# Observe that the weight's ``dtype`` changes as expected. However, the ``dtype``
114120
# of the subclass' payload (i.e.``elem``) does not change.
115121

116122
m = nn.Linear(3, 5, dtype=torch.float32)
@@ -128,7 +134,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
128134
# ``swap_tensors`` to swap the parameters of the module while preserving
129135
# references in place of ``.data`` setting. When this config is set,
130136
# ``swap_tensors`` will be used during the conversion, which ensures that
131-
# the dtype of the payload is properly converted.
137+
# the ``dtype`` of the payload is properly converted.
132138

133139
torch.__future__.set_swap_module_params_on_conversion(True)
134140
m = nn.Linear(3, 5, dtype=torch.float32)
@@ -150,32 +156,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
150156
# 1. ``assign=False``: in-place copy (i.e. ``param.copy_(state_dict['param'])``)
151157
# 2. ``assign=True``: ``__setattr__`` (i.e. ``module.param = state_dict['param']``)
152158
#
159+
#
153160
# Each has its own limitations -- ``assign=False`` imposes the constraint that
154161
# the type of the parameter in the state_dict must be the same as the type of
155162
# the parameter in the module while ``assign=True`` imposes the constraint that
156163
# anything that holds references to the module's parameters must be initialized
157164
# after ``nn.Module.load_state_dict()``.
158165
#
159-
# We address both constraints by adding a swap_tensors path to ``load_state_dict()``
166+
# We address both constraints by adding a ``swap_tensors`` path to ``load_state_dict()``
160167
# and introducing a new extension point ``torch.Tensor.module_load(self, other, assign=False)``.
161168
# When the ``swap_tensors`` path is enabled via the ``__future__`` mentioned above,
162-
# ``module_load`` can be overriden to apply a custom transformation to the value
163-
# in the ``state_dict``. The result of this transformation will be swapped with
164-
# the parameter in the module.
165-
#
166-
# In this world, ``assign=True`` is a directive to preserve the properties of the tensor
167-
# in the state_dict (excluding ``requires_grad``-ness) and ``assign=False`` is a directive
168-
# to preserve the properties of the tensor in the module.
169+
# we can use a ``__torch_function__`` handler for ``module_load`` to apply a
170+
# custom transformation to the value in the ``state_dict``. The result of this
171+
# transformation will be swapped with the parameter in the module.
169172
#
170173
# In the following example, we will use the ``MyQuantizedLinearWeight`` subclass
171174
# defined above to illustrate how we can use these features to apply a
172175
# custom quantization scheme to the weights of a linear layer when
173176
# loading the ``state_dict``.
174-
175-
176-
################################################################################
177+
#
177178
# Recall that the ``__torch_function__`` handler for ``module_load`` will be
178-
# invoked if either ``self`` or ``other`` (in this instance ``param`` or
179+
# invoked if either ``self`` or ``other`` (in this case ``param`` or
179180
# ``state_dict[param_key]``) are ``MyQuantizedLinearWeight`` subclasses.
180181
#
181182
# Assume that we expect the ``state_dict`` to contain plain tensors and the
@@ -218,7 +219,7 @@ def fn(m):
218219
#################################################################################
219220
# We can then load the ``state_dict``. Observe that we use ``assign=True`` because
220221
# for biases, we want to preserve the properties of the tensor in the ``state_dict``
221-
# (i.e. we do not want the bias to be on the meta device after loading).
222+
# (i.e. we do not want the bias to be on the ``meta`` device after loading).
222223

223224
torch.__future__.set_swap_module_params_on_conversion(True)
224225
print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
@@ -229,17 +230,17 @@ def fn(m):
229230
print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
230231
print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")
231232

233+
#################################################################################
232234
# The above is a toy example of how we can use the new extension point in
233235
# ``nn.Module.load_state_dict()``. One can also imagine alternate scenarios such
234236
# as when we have tensor subclasses in the state_dict and plain ``nn.Parameters``/
235237
# tensors in the module or when both are tensor subclasses. Based on the use
236238
# case, we can define the ``__torch_function__`` handler for ``module_load``
237239
# to apply the transforms as needed.
238-
239-
###############################################################################
240+
#
240241
# Conclusion
241242
# ----------
242-
# In this tutorial, we learnt about ``swap_tensors``, the importance
243+
# In this tutorial, we learned about ``swap_tensors``, the importance
243244
# of preserving references for parameters in ``nn.Module`` as well as how to
244245
# use the two new extension points that are gated by
245246
# ``torch.__future__.set_swap_module_params_on_conversion``.

0 commit comments

Comments
 (0)