Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit 4b183ce

Browse files
committed
Refactoring: a long awaited refactor that splits the huge stax.py into subcomponents, among other things:
(1) Move implementations into `_src`, and import public functions from sources into the top-level modules. This makes it easier to manage, remember, and view our public API. Note that as a downside this makes it less convenient to "hack" our library, where users might want to use our private functions or modules of the library. (2) Split `stax` into 4 parts: `requirements`, `elementwise`, `linear`, and `combinators`. This allows to better understand the structure of `stax` and make it easier to browse / implement new layers or combinators, unless they don't fall squarely into any category. (5) Move out `test_utils` from the library and into the tests folder only. Decouple more tests/stex/outside users from library internals. (6) Rename `batch` into `batching` to avoid confusing module with the function. (7) Remove dependence on `jax.lib.xla_bridge`. PiperOrigin-RevId: 429185591
1 parent b9c2c57 commit 4b183ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+13257
-12684
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,4 @@
199199
distributed under the License is distributed on an "AS IS" BASIS,
200200
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201201
See the License for the specific language governing permissions and
202-
limitations under the License.
202+
limitations under the License.

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ The `neural_tangents` (`nt`) package contains the following modules and function
235235

236236
* `monte_carlo_kernel_fn` - compute a Monte Carlo kernel estimate of _any_ `(init_fn, apply_fn)`, not necessarily specified via `nt.stax`, enabling the kernel computation of infinite networks without closed-form expressions.
237237

