|
17 | 17 | DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)
|
18 | 18 |
|
19 | 19 |
|
| 20 | +def _shard_map(func, mesh, input_specs, output_specs): |
| 21 | + """Map a function over shards of data. |
| 22 | +
|
| 23 | + Note: |
| 24 | + ``shard_map`` is an experimental API, and still subject to change. For an |
| 25 | + introduction to sharded data. For a more |
| 26 | + in-depth look at using ``shard_map``, refer to |
| 27 | + [SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html) |
| 28 | +
|
| 29 | + Args: |
| 30 | + func: callable to be mapped. Each application of ``f``, or "instance" of ``f``, |
| 31 | + takes as input a shard of the mapped-over arguments and produces a shard |
| 32 | + of the output. |
| 33 | + mesh: a ``Mesh`` representing the array of devices over which |
| 34 | + to shard the data and on which to execute instances of ``f``. The names of |
| 35 | + the ``Mesh`` can be used in collective communication operations in ``f``. |
| 36 | + This is typically created by a utility function like |
| 37 | + :func:`jax.experimental.mesh_utils.create_device_mesh`. |
| 38 | + in_specs: a tuple of tuples of str. Each is the partition spec of positional input |
| 39 | + of func. kwarg is not supported yet |
| 40 | + out_specs: a pytree with :class:`~tuple[tuple[str]]`, with the same length |
| 41 | + as the number of outputs |
| 42 | +
|
| 43 | + Returns: |
| 44 | + A callable that applies the input function ``f`` across data sharded according to |
| 45 | + the ``mesh`` and ``out_specs``. |
| 46 | +
|
| 47 | + Reference: |
| 48 | + This function behaves identically Jax's shard_map: |
| 49 | + https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html |
| 50 | + """ |
| 51 | + |
| 52 | + def _full_shape(a, spec): |
| 53 | + # a is local tensor |
| 54 | + # spec is the sharding spec |
| 55 | + # return logical shape of global tensor |
| 56 | + mesh_name_to_size = mesh.shape() |
| 57 | + |
| 58 | + result_shape = [] |
| 59 | + for axis_size, axis_sharding in zip(a.shape, spec): |
| 60 | + if axis_sharding is None: |
| 61 | + axis_sharding = () |
| 62 | + mesh_mult = [] |
| 63 | + if isinstance(axis_sharding, (str, int)): |
| 64 | + axis_sharding = [axis_sharding] |
| 65 | + for axis in axis_sharding: |
| 66 | + size = mesh_name_to_size[axis] or 1 |
| 67 | + mesh_mult.append(size) |
| 68 | + new_size = axis_size * math.prod(mesh_mult) |
| 69 | + result_shape.append(new_size) |
| 70 | + return tuple(result_shape) |
| 71 | + |
| 72 | + def wrapped(*args): |
| 73 | + assert len(args) == len( |
| 74 | + input_specs), f'args={len(args)}; input_specs={len(input_specs)}' |
| 75 | + new_args = [] |
| 76 | + for i, (a, spec) in enumerate(zip(args, input_specs)): |
| 77 | + if isinstance(a, torch.Tensor): |
| 78 | + assert (len(a.shape) == len(spec) |
| 79 | + ), f'{i}th input has wrong shape: {a.shape} for {spec}' |
| 80 | + new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor |
| 81 | + new_args.append(new_a) |
| 82 | + else: |
| 83 | + new_args.append(a) |
| 84 | + |
| 85 | + res = func(*new_args) |
| 86 | + if isinstance(res, tuple): |
| 87 | + res_updated = [] |
| 88 | + for i, (r, spec) in enumerate(zip(res, output_specs)): |
| 89 | + if isinstance(r, torch.Tensor) and spec is not None: |
| 90 | + assert str(r.device).startswith('xla'), f'{i}th device is {r.device}' |
| 91 | + assert len(r.shape) == len( |
| 92 | + spec), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}' |
| 93 | + new_r = xs.disable_manual_sharding( |
| 94 | + r, spec, _full_shape(r, spec), mesh=mesh).global_tensor |
| 95 | + else: |
| 96 | + new_r = r |
| 97 | + res_updated.append(new_r) |
| 98 | + return res_updated |
| 99 | + else: |
| 100 | + return xs.disable_manual_sharding( |
| 101 | + res, output_specs[0], _full_shape(res, output_specs[0]), |
| 102 | + mesh=mesh).global_tensor |
| 103 | + |
| 104 | + return wrapped |
| 105 | + |
| 106 | + |
20 | 107 | def _shard_map(func, mesh, input_specs, output_specs):
|
21 | 108 | """Map a function over shards of data.
|
22 | 109 |
|
|
0 commit comments