|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 |
|
3 | 3 | import pickle
|
| 4 | +from collections.abc import Sequence |
| 5 | +from inspect import isclass |
4 | 6 | from types import FunctionType
|
5 |
| -from typing import Any, Optional |
| 7 | +from typing import Any, Optional, Union |
6 | 8 |
|
7 | 9 | import cloudpickle
|
| 10 | +import numpy as np |
8 | 11 | import torch
|
| 12 | +import zmq |
9 | 13 | from msgspec import msgpack
|
10 | 14 |
|
11 |
| -CUSTOM_TYPE_TENSOR = 1 |
12 |
| -CUSTOM_TYPE_PICKLE = 2 |
13 |
| -CUSTOM_TYPE_CLOUDPICKLE = 3 |
| 15 | +CUSTOM_TYPE_PICKLE = 1 |
| 16 | +CUSTOM_TYPE_CLOUDPICKLE = 2 |
14 | 17 |
|
| 18 | +# TODO calibrate this size |
| 19 | +INLINE_BUF_SIZE_THRESHOLD = 256 |
15 | 20 |
|
16 |
| -class MsgpackEncoder: |
17 |
| - """Encoder with custom torch tensor serialization.""" |
| 21 | +bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] |
18 | 22 |
|
19 |
| - def __init__(self): |
20 |
| - self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook) |
21 | 23 |
|
22 |
| - def encode(self, obj: Any) -> bytes: |
23 |
| - return self.encoder.encode(obj) |
| 24 | +class MsgpackEncoder: |
| 25 | + """Encoder with custom torch tensor and numpy array serialization. |
24 | 26 |
|
25 |
| - def encode_into(self, obj: Any, buf: bytearray) -> None: |
26 |
| - self.encoder.encode_into(obj, buf) |
| 27 | + Note that unlike vanilla `msgspec` Encoders, this interface is generally |
| 28 | + not thread-safe when encoding tensors / numpy arrays. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__(self): |
| 32 | + self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) |
| 33 | + # This is used as a local stash of buffers that we can then access from |
| 34 | + # our custom `msgspec` hook, `enc_hook`. We don't have a way to |
| 35 | + # pass custom data to the hook otherwise. |
| 36 | + self.aux_buffers: Optional[list[bytestr]] = None |
| 37 | + |
| 38 | + def encode(self, obj: Any) -> Sequence[bytestr]: |
| 39 | + try: |
| 40 | + self.aux_buffers = bufs = [b''] |
| 41 | + bufs[0] = self.encoder.encode(obj) |
| 42 | + # This `bufs` list allows us to collect direct pointers to backing |
| 43 | + # buffers of tensors and np arrays, and return them along with the |
| 44 | + # top-level encoded buffer instead of copying their data into the |
| 45 | + # new buffer. |
| 46 | + return bufs |
| 47 | + finally: |
| 48 | + self.aux_buffers = None |
| 49 | + |
| 50 | + def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: |
| 51 | + try: |
| 52 | + self.aux_buffers = [buf] |
| 53 | + bufs = self.aux_buffers |
| 54 | + self.encoder.encode_into(obj, buf) |
| 55 | + return bufs |
| 56 | + finally: |
| 57 | + self.aux_buffers = None |
| 58 | + |
| 59 | + def enc_hook(self, obj: Any) -> Any: |
| 60 | + if isinstance(obj, torch.Tensor): |
| 61 | + return self._encode_ndarray(obj.numpy()) |
| 62 | + |
| 63 | + # Fall back to pickle for object or void kind ndarrays. |
| 64 | + if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): |
| 65 | + return self._encode_ndarray(obj) |
| 66 | + |
| 67 | + if isinstance(obj, FunctionType): |
| 68 | + # `pickle` is generally faster than cloudpickle, but can have |
| 69 | + # problems serializing methods. |
| 70 | + return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) |
| 71 | + |
| 72 | + return msgpack.Ext(CUSTOM_TYPE_PICKLE, |
| 73 | + pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) |
| 74 | + |
| 75 | + def _encode_ndarray( |
| 76 | + self, obj: np.ndarray |
| 77 | + ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: |
| 78 | + assert self.aux_buffers is not None |
| 79 | + if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD: |
| 80 | + # Encode small arrays and scalars inline. |
| 81 | + data = obj.data |
| 82 | + else: |
| 83 | + # Otherwise encode index of backing buffer. |
| 84 | + obj = np.ascontiguousarray(obj) |
| 85 | + data = len(self.aux_buffers) |
| 86 | + self.aux_buffers.append(obj.data) |
| 87 | + # We serialize the ndarray as a tuple of native types. |
| 88 | + # The data is either inlined if small, or an index into a list of |
| 89 | + # backing buffers that we've stashed in `aux_buffers`. |
| 90 | + return obj.dtype.str, obj.shape, data |
27 | 91 |
|
28 | 92 |
|
29 | 93 | class MsgpackDecoder:
|
30 |
| - """Decoder with custom torch tensor serialization.""" |
| 94 | + """Decoder with custom torch tensor and numpy array serialization. |
| 95 | +
|
| 96 | + Note that unlike vanilla `msgspec` Decoders, this interface is generally |
| 97 | + not thread-safe when encoding tensors / numpy arrays. |
| 98 | + """ |
31 | 99 |
|
32 | 100 | def __init__(self, t: Optional[Any] = None):
|
33 | 101 | args = () if t is None else (t, )
|
34 |
| - self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook) |
35 |
| - |
36 |
| - def decode(self, obj: Any): |
37 |
| - return self.decoder.decode(obj) |
38 |
| - |
39 |
| - |
40 |
| -def custom_enc_hook(obj: Any) -> Any: |
41 |
| - if isinstance(obj, torch.Tensor): |
42 |
| - # NOTE(rob): it is fastest to use numpy + pickle |
43 |
| - # when serializing torch tensors. |
44 |
| - # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 |
45 |
| - return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) |
46 |
| - |
47 |
| - if isinstance(obj, FunctionType): |
48 |
| - return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) |
49 |
| - |
50 |
| - return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) |
51 |
| - |
52 |
| - |
53 |
| -def custom_ext_hook(code: int, data: memoryview) -> Any: |
54 |
| - if code == CUSTOM_TYPE_TENSOR: |
55 |
| - return torch.from_numpy(pickle.loads(data)) |
56 |
| - if code == CUSTOM_TYPE_PICKLE: |
57 |
| - return pickle.loads(data) |
58 |
| - if code == CUSTOM_TYPE_CLOUDPICKLE: |
59 |
| - return cloudpickle.loads(data) |
60 |
| - |
61 |
| - raise NotImplementedError(f"Extension type code {code} is not supported") |
| 102 | + self.decoder = msgpack.Decoder(*args, |
| 103 | + ext_hook=self.ext_hook, |
| 104 | + dec_hook=self.dec_hook) |
| 105 | + self.aux_buffers: Sequence[bytestr] = () |
| 106 | + |
| 107 | + def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: |
| 108 | + if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): |
| 109 | + # TODO - This check can become `isinstance(bufs, bytestr)` |
| 110 | + # as of Python 3.10. |
| 111 | + return self.decoder.decode(bufs) |
| 112 | + |
| 113 | + self.aux_buffers = bufs |
| 114 | + try: |
| 115 | + return self.decoder.decode(bufs[0]) |
| 116 | + finally: |
| 117 | + self.aux_buffers = () |
| 118 | + |
| 119 | + def dec_hook(self, t: type, obj: Any) -> Any: |
| 120 | + # Given native types in `obj`, convert to type `t`. |
| 121 | + if isclass(t): |
| 122 | + if issubclass(t, np.ndarray): |
| 123 | + return self._decode_ndarray(obj) |
| 124 | + if issubclass(t, torch.Tensor): |
| 125 | + return torch.from_numpy(self._decode_ndarray(obj)) |
| 126 | + return obj |
| 127 | + |
| 128 | + def _decode_ndarray(self, arr: Any) -> np.ndarray: |
| 129 | + dtype, shape, data = arr |
| 130 | + buffer = self.aux_buffers[data] if isinstance(data, int) else data |
| 131 | + return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) |
| 132 | + |
| 133 | + def ext_hook(self, code: int, data: memoryview) -> Any: |
| 134 | + if code == CUSTOM_TYPE_PICKLE: |
| 135 | + return pickle.loads(data) |
| 136 | + if code == CUSTOM_TYPE_CLOUDPICKLE: |
| 137 | + return cloudpickle.loads(data) |
| 138 | + |
| 139 | + raise NotImplementedError( |
| 140 | + f"Extension type code {code} is not supported") |
0 commit comments