238-
* Tools to investigate training dynamics of _wide but finite_ neural networks, like `linearize`, `taylor_expand`, `empirical_kernel_fn` and more. See [Training dynamics of wide but finite networks](#training-dynamics-of-wide-but-finite-networks) for details.
238+
* Tools to investigate training dynamics of _wide but finite_ neural networks, like `linearize`, `taylor_expand`, `empirical.kernel_fn` and more. See [Training dynamics of wide but finite networks](#training-dynamics-of-wide-but-finite-networks) for details.
239239

240240

241241
## Technical gotchas
@@ -311,10 +311,12 @@ import jax.random as random
311311
import jax.numpy as np
312312
import neural_tangents as nt
313313

314+
314315
def apply_fn(params, x):
315316
W, b = params
316317
return np.dot(x, W) + b
317318

319+
318320
W_0 = np.array([[1., 0.], [0., 1.]])
319321
b_0 = np.zeros((2,))
320322
params = (W_0, b_0)

docs/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ neural networks.
99
:caption: Topics:
1010

1111
neural_tangents.stax
12-
neural_tangents.empirical
12+
neural_tangents._src.empirical
1313
neural_tangents.predict
14-
neural_tangents.batching
15-
neural_tangents.monte_carlo
14+
neural_tangents._src.batching
15+
neural_tangents._src.monte_carlo
1616

1717
Indices and tables
1818
==================

docs/neural_tangents.batching.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ Batching
22
===========================
33

44
.. default-role:: code
5-
.. automodule:: neural_tangents.utils.batch
5+
.. automodule:: neural_tangents._src.batching
66
:members:

docs/neural_tangents.empirical.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ Empirical
22
===========================
33

44
.. default-role:: code
5-
.. automodule:: neural_tangents.utils.empirical
5+
.. automodule:: neural_tangents._src.empirical
66
:members:

docs/neural_tangents.monte_carlo.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ Monte Carlo Sampling
22
===========================
33

44
.. default-role:: code
5-
.. automodule:: neural_tangents.utils.monte_carlo
5+
.. automodule:: neural_tangents._src.monte_carlo
66
:members:

examples/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2019 Google LLC
1+
# Copyright 2022 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

examples/datasets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,11 @@ def minibatch(x_train, y_train, batch_size, train_epochs):
9191

9292
if end > x_train.shape[0]:
9393
key, split = random.split(key)
94-
permutation = random.shuffle(split,
95-
np.arange(x_train.shape[0], dtype=np.int64))
94+
permutation = random.permutation(
95+
split,
96+
np.arange(x_train.shape[0], dtype=np.int64),
97+
independent=True
98+
)
9699
x_train = x_train[permutation]
97100
y_train = y_train[permutation]
98101
epoch += 1

examples/weight_space.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
def main(unused_argv):
44-
# Build data and .
44+
# Load data and preprocess it.
4545
print('Loading data.')
4646
x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
4747
permute_train=True)

neural_tangents/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
"""Public Neural Tangents modules and functions."""
1717

1818

19-
__version__ = '0.4.0'
19+
__version__ = '0.5.0'
2020

2121

22-
from neural_tangents import predict
23-
from neural_tangents import stax
24-
from neural_tangents.utils.batch import batch
25-
from neural_tangents.utils.empirical import empirical_kernel_fn
26-
from neural_tangents.utils.empirical import empirical_nngp_fn
27-
from neural_tangents.utils.empirical import empirical_ntk_fn
28-
from neural_tangents.utils.empirical import linearize
29-
from neural_tangents.utils.empirical import taylor_expand
30-
from neural_tangents.utils.monte_carlo import monte_carlo_kernel_fn
22+
from . import predict
23+
from . import stax
24+
from ._src.batching import batch
25+
from ._src.empirical import empirical_kernel_fn
26+
from ._src.empirical import empirical_nngp_fn
27+
from ._src.empirical import empirical_ntk_fn
28+
from ._src.empirical import linearize
29+
from ._src.empirical import taylor_expand
30+
from ._src.monte_carlo import monte_carlo_kernel_fn
File renamed without changes.

neural_tangents/utils/batch.py renamed to neural_tangents/_src/batching.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555
from jax.tree_util import tree_all
5656
from jax.tree_util import tree_map
5757
from jax.tree_util import tree_multimap, tree_flatten, tree_unflatten
58-
from neural_tangents.utils.kernel import Kernel
59-
from neural_tangents.utils import utils
60-
from neural_tangents.utils.typing import KernelFn, NTTree
58+
from .utils.kernel import Kernel
59+
from .utils import utils
60+
from .utils.typing import KernelFn, NTTree
6161

6262
import numpy as onp
6363

@@ -79,15 +79,18 @@ def batch(kernel_fn: KernelFn,
7979
`kernel_fn(x1, x2, *args, **kwargs)`. Here `x1` and `x2` are
8080
`np.ndarray`s of shapes `(n1,) + input_shape` and `(n2,) + input_shape`.
8181
The kernel function should return a `PyTree`.
82+
8283
batch_size:
8384
specifies the size of each batch that gets processed per physical device.
8485
Because we parallelize the computation over columns it should be the case
8586
that `x1.shape[0]` is divisible by `device_count * batch_size` and
8687
`x2.shape[0]` is divisible by `batch_size`.
88+
8789
device_count:
8890
specifies the number of physical devices to be used. If
8991
`device_count == -1` all devices are used. If `device_count == 0`, no
9092
device parallelism is used (a single default device is used).
93+
9194
store_on_device:
9295
specifies whether the output should be kept on device or brought back to
9396
CPU RAM as it is computed. Defaults to `True`. Set to `False` to store
@@ -249,7 +252,6 @@ def _flatten_kernel(k: Kernel,
249252
def _reshape_kernel_for_pmap(k: Kernel,
250253
device_count: int,
251254
n1_per_device: int) -> Kernel:
252-
# pytype: disable=attribute-error
253255
cov2 = k.cov2
254256
if cov2 is None:
255257
cov2 = k.cov1
@@ -283,7 +285,6 @@ def _set_cov2_to_none(
283285
if isinstance(k, Kernel):
284286
k = k.replace(cov2=None)
285287
return k
286-
# pytype: enable=attribute-error
287288

288289

289290
def _serial(kernel_fn: KernelFn,
@@ -444,8 +445,7 @@ def col_fn(n1, n2):
444445
in_kernel = slice_kernel(k, n1_slice, n2_slice)
445446
return (n1, kwargs1), kernel_fn(in_kernel, *args, **kwargs_merge)
446447

447-
cov2_is_none = utils.nt_tree_fn(reduce=lambda k: all(k))(lambda k:
448-
k.cov2 is None)(k)
448+
cov2_is_none = utils.nt_tree_fn(reduce=all)(lambda k: k.cov2 is None)(k)
449449
_, k = _scan(row_fn, 0, (n1s, kwargs_np1))
450450
if cov2_is_none:
451451
k = _set_cov2_to_none(k)
@@ -520,7 +520,7 @@ def _check_dropout(n1, n2, kwargs):
520520
'Using `serial` (i.e. use a non-zero batch_size in the '
521521
'`batch` function.) could enforce square batch size in each device.')
522522

523-
def _get_n_per_device(n1, n2):
523+
def _get_n_per_device(n1):
524524
_device_count = device_count
525525

526526
n1_per_device, ragged = divmod(n1, device_count)
@@ -549,7 +549,7 @@ def get_batch_size(x):
549549
n2 = n1 if x2_is_none else get_batch_size(x2)
550550

551551
_check_dropout(n1, n2, kwargs)
552-
n1_per_device, _device_count = _get_n_per_device(n1, n2)
552+
n1_per_device, _device_count = _get_n_per_device(n1)
553553

554554
_kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count)
555555

@@ -579,7 +579,7 @@ def get_batch_sizes(k):
579579

580580
n1, n2 = get_batch_sizes(kernel)
581581
_check_dropout(n1, n2, kwargs)
582-
n1_per_device, _device_count = _get_n_per_device(n1, n2)
582+
n1_per_device, _device_count = _get_n_per_device(n1)
583583

584584
_kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count)
585585

neural_tangents/utils/empirical.py renamed to neural_tangents/_src/empirical.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
All functions in this module are applicable to any JAX functions of proper
1818
signatures (not only those from `nt.stax`).
1919
20-
NNGP and NTK are computed using `empirical_nngp_fn`, `empirical_ntk_fn`, or
21-
`empirical_kernel_fn` (for both). The kernels have a very specific output shape
22-
convention that may be unexpected. Further, NTK has multiple implementations
23-
that may perform differently depending on the task. Please read individual
24-
functions' docstrings.
20+
NNGP and NTK are computed using `empirical_nngp_fn`, `nt.empirical_ntk_fn`, or
21+
`nt.empirical_kernel_fn` (for both). The kernels have a very specific output
22+
shape convention that may be unexpected. Further, NTK has multiple
23+
implementations that may perform differently depending on the task. Please read
24+
individual functions' docstrings.
2525
2626
Example:
2727
>>> from jax import random
@@ -49,18 +49,18 @@
4949
>>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
5050
>>> # network is iid along the batch axis, no BatchNorm. Use default
5151
>>> # `implementation=1` since the network has few trainable parameters.
52-
>>> kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(-1,),
53-
>>> vmap_axes=0, implementation=1)
52+
>>> kernel_fn = nt.empirical_kernel_fn(
53+
>>> f, trace_axes=(-1,), vmap_axes=0, implementation=1)
5454
>>>
5555
>>> # (5, 20) np.ndarray test-train NNGP/NTK
56-
>>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
57-
>>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
56+
>>> nngp_test_train = empirical_kernel_fn(x_test, x_train, 'nngp', params)
57+
>>> ntk_test_train = empirical_kernel_fn(x_test, x_train, 'ntk', params)
5858
>>>
5959
>>> # Full kernel: not reducing over logits.
6060
>>> kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(), vmap_axes=0)
6161
>>>
6262
>>> # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple.
63-
>>> k_test_train = kernel_fn(x_test, x_train, params)
63+
>>> k_test_train = empirical_kernel_fn(x_test, x_train, params)
6464
>>>
6565
>>> # A wide FCN with lots of parameters
6666
>>> init_fn, f, _ = stax.serial(
@@ -79,22 +79,22 @@
7979
>>> ntk_fn = nt.empirical_ntk_fn(f, vmap_axes=0, implementation=2)
8080
>>>
8181
>>> # (5, 5) np.ndarray test-test NTK
82-
>>> ntk_test_train = ntk_fn(x_test, None, params)
82+
>>> ntk_test_train = empirical_ntk_fn(x_test, None, params)
8383
>>>
8484
>>> # Compute only output variances:
8585
>>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
8686
>>>
8787
>>> # (20,) np.ndarray train-train diagonal NNGP
88-
>>> nngp_train_train_diag = nngp_fn(x_train, None, params)
88+
>>> nngp_train_train_diag = empirical_nngp_fn(x_train, None, params)
8989
"""
9090

9191
import operator
9292
from typing import Union, Callable, Optional, Tuple, Dict
9393
from jax import eval_shape, jacobian, jvp, vjp, vmap, linear_transpose
9494
import jax.numpy as np
9595
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_map
96-
from neural_tangents.utils import utils
97-
from neural_tangents.utils.typing import ApplyFn, EmpiricalKernelFn, NTTree, PyTree, Axes, VMapAxes, VMapAxisTriple
96+
from .utils import utils
97+
from .utils.typing import ApplyFn, EmpiricalKernelFn, NTTree, PyTree, Axes, VMapAxes, VMapAxisTriple
9898

9999

100100
def linearize(f: Callable[..., PyTree],
@@ -589,22 +589,22 @@ def empirical_ntk_fn(f: ApplyFn,
589589
vmap_axes=vmap_axes)
590590

591591
if implementation == 1:
592-
return _empirical_direct_ntk_fn(**kwargs)
592+
return _direct_ntk_fn(**kwargs)
593593

594594
if implementation == 2:
595-
return _empirical_implicit_ntk_fn(**kwargs)
595+
return _implicit_ntk_fn(**kwargs)
596596

597597
raise ValueError(implementation)
598598

599599

600-
def _empirical_implicit_ntk_fn(f: ApplyFn,
601-
trace_axes: Axes = (-1,),
602-
diagonal_axes: Axes = (),
603-
vmap_axes: VMapAxes = None
604-
) -> Callable[[NTTree[np.ndarray],
605-
Optional[NTTree[np.ndarray]],
606-
PyTree],
607-
NTTree[np.ndarray]]:
600+
def _implicit_ntk_fn(f: ApplyFn,
601+
trace_axes: Axes = (-1,),
602+
diagonal_axes: Axes = (),
603+
vmap_axes: VMapAxes = None
604+
) -> Callable[[NTTree[np.ndarray],
605+
Optional[NTTree[np.ndarray]],
606+
PyTree],
607+
NTTree[np.ndarray]]:
608608
"""Compute NTK implicitly without instantiating full Jacobians."""
609609

610610
def ntk_fn(x1: NTTree[np.ndarray],
@@ -688,14 +688,14 @@ def delta_vjp(delta):
688688
return ntk_fn
689689

690690

691-
def _empirical_direct_ntk_fn(f: ApplyFn,
692-
trace_axes: Axes = (-1,),
693-
diagonal_axes: Axes = (),
694-
vmap_axes: VMapAxes = None
695-
) -> Callable[[NTTree[np.ndarray],
696-
Optional[NTTree[np.ndarray]],
697-
PyTree],
698-
NTTree[np.ndarray]]:
691+
def _direct_ntk_fn(f: ApplyFn,
692+
trace_axes: Axes = (-1,),
693+
diagonal_axes: Axes = (),
694+
vmap_axes: VMapAxes = None
695+
) -> Callable[[NTTree[np.ndarray],
696+
Optional[NTTree[np.ndarray]],
697+
PyTree],
698+
NTTree[np.ndarray]]:
699699
"""Compute NTK by directly instantiating Jacobians and contracting."""
700700

701701
@utils.nt_tree_fn(tree_structure_argnum=0)

0 commit comments

Comments
 (0)