diff --git a/docs/api-reference.md b/docs/api-reference.md
index 32205248..2483a55d 100644
--- a/docs/api-reference.md
+++ b/docs/api-reference.md
@@ -8,6 +8,7 @@
 
     at
     atleast_nd
+    broadcast_shapes
     cov
     create_diagonal
     expand_dims
diff --git a/docs/conf.py b/docs/conf.py
index afa3bd5e..eff2a33d 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -54,6 +54,7 @@
 intersphinx_mapping = {
     "python": ("https://docs.python.org/3", None),
     "array-api": ("https://data-apis.org/array-api/draft", None),
+    "numpy": ("https://numpy.org/doc/stable", None),
     "jax": ("https://jax.readthedocs.io/en/latest", None),
 }
 
diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py
index 840dd8e7..4a49fd48 100644
--- a/src/array_api_extra/__init__.py
+++ b/src/array_api_extra/__init__.py
@@ -4,6 +4,7 @@
 from ._lib._at import at
 from ._lib._funcs import (
     atleast_nd,
+    broadcast_shapes,
     cov,
     create_diagonal,
     expand_dims,
@@ -20,6 +21,7 @@
     "__version__",
     "at",
     "atleast_nd",
+    "broadcast_shapes",
     "cov",
     "create_diagonal",
     "expand_dims",
diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py
index f7eb8c88..a5729559 100644
--- a/src/array_api_extra/_lib/_funcs.py
+++ b/src/array_api_extra/_lib/_funcs.py
@@ -17,6 +17,7 @@
 
 __all__ = [
     "atleast_nd",
+    "broadcast_shapes",
     "cov",
     "create_diagonal",
     "expand_dims",
@@ -71,6 +72,69 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
     return x
 
 
+# `float` in signature to accept `math.nan` for Dask.
+# `int`s are still accepted as `float` is a superclass of `int` in typing
+def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]:
+    """
+    Compute the shape of the broadcasted arrays.
+
+    Duplicates :func:`numpy.broadcast_shapes`, with additional support for
+    None and NaN sizes.
+
+    This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
+    without needing to worry about the backend potentially deep copying
+    the arrays.
+
+    Parameters
+    ----------
+    *shapes : tuple[int | None, ...]
+        Shapes of the arrays to broadcast.
+
+    Returns
+    -------
+    tuple[int | None, ...]
+        The shape of the broadcasted arrays.
+
+    See Also
+    --------
+    numpy.broadcast_shapes : Equivalent NumPy function.
+    array_api.broadcast_arrays : Function to broadcast actual arrays.
+
+    Notes
+    -----
+    This function accepts the Array API's ``None`` for unknown sizes,
+    as well as Dask's non-standard ``math.nan``.
+    Regardless of input, the output always contains ``None`` for unknown sizes.
+
+    Examples
+    --------
+    >>> import array_api_extra as xpx
+    >>> xpx.broadcast_shapes((2, 3), (2, 1))
+    (2, 3)
+    >>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3))
+    (4, 2, 3)
+    """
+    if not shapes:
+        return ()  # Match numpy output
+
+    ndim = max(len(shape) for shape in shapes)
+    out: list[int | None] = []
+    for axis in range(-ndim, 0):
+        sizes = {shape[axis] for shape in shapes if axis >= -len(shape)}
+        # Dask uses NaN for unknown shape, which predates the Array API spec for None
+        none_size = None in sizes or math.nan in sizes
+        sizes -= {1, None, math.nan}
+        if len(sizes) > 1:
+            msg = (
+                "shape mismatch: objects cannot be broadcast to a single shape: "
+                f"{shapes}."
+            )
+            raise ValueError(msg)
+        out.append(None if none_size else cast(int, sizes.pop()) if sizes else 1)
+
+    return tuple(out)
+
+
 def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
     """
     Estimate a covariance matrix.
diff --git a/tests/test_funcs.py b/tests/test_funcs.py
index 2b900a85..84d2f5d1 100644
--- a/tests/test_funcs.py
+++ b/tests/test_funcs.py
@@ -1,4 +1,5 @@
 import contextlib
+import math
 import warnings
 from types import ModuleType
 
@@ -8,6 +9,7 @@
 from array_api_extra import (
     at,
     atleast_nd,
+    broadcast_shapes,
     cov,
     create_diagonal,
     expand_dims,
@@ -113,6 +115,63 @@ def test_xp(self, xp: ModuleType):
         xp_assert_equal(y, xp.ones((1,)))
 
 
+class TestBroadcastShapes:
+    @pytest.mark.parametrize(
+        "args",
+        [
+            (),
+            ((),),
+            ((), ()),
+            ((1,),),
+            ((1,), (1,)),
+            ((2,), (1,)),
+            ((3, 1, 4), (2, 1)),
+            ((1, 1, 4), (2, 1)),
+            ((1,), ()),
+            ((), (2,), ()),
+            ((0,),),
+            ((0,), (1,)),
+            ((2, 0), (1, 1)),
+            ((2, 0, 3), (2, 1, 1)),
+        ],
+    )
+    def test_simple(self, args: tuple[tuple[int, ...], ...]):
+        expect = np.broadcast_shapes(*args)
+        actual = broadcast_shapes(*args)
+        assert actual == expect
+
+    @pytest.mark.parametrize(
+        "args",
+        [
+            ((2,), (3,)),
+            ((2, 3), (1, 2)),
+            ((2,), (0,)),
+            ((2, 0, 2), (1, 3, 1)),
+        ],
+    )
+    def test_fail(self, args: tuple[tuple[int, ...], ...]):
+        match = "cannot be broadcast to a single shape"
+        with pytest.raises(ValueError, match=match):
+            _ = np.broadcast_shapes(*args)
+        with pytest.raises(ValueError, match=match):
+            _ = broadcast_shapes(*args)
+
+    @pytest.mark.parametrize(
+        "args",
+        [
+            ((None,), (None,)),
+            ((math.nan,), (None,)),
+            ((1, None, 2, 4), (2, 3, None, 1), (2, None, None, 4)),
+            ((1, math.nan, 2), (4, 2, 3, math.nan), (4, 2, None, None)),
+            ((math.nan, 1), (None, 2), (None, 2)),
+        ],
+    )
+    def test_none(self, args: tuple[tuple[float | None, ...], ...]):
+        expect = args[-1]
+        actual = broadcast_shapes(*args[:-1])
+        assert actual == expect
+
+
 @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
 class TestCov:
     def test_basic(self, xp: ModuleType):