Skip to content

Commit 98d233a

Browse files
Add tutorial for swap_tensors in nn.Module
1 parent 19f34fa commit 98d233a

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""
2+
Extension points in ``nn.Module`` for ``load_state_dict`` and tensor subclasses
3+
===============================================================================
4+
5+
This tutorial introduces a new utility function ``torch.utils.swap_tensors``
6+
and introduces two new extension points where it has been integrated in
7+
``nn.Module``.
8+
9+
1. ``nn.Module.to()` and related methods
10+
2. ``nn.Module.load_state_dict()``
11+
12+
.. note::
13+
This recipe requires PyTorch 2.3.0 or later.
14+
"""
15+
16+
###############################################################################
17+
# ``torch.utils.swap_tensors``
18+
# ----------------------------
19+
# ``torch.utils.swap_tensors`` (hereafter referred to as ``swap_tensors``) is a
20+
# utility function that takes in two python tensors and swaps them.
21+
22+
import torch
23+
import torch.nn as nn
24+
t1 = torch.arange(2)
25+
t2 = torch.arange(3)
26+
print(f"Before swapping, t1: {t1}, t2: {t2}")
27+
torch.utils.swap_tensors(t1, t2)
28+
print(f"After swapping, t1: {t1}, t2: {t2}")
29+
30+
# More specifically, ``swap_tensors`` swaps the python ``__class__``, ``__dict__``
31+
# and ``__slots__`` of the two tensors, as well as their associated ``at::Tensor``.
32+
33+
################################################################################
34+
# This utility is pertinent to ``nn.Module`` when a python object outside
35+
# of the module holds a reference to parameters of the module. If an ``nn.Module``
36+
# modifies any of its parameters out of place, the object holding references to
37+
# the parameters will not see the change. A classic example of this is the
38+
# optimizer, which holds a reference to the parameters of the nn.Module.
39+
40+
mod = torch.nn.Linear(1, 2, bias=False)
41+
optimizer = torch.optim.SGD(mod.parameters())
42+
print(f"weight in mod: {mod.weight}")
43+
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
44+
mod.weight = torch.nn.Parameter(2 * mod.weight)
45+
print(f"weight in mod: {mod.weight}")
46+
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
47+
48+
################################################################################
49+
# Presently, the two broad classes of ``nn.Module`` methods that modify the
50+
# parameters are
51+
# 1. ``nn.Module.to()`` and related methods
52+
# 2. ``nn.Module.load_state_dict()``
53+
# We discuss these in detail below.
54+
55+
################################################################################
56+
# ``nn.Module.to()`` and related methods
57+
# --------------------------------------
58+
# This includes methods that change the device of the module (e.g. ``nn.Module.cpu()``),
59+
# methods that change the dtype of the module (e.g. ``nn.Module.float()``) as well
60+
# as methods that allow the module to be materialized (``nn.Module.to_empty()``).
61+
#
62+
# At first glance, it might be non-intuitive that these methods are able to
63+
# modify the parameters of the module in-place. The existing approach has been
64+
# to set the ``.data`` of the module under the hood (``param.data = new_param``).
65+
#
66+
# Notably, the existing approach does not work
67+
# 1. when using ``__torch_dispatch__`` subclasses
68+
# 2. when ``param`` and ``new_param`` do not have the same type
69+
70+
################################################################################
71+
# In the following part of this tutorial, we will define a toy ``__torch_dispatch__``
72+
# subclass ``MyQuantizedLinearWeight`` that represents quantized linear weights.
73+
# This subclass will be used for illustration purposes throughout the rest of
74+
# the tutorial. For brevity, we omit most of the ``__torch_dispatch__``
75+
# implementation.
76+
aten = torch.ops.aten
77+
78+
class MyQuantizedLinearWeight(torch.Tensor):
79+
@staticmethod
80+
def __new__(cls, elem, scale, **kwargs):
81+
return torch.Tensor._make_wrapper_subclass(
82+
cls,
83+
elem.shape,
84+
dtype=elem.dtype,
85+
layout=elem.layout,
86+
device=elem.device,
87+
requires_grad=elem.requires_grad,
88+
strides=elem.stride(),
89+
storage_offset=elem.storage_offset())
90+
91+
def __init__(self, elem: torch.Tensor, scale: float, **kwargs):
92+
self.elem = elem
93+
self.scale = scale
94+
95+
def __repr__(self):
96+
return f"MyQuantizedLinearWeight({self.elem}, scale={self.scale})"
97+
98+
@classmethod
99+
def __torch_dispatch__(cls, func, types, args, kwargs):
100+
if func in (aten.detach.default, aten._to_copy.default):
101+
new_elem = func(args[0].elem, *args[1:], **kwargs)
102+
return cls(new_elem, args[0].scale)
103+
# Special implementations for certains ops would be added here.
104+
# We omit this for brevity.
105+
# OP_TABLE = ...
106+
# elif func is in OP_TABLE:
107+
# return OP_TABLE[func](func, args, kwargs)
108+
raise NotImplementedError(f"Unsupported function {func}")
109+
110+
#################################################################################
111+
# Let us create a Linear layer of dtype ``torch.float32`` where the weight is
112+
# a ``MyQuantizedLinearWeight`` and try to convert it to ``torch.bfloat16``.
113+
# Observe that the weight's dtype changes as expected. However, the dtype
114+
# of the subclass' payload (i.e.``elem``) does not change.
115+
116+
m = nn.Linear(3, 5, dtype=torch.float32)
117+
m.weight = torch.nn.Parameter(MyQuantizedLinearWeight(m.weight, 0.5))
118+
print(f"Before: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")
119+
m.bfloat16()
120+
print(f"After: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")
121+
print(f"m.weight.dtype: {m.weight.dtype}")
122+
print(f"m.weight.elem.dtype: {m.weight.elem.dtype}")
123+
print(f"m.bias.dtype: {m.bias.dtype}")
124+
125+
################################################################################
126+
# To this end, we introduce a global config
127+
# ``torch.__future__.set_swap_module_params_on_conversion`` that will use
128+
# ``swap_tensors`` to swap the parameters of the module while preserving
129+
# references in place of ``.data`` setting. When this config is set,
130+
# ``swap_tensors`` will be used during the conversion, which ensures that
131+
# the dtype of the payload is properly converted.
132+
133+
torch.__future__.set_swap_module_params_on_conversion(True)
134+
m = nn.Linear(3, 5, dtype=torch.float32)
135+
m.weight = torch.nn.Parameter(MyQuantizedLinearWeight(m.weight, 0.5))
136+
print(f"Before: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")
137+
m.bfloat16()
138+
print(f"After: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")
139+
print(f"m.weight.dtype: {m.weight.dtype}")
140+
print(f"m.weight.elem.dtype: {m.weight.elem.dtype}")
141+
print(f"m.bias.dtype: {m.bias.dtype}")
142+
torch.__future__.set_swap_module_params_on_conversion(False)
143+
144+
################################################################################
145+
# ``nn.Module.load_state_dict()``
146+
# --------------------------------
147+
# At present, depending on the value of the ``assign`` keyword argument passed
148+
# to ``load_state_dict()``, there are two ways to load the ``state_dict``:
149+
#
150+
# 1. ``assign=False``: in-place copy (i.e. ``param.copy_(state_dict['param'])``)
151+
# 2. ``assign=True``: ``__setattr__`` (i.e. ``module.param = state_dict['param']``)
152+
#
153+
# Each has its own limitations -- ``assign=False`` imposes the constraint that
154+
# the type of the parameter in the state_dict must be the same as the type of
155+
# the parameter in the module while ``assign=True`` imposes the constraint that
156+
# anything that holds references to the module's parameters must be initialized
157+
# after ``nn.Module.load_state_dict()``.
158+
#
159+
# We address both constraints by adding a swap_tensors path to ``load_state_dict()``
160+
# and introducing a new extension point ``torch.Tensor.module_load(self, other, assign=False)``.
161+
# When the ``swap_tensors`` path is enabled via the ``__future__`` mentioned above,
162+
# ``module_load`` can be overriden to apply a custom transformation to the value
163+
# in the ``state_dict``. The result of this transformation will be swapped with
164+
# the parameter in the module.
165+
#
166+
# In this world, ``assign=True`` is a directive to preserve the properties of the tensor
167+
# in the state_dict (excluding ``requires_grad``-ness) and ``assign=False`` is a directive
168+
# to preserve the properties of the tensor in the module.
169+
#
170+
# In the following example, we will use the ``MyQuantizedLinearWeight`` subclass
171+
# defined above to illustrate how we can use these features to apply a
172+
# custom quantization scheme to the weights of a linear layer when
173+
# loading the ``state_dict``.
174+
175+
176+
################################################################################
177+
# Recall that the ``__torch_function__`` handler for ``module_load`` will be
178+
# invoked if either ``self`` or ``other`` (in this instance ``param`` or
179+
# ``state_dict[param_key]``) are ``MyQuantizedLinearWeight`` subclasses.
180+
#
181+
# Assume that we expect the ``state_dict`` to contain plain tensors and the
182+
# module to contain ``MyQuantizedLinearWeight`` parameters where we want the
183+
# tensors in the ``state_dict`` to be transformed into the subclass. Then we
184+
# can define a ``__torch_function__`` handler for ``torch.Tensor.module_load``
185+
# as such:
186+
187+
@classmethod
188+
def custom_torch_function(cls, func, types, args=(), kwargs=None):
189+
kwargs = {} if kwargs is None else kwargs
190+
def module_load(dest, src, assign=False):
191+
assert type(dest) == cls and type(src) == torch.Tensor
192+
return MyQuantizedLinearWeight(src, dest.scale)
193+
194+
if func is torch.Tensor.module_load:
195+
return module_load(*args, **kwargs)
196+
else:
197+
with torch._C.DisableTorchFunctionSubclass():
198+
return func(*args, **kwargs)
199+
200+
MyQuantizedLinearWeight.__torch_function__ = custom_torch_function
201+
202+
#################################################################################
203+
# First, let us create a skeleton of a model on the meta device to avoid
204+
# materializing storages. We convert all weights in the modules to
205+
# ``MyQuantizedLinearWeight`` subclasses while leaving biases intact.
206+
207+
def fn(m):
208+
if isinstance(m, nn.Linear):
209+
requires_grad = m.weight.requires_grad
210+
m.weight = torch.nn.Parameter(
211+
MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad
212+
)
213+
214+
with torch.device("meta"):
215+
m = nn.Linear(3, 5)
216+
m.apply(fn)
217+
218+
#################################################################################
219+
# We can then load the ``state_dict``. Observe that we use ``assign=True`` because
220+
# for biases, we want to preserve the properties of the tensor in the ``state_dict``
221+
# (i.e. we do not want the bias to be on the meta device after loading).
222+
223+
torch.__future__.set_swap_module_params_on_conversion(True)
224+
print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
225+
print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}")
226+
state_dict = nn.Linear(3, 5).state_dict()
227+
print(f"state_dict:\n {state_dict}")
228+
m.load_state_dict(state_dict, assign=True)
229+
print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
230+
print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")
231+
232+
# The above is a toy example of how we can use the new extension point in
233+
# ``nn.Module.load_state_dict()``. One can also imagine alternate scenarios such
234+
# as when we have tensor subclasses in the state_dict and plain ``nn.Parameters``/
235+
# tensors in the module or when both are tensor subclasses. Based on the use
236+
# case, we can define the ``__torch_function__`` handler for ``module_load``
237+
# to apply the transforms as needed.
238+
239+
###############################################################################
240+
# Conclusion
241+
# ----------
242+
# In this tutorial, we learnt about ``swap_tensors``, the importance
243+
# of preserving references for parameters in ``nn.Module`` as well as how to
244+
# use the two new extension points that are gated by
245+
# ``torch.__future__.set_swap_module_params_on_conversion``.

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
151151
:link: ../recipes/torch_logs.html
152152
:tags: Basics
153153

154+
.. customcarditem::
155+
:header: Extension points in nn.Module for loading state_dict and tensor subclasses
156+
:card_description: New extension points in nn.Module.
157+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
158+
:link: ../recipes/recipes/module_load_state_dict_tips.html
159+
:tags: Basics
160+
154161

155162
.. Interpretability
156163

0 commit comments

Comments
 (0)