Skip to content

Commit 98ce60d

Browse files
committed
Allow arbitrary dimension numbers in stax layers.
* Note that FC / GlobalAvgPool layers will still enforce `NC` dimension numbers for now, and * batching still only works with the leading batch (`N`) dimension, so keeping `N` as leading dimension is highly recommended. * `NHWC` is recommended currently for attention, `NCHW` or `NHWC` for CNNs. Also: 1) Remove some warnings that I feel are redundant given our codebase - please let me know if I'm wrong. 2) Make the code slightly more generic, hopefully facilitating future ND-cnn cases. 3) Make `Flatten` work on batches of size 0. 4) Make `GeneralConv` public as it now supports different dimension numbers. PiperOrigin-RevId: 290168022
1 parent 818678a commit 98ce60d

File tree

6 files changed

+265
-5172
lines changed

6 files changed

+265
-5172
lines changed

README.md

+8-21
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Then either run
4848
```
4949
pip install neural-tangents
5050
```
51-
or, to use the bleeding-edge version from GitHub source,
51+
or, to build the bleeding-edge version from source,
5252
```
5353
git clone https://github.com/google/neural-tangents
5454
pip install -e neural-tangents
@@ -79,7 +79,6 @@ colab examples:
7979
- [Neural Tangents Cookbook](https://colab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb)
8080
- [Weight Space Linearization](https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/weight_space_linearization.ipynb)
8181
- [Function Space Linearization](https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/function_space_linearization.ipynb)
82-
- [Neural Network Phase Diagram](https://colab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/phase_diagram.ipynb)
8382

8483

8584
## 5-Minute intro
@@ -168,7 +167,7 @@ y_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test,
168167

169168
### Infinitely WideResnet
170169

171-
We can define a more compex, (infinitely) Wide Residual Network [[14]](#14-wide-residual-networks-bmvc-2018-sergey-zagoruyko-nikos-komodakis) using the same `nt.stax` building blocks:
170+
We can define a more compex, (infinitely) Wide Residual Network [[14]](#8-wide-residual-networks-bmvc-2018-sergey-zagoruyko-nikos-komodakis) using the same `nt.stax` building blocks:
172171

173172
```python
174173
from neural_tangents import stax
@@ -246,7 +245,7 @@ import neural_tangents as nt # 64-bit precision enabled
246245
We remark the following differences between our library and the JAX one.
247246

248247
* All `nt.stax` layers are instantiated with a function call, i.e. `nt.stax.Relu()` vs `jax.experimental.stax.Relu`.
249-
* All layers with trainable parameters use the _NTK parameterization_ by default (see [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler), Remark 1). However, Dense and Conv layers also support the _standard parameterization_ via a `parameterization` keyword argument (see [[15]](#15-on-the-infinite-width-limit-of-neural-networks-with-a-standard-parameterization)).
248+
* All layers with trainable parameters use the _NTK parameterization_ by default (see [[10]](#5-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler), Remark 1). However, Dense and Conv layers also support the _standard parameterization_ via a `parameterization` keyword argument. <!-- TODO(jaschasd) add link to note deriving NTK for standard parameterization -->
250249
* `nt.stax` and `jax.experimental.stax` may have different layers and options available (for example `nt.stax` layers support `CIRCULAR` padding, but only `NHWC` data format).
251250

252251
### Python 2 is not supported
@@ -259,7 +258,7 @@ The kernel of an infinite network `kernel_fn(x1, x2).ntk` combined with `nt.pre
259258

260259
### Weight space
261260

262-
Continuous gradient descent in an infinite network has been shown in [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) to correspond to training a _linear_ (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
261+
Continuous gradient descent in an infinite network has been shown in [[11]](#6-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) to correspond to training a _linear_ (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
263262

264263
For this, we provide two convenient methods:
265264

@@ -298,7 +297,7 @@ logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray
298297

299298
### Function space:
300299

301-
Outputs of a linearized model evolve identically to those of an infinite one [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) but with a different kernel - specifically, the Neural Tangent Kernel [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, params)` that allows to compute the empirical NTK and NNGP kernels on specific `params`.
300+
Outputs of a linearized model evolve identically to those of an infinite one [[11]](#6-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) but with a different kernel - specifically, the Neural Tangent Kernel [[10]](#5-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, params)` that allows to compute the empirical NTK and NNGP kernels on specific `params`.
302301

303302
#### Example:
304303

@@ -356,25 +355,15 @@ a small dataset using a small learning rate.
356355

357356
## Papers
358357

359-
Neural Tangents has been used in the following papers:
360-
361-
362-
* [Disentangling Trainability and Generalization in Deep Learning.](https://arxiv.org/abs/1912.13053) \
363-
Lechao Xiao, Jeffrey Pennington, Samuel S. Schoenholz
364-
365-
* [Information in Infinite Ensembles of Infinitely-Wide Neural Networks.](https://arxiv.org/abs/1911.09189) \
366-
Ravid Shwartz-Ziv, Alexander A. Alemi
367-
368-
* [Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.](https://arxiv.org/abs/1905.13654) \
369-
Soufiane Hayou, Arnaud Doucet, Judith Rousseau
358+
Neural tangents has been used in the following papers:
370359

371360
* [Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient
372361
Descent.](https://arxiv.org/abs/1902.06720) \
373362
Jaehoon Lee*, Lechao Xiao*, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha
374363
Sohl-Dickstein, Jeffrey Pennington
375364

376-
* [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) \
377-
Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
365+
* [Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.](https://arxiv.org/abs/1905.13654) \
366+
Soufiane Hayou, Arnaud Doucet, Judith Rousseau
378367

379368
Please let us know if you make use of the code in a publication and we'll add it
380369
to the list!
@@ -427,5 +416,3 @@ If you use the code in a publication, please cite the repo using the .bib,
427416
##### [13] [Mean Field Residual Networks: On the Edge of Chaos.](https://arxiv.org/abs/1712.08969) *NeurIPS 2017.* Greg Yang, Samuel S. Schoenholz
428417

429418
##### [14] [Wide Residual Networks.](https://arxiv.org/abs/1605.07146) *BMVC 2018.* Sergey Zagoruyko, Nikos Komodakis
430-
431-
##### [15] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee

0 commit comments

Comments
 (0)