Skip to content

Commit 68eab1f

Browse files
committed
“referece_tutorial_links_added”
1 parent de194c3 commit 68eab1f

4 files changed

+35
-20
lines changed

docs/conf.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,20 @@
3030
html_theme = 'sphinx_book_theme'
3131
html_title = 'JAX AI Stack'
3232
html_static_path = ['_static']
33+
html_css_files = ['css/custom.css']
34+
html_logo = '_static/ai-stack-logo.svg'
35+
html_favicon = '_static/favicon.png'
3336

3437
# Theme-specific options
3538
# https://sphinx-book-theme.readthedocs.io/en/stable/reference.html
3639
html_theme_options = {
3740
'show_navbar_depth': 2,
3841
'show_toc_level': 2,
3942
'repository_url': 'https://github.com/jax-ml/jax-ai-stack',
40-
'path_to_docs': 'docs/',
43+
'path_to_docs': 'docs/source/',
4144
'use_repository_button': True,
4245
'navigation_with_keys': True,
46+
'home_page_in_toc': True,
4347
}
4448

4549
exclude_patterns = [
@@ -67,6 +71,7 @@
6771

6872
suppress_warnings = [
6973
'misc.highlighting_failure', # Suppress warning in exception in digits_vae
74+
'mystnb.unknown_mime_type', # Suppress warning for unknown mime type (e.g. colab-display-data+json)
7075
]
7176

7277
# -- Options for myst ----------------------------------------------

docs/data_loaders_for_multi_device_setups_with_jax.ipynb

+13-9
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@
2323
"* [**Grain**](https://github.com/google/grain)\n",
2424
"* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n",
2525
"\n",
26-
"You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.\n",
26+
"You'll learn how to use each of these libraries to efficiently load data for an image classification task using the MNIST dataset.\n",
2727
"\n",
28-
"Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide introduces optimizations for distributed training across multiple GPUs or TPUs. It focuses on data sharding with `Mesh` and `NamedSharding` to efficiently partition and synchronize data across devices. By leveraging multi-device setups, you'll maximize resource utilization for large datasets in distributed environments."
28+
"Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide covers advanced strategies for multi-device setups, such as data sharding with `Mesh` and `NamedSharding` to partition and synchronize data across devices. These techniques allow you to partition and synchronize data across multiple devices, balancing the complexities of distributed systems while optimizing resource usage for large-scale datasets.\n",
29+
"\n",
30+
"If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).\n",
31+
"\n",
32+
"If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html)."
2933
]
3034
},
3135
{
@@ -57,7 +61,7 @@
5761
"id": "TsFdlkSZKp9S"
5862
},
5963
"source": [
60-
"### Checking TPU Availability for JAX"
64+
"## Checking TPU Availability for JAX"
6165
]
6266
},
6367
{
@@ -99,7 +103,7 @@
99103
"id": "qyJ_WTghDnIc"
100104
},
101105
"source": [
102-
"### Setting Hyperparameters and Initializing Parameters\n",
106+
"## Setting Hyperparameters and Initializing Parameters\n",
103107
"\n",
104108
"You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network."
105109
]
@@ -141,7 +145,7 @@
141145
"id": "rHLdqeI7D2WZ"
142146
},
143147
"source": [
144-
"### Model Prediction with Auto-Batching\n",
148+
"## Model Prediction with Auto-Batching\n",
145149
"\n",
146150
"In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n",
147151
"\n",
@@ -182,7 +186,7 @@
182186
"id": "AMWmxjVEpH2D"
183187
},
184188
"source": [
185-
"Multi-device setup using a Mesh of devices"
189+
"## Multi-device setup using a Mesh of devices"
186190
]
187191
},
188192
{
@@ -210,7 +214,7 @@
210214
"id": "rLqfeORsERek"
211215
},
212216
"source": [
213-
"### Utility and Loss Functions\n",
217+
"## Utility and Loss Functions\n",
214218
"\n",
215219
"You'll now define utility functions for:\n",
216220
"- One-hot encoding: Converts class indices to binary vectors.\n",
@@ -1676,9 +1680,9 @@
16761680
"source": [
16771681
"## Summary\n",
16781682
"\n",
1679-
"This notebook has introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for machine learning tasks. Each library offers distinct advantages, allowing you to select the best approach for your specific project needs.\n",
1683+
"This notebook introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to optimize the data loading process for machine learning tasks. Each library offers unique advantages, enabling you to choose the best approach based on your project’s requirements.\n",
16801684
"\n",
1681-
"For more detailed strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html)."
1685+
"For more in-depth strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html)."
16821686
]
16831687
}
16841688
],

