You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Note that FC / GlobalAvgPool layers will still enforce `NC` dimension numbers for now, and
* batching still only works with the leading batch (`N`) dimension, so keeping `N` as leading dimension is highly recommended.
* `NHWC` is recommended currently for attention, `NCHW` or `NHWC` for CNNs.
Also:
1) Remove some warnings that I feel are redundant given our codebase - please let me know if I'm wrong.
2) Make the code slightly more generic, hopefully facilitating future ND-cnn cases.
3) Make `Flatten` work on batches of size 0.
4) Make `GeneralConv` public as it now supports different dimension numbers.
PiperOrigin-RevId: 290168022
-[Weight Space Linearization](https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/weight_space_linearization.ipynb)
81
81
-[Function Space Linearization](https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/function_space_linearization.ipynb)
We can define a more compex, (infinitely) Wide Residual Network [[14]](#14-wide-residual-networks-bmvc-2018-sergey-zagoruyko-nikos-komodakis) using the same `nt.stax` building blocks:
170
+
We can define a more compex, (infinitely) Wide Residual Network [[14]](#8-wide-residual-networks-bmvc-2018-sergey-zagoruyko-nikos-komodakis) using the same `nt.stax` building blocks:
172
171
173
172
```python
174
173
from neural_tangents import stax
@@ -246,7 +245,7 @@ import neural_tangents as nt # 64-bit precision enabled
246
245
We remark the following differences between our library and the JAX one.
247
246
248
247
* All `nt.stax` layers are instantiated with a function call, i.e. `nt.stax.Relu()` vs `jax.experimental.stax.Relu`.
249
-
* All layers with trainable parameters use the _NTK parameterization_ by default (see [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler), Remark 1). However, Dense and Conv layers also support the _standard parameterization_ via a `parameterization` keyword argument (see [[15]](#15-on-the-infinite-width-limit-of-neural-networks-with-a-standard-parameterization)).
248
+
* All layers with trainable parameters use the _NTK parameterization_ by default (see [[10]](#5-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler), Remark 1). However, Dense and Conv layers also support the _standard parameterization_ via a `parameterization` keyword argument. <!-- TODO(jaschasd) add link to note deriving NTK for standardparameterization-->
250
249
*`nt.stax` and `jax.experimental.stax` may have different layers and options available (for example `nt.stax` layers support `CIRCULAR` padding, but only `NHWC` data format).
251
250
252
251
### Python 2 is not supported
@@ -259,7 +258,7 @@ The kernel of an infinite network `kernel_fn(x1, x2).ntk` combined with `nt.pre
259
258
260
259
### Weight space
261
260
262
-
Continuous gradient descent in an infinite network has been shown in [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) to correspond to training a _linear_ (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
261
+
Continuous gradient descent in an infinite network has been shown in [[11]](#6-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) to correspond to training a _linear_ (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
Outputs of a linearized model evolve identically to those of an infinite one [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) but with a different kernel - specifically, the Neural Tangent Kernel [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, params)` that allows to compute the empirical NTK and NNGP kernels on specific `params`.
300
+
Outputs of a linearized model evolve identically to those of an infinite one [[11]](#6-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) but with a different kernel - specifically, the Neural Tangent Kernel [[10]](#5-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, params)` that allows to compute the empirical NTK and NNGP kernels on specific `params`.
302
301
303
302
#### Example:
304
303
@@ -356,25 +355,15 @@ a small dataset using a small learning rate.
356
355
357
356
## Papers
358
357
359
-
Neural Tangents has been used in the following papers:
360
-
361
-
362
-
*[Disentangling Trainability and Generalization in Deep Learning.](https://arxiv.org/abs/1912.13053)\
363
-
Lechao Xiao, Jeffrey Pennington, Samuel S. Schoenholz
364
-
365
-
*[Information in Infinite Ensembles of Infinitely-Wide Neural Networks.](https://arxiv.org/abs/1911.09189)\
366
-
Ravid Shwartz-Ziv, Alexander A. Alemi
367
-
368
-
*[Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.](https://arxiv.org/abs/1905.13654)\
369
-
Soufiane Hayou, Arnaud Doucet, Judith Rousseau
358
+
Neural tangents has been used in the following papers:
370
359
371
360
*[Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient
372
361
Descent.](https://arxiv.org/abs/1902.06720)\
373
362
Jaehoon Lee*, Lechao Xiao*, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha
374
363
Sohl-Dickstein, Jeffrey Pennington
375
364
376
-
*[On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf)\
377
-
Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
365
+
*[Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.](https://arxiv.org/abs/1905.13654)\
366
+
Soufiane Hayou, Arnaud Doucet, Judith Rousseau
378
367
379
368
Please let us know if you make use of the code in a publication and we'll add it
380
369
to the list!
@@ -427,5 +416,3 @@ If you use the code in a publication, please cite the repo using the .bib,
427
416
##### [13][Mean Field Residual Networks: On the Edge of Chaos.](https://arxiv.org/abs/1712.08969)*NeurIPS 2017.* Greg Yang, Samuel S. Schoenholz
##### [15][On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf)*arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
0 commit comments