Skip to content

Commit ed1707a

Browse files
authored
Add initial nnx/optax test (#4)
1 parent 46cf3a0 commit ed1707a

File tree

4 files changed

+98
-2
lines changed

4 files changed

+98
-2
lines changed

.github/workflows/test.yaml

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
name: Test
2+
3+
on:
4+
# Trigger the workflow on push or pull request, but only on main branch
5+
push:
6+
branches:
7+
- main
8+
pull_request:
9+
branches:
10+
- main
11+
12+
permissions:
13+
contents: read # to fetch code
14+
15+
jobs:
16+
build:
17+
name: ${{ matrix.os }} Python ${{ matrix.python-version }}
18+
runs-on: ${{ matrix.os }}
19+
strategy:
20+
matrix:
21+
os: ["ubuntu-latest"]
22+
python-version: ["3.9", "3.10", "3.11", "3.12"]
23+
include:
24+
- os: macOS-11
25+
python-version: "3.11"
26+
- os: windows-2019
27+
python-version: "3.11"
28+
29+
steps:
30+
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
31+
with:
32+
submodules: true
33+
- name: Set up Python ${{ matrix.python-version }}
34+
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
35+
with:
36+
python-version: ${{ matrix.python-version }}
37+
- name: Install dependencies
38+
run: |
39+
python -m pip install --upgrade pip
40+
pip install .[dev]
41+
- name: Run tests
42+
run: |
43+
pytest -n auto jax_ml_stack

jax_ml_stack/tests/__init__.py

Whitespace-only changes.
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import jax
17+
import jax.numpy as jnp
18+
from flax import nnx
19+
import optax
20+
21+
22+
class SimpleModel(nnx.Module):
23+
24+
def __init__(self, rngs):
25+
self.layer1 = nnx.Linear(2, 5, rngs=rngs)
26+
self.layer2 = nnx.Linear(5, 3, rngs=rngs)
27+
28+
def __call__(self, x):
29+
for layer in [self.layer1, self.layer2]:
30+
x = layer(x)
31+
return x
32+
33+
34+
class NNXOptaxTest(unittest.TestCase):
35+
36+
def test_nnx_optax(self):
37+
key = jax.random.key(1701)
38+
x = jax.random.normal(key, (1, 2))
39+
y = jnp.ones((1, 3))
40+
41+
model = SimpleModel(nnx.Rngs(0))
42+
optimizer = optax.adam(learning_rate=1e-3)
43+
state = nnx.Optimizer(model, optimizer)
44+
45+
def loss(model, x=x, y=y):
46+
return jnp.mean((model(x) - y) ** 2)
47+
48+
initial_loss = loss(model)
49+
grads = nnx.grad(loss, wrt=nnx.Param)(state.model)
50+
state.update(grads)
51+
final_loss = loss(model)
52+
53+
self.assertNotAlmostEqual(initial_loss, final_loss)

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ keywords = []
1616
# pip dependencies of the project
1717
dependencies = [
1818
"jax",
19+
"flax",
20+
"optax",
1921
]
2022

2123
[project.urls]
@@ -29,8 +31,6 @@ repository = "https://github.com/jax-ml/jax_ml_stack"
2931
dev = [
3032
"pytest",
3133
"pytest-xdist",
32-
"pylint>=2.6.0",
33-
"pyink",
3434
]
3535

3636
[tool.pyink]

0 commit comments

Comments
 (0)