docs/data_loaders_for_multi_device_setups_with_jax.md

+13-9
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ This tutorial explores various data loading strategies for **JAX** in **multi-de
2525
* [**Grain**](https://github.com/google/grain)
2626
* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)
2727

28-
You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.
28+
You'll learn how to use each of these libraries to efficiently load data for an image classification task using the MNIST dataset.
2929

30-
Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide introduces optimizations for distributed training across multiple GPUs or TPUs. It focuses on data sharding with `Mesh` and `NamedSharding` to efficiently partition and synchronize data across devices. By leveraging multi-device setups, you'll maximize resource utilization for large datasets in distributed environments.
30+
Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide covers advanced strategies for multi-device setups, such as data sharding with `Mesh` and `NamedSharding` to partition and synchronize data across devices. These techniques allow you to partition and synchronize data across multiple devices, balancing the complexities of distributed systems while optimizing resource usage for large-scale datasets.
31+
32+
If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).
33+
34+
If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).
3135

3236
+++ {"id": "-rsMgVtO6asW"}
3337

@@ -44,7 +48,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding
4448

4549
+++ {"id": "TsFdlkSZKp9S"}
4650

47-
### Checking TPU Availability for JAX
51+
## Checking TPU Availability for JAX
4852

4953
```{code-cell}
5054
---
@@ -58,7 +62,7 @@ jax.devices()
5862

5963
+++ {"id": "qyJ_WTghDnIc"}
6064

61-
### Setting Hyperparameters and Initializing Parameters
65+
## Setting Hyperparameters and Initializing Parameters
6266

6367
You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network.
6468

@@ -90,7 +94,7 @@ params = init_network_params(layer_sizes, random.PRNGKey(0))
9094

9195
+++ {"id": "rHLdqeI7D2WZ"}
9296

93-
### Model Prediction with Auto-Batching
97+
## Model Prediction with Auto-Batching
9498

9599
In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.
96100

@@ -121,7 +125,7 @@ batched_predict = vmap(predict, in_axes=(None, 0))
121125

122126
+++ {"id": "AMWmxjVEpH2D"}
123127

124-
Multi-device setup using a Mesh of devices
128+
## Multi-device setup using a Mesh of devices
125129

126130
```{code-cell}
127131
:id: 4Jc5YLFnpE-_
@@ -139,7 +143,7 @@ sharding_spec = PartitionSpec('device')
139143

140144
+++ {"id": "rLqfeORsERek"}
141145

142-
### Utility and Loss Functions
146+
## Utility and Loss Functions
143147

144148
You'll now define utility functions for:
145149
- One-hot encoding: Converts class indices to binary vectors.
@@ -714,6 +718,6 @@ train_model(num_epochs, params, hf_training_generator)
714718

715719
## Summary
716720

717-
This notebook has introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for machine learning tasks. Each library offers distinct advantages, allowing you to select the best approach for your specific project needs.
721+
This notebook introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to optimize the data loading process for machine learning tasks. Each library offers unique advantages, enabling you to choose the best approach based on your project’s requirements.
718722

719-
For more detailed strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html).
723+
For more in-depth strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html).

docs/tutorials.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Tutorials
22

3-
*Note: this is a work in progress; visit again soon for updated content!*
3+
```{note}
4+
This is a work in progress; visit again soon for updated content!
5+
```
46

57
The following tutorials are meant as an intro to the full stack:
68

0 commit comments

Comments
 (0)