|
| 1 | +""" |
| 2 | +Extension points in ``nn.Module`` for ``load_state_dict`` and tensor subclasses |
| 3 | +=============================================================================== |
| 4 | +
|
| 5 | +This tutorial introduces a new utility function ``torch.utils.swap_tensors`` |
| 6 | +and introduces two new extension points where it has been integrated in |
| 7 | +``nn.Module``. |
| 8 | +
|
| 9 | +1. ``nn.Module.to()` and related methods |
| 10 | +2. ``nn.Module.load_state_dict()`` |
| 11 | +
|
| 12 | +.. note:: |
| 13 | + This recipe requires PyTorch 2.3.0 or later. |
| 14 | +""" |
| 15 | + |
| 16 | +############################################################################### |
| 17 | +# ``torch.utils.swap_tensors`` |
| 18 | +# ---------------------------- |
| 19 | +# ``torch.utils.swap_tensors`` (hereafter referred to as ``swap_tensors``) is a |
| 20 | +# utility function that takes in two python tensors and swaps them. |
| 21 | + |
| 22 | +import torch |
| 23 | +import torch.nn as nn |
| 24 | +t1 = torch.arange(2) |
| 25 | +t2 = torch.arange(3) |
| 26 | +print(f"Before swapping, t1: {t1}, t2: {t2}") |
| 27 | +torch.utils.swap_tensors(t1, t2) |
| 28 | +print(f"After swapping, t1: {t1}, t2: {t2}") |
| 29 | + |
| 30 | +# More specifically, ``swap_tensors`` swaps the python ``__class__``, ``__dict__`` |
| 31 | +# and ``__slots__`` of the two tensors, as well as their associated ``at::Tensor``. |
| 32 | + |
| 33 | +################################################################################ |
| 34 | +# This utility is pertinent to ``nn.Module`` when a python object outside |
| 35 | +# of the module holds a reference to parameters of the module. If an ``nn.Module`` |
| 36 | +# modifies any of its parameters out of place, the object holding references to |
| 37 | +# the parameters will not see the change. A classic example of this is the |
| 38 | +# optimizer, which holds a reference to the parameters of the nn.Module. |
| 39 | + |
| 40 | +mod = torch.nn.Linear(1, 2, bias=False) |
| 41 | +optimizer = torch.optim.SGD(mod.parameters()) |
| 42 | +print(f"weight in mod: {mod.weight}") |
| 43 | +print(f"weight in optimizer: {optimizer.param_groups[0]['params']}") |
| 44 | +mod.weight = torch.nn.Parameter(2 * mod.weight) |
| 45 | +print(f"weight in mod: {mod.weight}") |
| 46 | +print(f"weight in optimizer: {optimizer.param_groups[0]['params']}") |
| 47 | + |
| 48 | +################################################################################ |
| 49 | +# Presently, the two broad classes of ``nn.Module`` methods that modify the |
| 50 | +# parameters are |
| 51 | +# 1. ``nn.Module.to()`` and related methods |
| 52 | +# 2. ``nn.Module.load_state_dict()`` |
| 53 | +# We discuss these in detail below. |
| 54 | + |
| 55 | +################################################################################ |
| 56 | +# ``nn.Module.to()`` and related methods |
| 57 | +# -------------------------------------- |
| 58 | +# This includes methods that change the device of the module (e.g. ``nn.Module.cpu()``), |
| 59 | +# methods that change the dtype of the module (e.g. ``nn.Module.float()``) as well |
| 60 | +# as methods that allow the module to be materialized (``nn.Module.to_empty()``). |
| 61 | +# |
| 62 | +# At first glance, it might be non-intuitive that these methods are able to |
| 63 | +# modify the parameters of the module in-place. The existing approach has been |
| 64 | +# to set the ``.data`` of the module under the hood (``param.data = new_param``). |
| 65 | +# |
| 66 | +# Notably, the existing approach does not work |
| 67 | +# 1. when using ``__torch_dispatch__`` subclasses |
| 68 | +# 2. when ``param`` and ``new_param`` do not have the same type |
| 69 | + |
| 70 | +################################################################################ |
| 71 | +# In the following part of this tutorial, we will define a toy ``__torch_dispatch__`` |
| 72 | +# subclass ``MyQuantizedLinearWeight`` that represents quantized linear weights. |
| 73 | +# This subclass will be used for illustration purposes throughout the rest of |
| 74 | +# the tutorial. For brevity, we omit most of the ``__torch_dispatch__`` |
| 75 | +# implementation. |
| 76 | +aten = torch.ops.aten |
| 77 | + |
| 78 | +class MyQuantizedLinearWeight(torch.Tensor): |
| 79 | + @staticmethod |
| 80 | + def __new__(cls, elem, scale, **kwargs): |
| 81 | + return torch.Tensor._make_wrapper_subclass( |
| 82 | + cls, |
| 83 | + elem.shape, |
| 84 | + dtype=elem.dtype, |
| 85 | + layout=elem.layout, |
| 86 | + device=elem.device, |
| 87 | + requires_grad=elem.requires_grad, |
| 88 | + strides=elem.stride(), |
| 89 | + storage_offset=elem.storage_offset()) |
| 90 | + |
| 91 | + def __init__(self, elem: torch.Tensor, scale: float, **kwargs): |
| 92 | + self.elem = elem |
| 93 | + self.scale = scale |
| 94 | + |
| 95 | + def __repr__(self): |
| 96 | + return f"MyQuantizedLinearWeight({self.elem}, scale={self.scale})" |
| 97 | + |
| 98 | + @classmethod |
| 99 | + def __torch_dispatch__(cls, func, types, args, kwargs): |
| 100 | + if func in (aten.detach.default, aten._to_copy.default): |
| 101 | + new_elem = func(args[0].elem, *args[1:], **kwargs) |
| 102 | + return cls(new_elem, args[0].scale) |
| 103 | + # Special implementations for certains ops would be added here. |
| 104 | + # We omit this for brevity. |
| 105 | + # OP_TABLE = ... |
| 106 | + # elif func is in OP_TABLE: |
| 107 | + # return OP_TABLE[func](func, args, kwargs) |
| 108 | + raise NotImplementedError(f"Unsupported function {func}") |
| 109 | + |
| 110 | +################################################################################# |
| 111 | +# Let us create a Linear layer of dtype ``torch.float32`` where the weight is |
| 112 | +# a ``MyQuantizedLinearWeight`` and try to convert it to ``torch.bfloat16``. |
| 113 | +# Observe that the weight's dtype changes as expected. However, the dtype |
| 114 | +# of the subclass' payload (i.e.``elem``) does not change. |
| 115 | + |
| 116 | +m = nn.Linear(3, 5, dtype=torch.float32) |
| 117 | +m.weight = torch.nn.Parameter(MyQuantizedLinearWeight(m.weight, 0.5)) |
| 118 | +print(f"Before: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}") |
| 119 | +m.bfloat16() |
| 120 | +print(f"After: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}") |
| 121 | +print(f"m.weight.dtype: {m.weight.dtype}") |
| 122 | +print(f"m.weight.elem.dtype: {m.weight.elem.dtype}") |
| 123 | +print(f"m.bias.dtype: {m.bias.dtype}") |
| 124 | + |
| 125 | +################################################################################ |
| 126 | +# To this end, we introduce a global config |
| 127 | +# ``torch.__future__.set_swap_module_params_on_conversion`` that will use |
| 128 | +# ``swap_tensors`` to swap the parameters of the module while preserving |
| 129 | +# references in place of ``.data`` setting. When this config is set, |
| 130 | +# ``swap_tensors`` will be used during the conversion, which ensures that |
| 131 | +# the dtype of the payload is properly converted. |
| 132 | + |
| 133 | +torch.__future__.set_swap_module_params_on_conversion(True) |
| 134 | +m = nn.Linear(3, 5, dtype=torch.float32) |
| 135 | +m.weight = torch.nn.Parameter(MyQuantizedLinearWeight(m.weight, 0.5)) |
| 136 | +print(f"Before: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}") |
| 137 | +m.bfloat16() |
| 138 | +print(f"After: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}") |
| 139 | +print(f"m.weight.dtype: {m.weight.dtype}") |
| 140 | +print(f"m.weight.elem.dtype: {m.weight.elem.dtype}") |
| 141 | +print(f"m.bias.dtype: {m.bias.dtype}") |
| 142 | +torch.__future__.set_swap_module_params_on_conversion(False) |
| 143 | + |
| 144 | +################################################################################ |
| 145 | +# ``nn.Module.load_state_dict()`` |
| 146 | +# -------------------------------- |
| 147 | +# At present, depending on the value of the ``assign`` keyword argument passed |
| 148 | +# to ``load_state_dict()``, there are two ways to load the ``state_dict``: |
| 149 | +# |
| 150 | +# 1. ``assign=False``: in-place copy (i.e. ``param.copy_(state_dict['param'])``) |
| 151 | +# 2. ``assign=True``: ``__setattr__`` (i.e. ``module.param = state_dict['param']``) |
| 152 | +# |
| 153 | +# Each has its own limitations -- ``assign=False`` imposes the constraint that |
| 154 | +# the type of the parameter in the state_dict must be the same as the type of |
| 155 | +# the parameter in the module while ``assign=True`` imposes the constraint that |
| 156 | +# anything that holds references to the module's parameters must be initialized |
| 157 | +# after ``nn.Module.load_state_dict()``. |
| 158 | +# |
| 159 | +# We address both constraints by adding a swap_tensors path to ``load_state_dict()`` |
| 160 | +# and introducing a new extension point ``torch.Tensor.module_load(self, other, assign=False)``. |
| 161 | +# When the ``swap_tensors`` path is enabled via the ``__future__`` mentioned above, |
| 162 | +# ``module_load`` can be overriden to apply a custom transformation to the value |
| 163 | +# in the ``state_dict``. The result of this transformation will be swapped with |
| 164 | +# the parameter in the module. |
| 165 | +# |
| 166 | +# In this world, ``assign=True`` is a directive to preserve the properties of the tensor |
| 167 | +# in the state_dict (excluding ``requires_grad``-ness) and ``assign=False`` is a directive |
| 168 | +# to preserve the properties of the tensor in the module. |
| 169 | +# |
| 170 | +# In the following example, we will use the ``MyQuantizedLinearWeight`` subclass |
| 171 | +# defined above to illustrate how we can use these features to apply a |
| 172 | +# custom quantization scheme to the weights of a linear layer when |
| 173 | +# loading the ``state_dict``. |
| 174 | + |
| 175 | + |
| 176 | +################################################################################ |
| 177 | +# Recall that the ``__torch_function__`` handler for ``module_load`` will be |
| 178 | +# invoked if either ``self`` or ``other`` (in this instance ``param`` or |
| 179 | +# ``state_dict[param_key]``) are ``MyQuantizedLinearWeight`` subclasses. |
| 180 | +# |
| 181 | +# Assume that we expect the ``state_dict`` to contain plain tensors and the |
| 182 | +# module to contain ``MyQuantizedLinearWeight`` parameters where we want the |
| 183 | +# tensors in the ``state_dict`` to be transformed into the subclass. Then we |
| 184 | +# can define a ``__torch_function__`` handler for ``torch.Tensor.module_load`` |
| 185 | +# as such: |
| 186 | + |
| 187 | +@classmethod |
| 188 | +def custom_torch_function(cls, func, types, args=(), kwargs=None): |
| 189 | + kwargs = {} if kwargs is None else kwargs |
| 190 | + def module_load(dest, src, assign=False): |
| 191 | + assert type(dest) == cls and type(src) == torch.Tensor |
| 192 | + return MyQuantizedLinearWeight(src, dest.scale) |
| 193 | + |
| 194 | + if func is torch.Tensor.module_load: |
| 195 | + return module_load(*args, **kwargs) |
| 196 | + else: |
| 197 | + with torch._C.DisableTorchFunctionSubclass(): |
| 198 | + return func(*args, **kwargs) |
| 199 | + |
| 200 | +MyQuantizedLinearWeight.__torch_function__ = custom_torch_function |
| 201 | + |
| 202 | +################################################################################# |
| 203 | +# First, let us create a skeleton of a model on the meta device to avoid |
| 204 | +# materializing storages. We convert all weights in the modules to |
| 205 | +# ``MyQuantizedLinearWeight`` subclasses while leaving biases intact. |
| 206 | + |
| 207 | +def fn(m): |
| 208 | + if isinstance(m, nn.Linear): |
| 209 | + requires_grad = m.weight.requires_grad |
| 210 | + m.weight = torch.nn.Parameter( |
| 211 | + MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad |
| 212 | + ) |
| 213 | + |
| 214 | +with torch.device("meta"): |
| 215 | + m = nn.Linear(3, 5) |
| 216 | + m.apply(fn) |
| 217 | + |
| 218 | +################################################################################# |
| 219 | +# We can then load the ``state_dict``. Observe that we use ``assign=True`` because |
| 220 | +# for biases, we want to preserve the properties of the tensor in the ``state_dict`` |
| 221 | +# (i.e. we do not want the bias to be on the meta device after loading). |
| 222 | + |
| 223 | +torch.__future__.set_swap_module_params_on_conversion(True) |
| 224 | +print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}") |
| 225 | +print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}") |
| 226 | +state_dict = nn.Linear(3, 5).state_dict() |
| 227 | +print(f"state_dict:\n {state_dict}") |
| 228 | +m.load_state_dict(state_dict, assign=True) |
| 229 | +print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}") |
| 230 | +print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}") |
| 231 | + |
| 232 | +# The above is a toy example of how we can use the new extension point in |
| 233 | +# ``nn.Module.load_state_dict()``. One can also imagine alternate scenarios such |
| 234 | +# as when we have tensor subclasses in the state_dict and plain ``nn.Parameters``/ |
| 235 | +# tensors in the module or when both are tensor subclasses. Based on the use |
| 236 | +# case, we can define the ``__torch_function__`` handler for ``module_load`` |
| 237 | +# to apply the transforms as needed. |
| 238 | + |
| 239 | +############################################################################### |
| 240 | +# Conclusion |
| 241 | +# ---------- |
| 242 | +# In this tutorial, we learnt about ``swap_tensors``, the importance |
| 243 | +# of preserving references for parameters in ``nn.Module`` as well as how to |
| 244 | +# use the two new extension points that are gated by |
| 245 | +# ``torch.__future__.set_swap_module_params_on_conversion``. |
0 commit comments