You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.\n",
26
+
"You'll learn how to use each of these libraries to efficiently load data for an image classification task using the MNIST dataset.\n",
27
27
"\n",
28
-
"Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide introduces optimizations for distributed training across multiple GPUs or TPUs. It focuses on data sharding with `Mesh` and `NamedSharding` to efficiently partition and synchronize data across devices. By leveraging multi-device setups, you'll maximize resource utilization for large datasets in distributed environments."
28
+
"Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide covers advanced strategies for multi-device setups, such as data sharding with `Mesh` and `NamedSharding` to partition and synchronize data across devices. These techniques allow you to partition and synchronize data across multiple devices, balancing the complexities of distributed systems while optimizing resource usage for large-scale datasets.\n",
29
+
"\n",
30
+
"If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).\n",
31
+
"\n",
32
+
"If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html)."
29
33
]
30
34
},
31
35
{
@@ -57,7 +61,7 @@
57
61
"id": "TsFdlkSZKp9S"
58
62
},
59
63
"source": [
60
-
"### Checking TPU Availability for JAX"
64
+
"## Checking TPU Availability for JAX"
61
65
]
62
66
},
63
67
{
@@ -99,7 +103,7 @@
99
103
"id": "qyJ_WTghDnIc"
100
104
},
101
105
"source": [
102
-
"### Setting Hyperparameters and Initializing Parameters\n",
106
+
"## Setting Hyperparameters and Initializing Parameters\n",
103
107
"\n",
104
108
"You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network."
105
109
]
@@ -141,7 +145,7 @@
141
145
"id": "rHLdqeI7D2WZ"
142
146
},
143
147
"source": [
144
-
"### Model Prediction with Auto-Batching\n",
148
+
"## Model Prediction with Auto-Batching\n",
145
149
"\n",
146
150
"In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n",
147
151
"\n",
@@ -182,7 +186,7 @@
182
186
"id": "AMWmxjVEpH2D"
183
187
},
184
188
"source": [
185
-
"Multi-device setup using a Mesh of devices"
189
+
"## Multi-device setup using a Mesh of devices"
186
190
]
187
191
},
188
192
{
@@ -210,7 +214,7 @@
210
214
"id": "rLqfeORsERek"
211
215
},
212
216
"source": [
213
-
"### Utility and Loss Functions\n",
217
+
"## Utility and Loss Functions\n",
214
218
"\n",
215
219
"You'll now define utility functions for:\n",
216
220
"- One-hot encoding: Converts class indices to binary vectors.\n",
@@ -1676,9 +1680,9 @@
1676
1680
"source": [
1677
1681
"## Summary\n",
1678
1682
"\n",
1679
-
"This notebook has introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for machine learning tasks. Each library offers distinct advantages, allowing you to select the best approach for your specific project needs.\n",
1683
+
"This notebook introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to optimize the data loading process for machine learning tasks. Each library offers unique advantages, enabling you to choose the best approach based on your project’s requirements.\n",
1680
1684
"\n",
1681
-
"For more detailed strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html)."
1685
+
"For more in-depth strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html)."
You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.
28
+
You'll learn how to use each of these libraries to efficiently load data for an image classification task using the MNIST dataset.
29
29
30
-
Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide introduces optimizations for distributed training across multiple GPUs or TPUs. It focuses on data sharding with `Mesh` and `NamedSharding` to efficiently partition and synchronize data across devices. By leveraging multi-device setups, you'll maximize resource utilization for large datasets in distributed environments.
30
+
Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide covers advanced strategies for multi-device setups, such as data sharding with `Mesh` and `NamedSharding` to partition and synchronize data across devices. These techniques allow you to partition and synchronize data across multiple devices, balancing the complexities of distributed systems while optimizing resource usage for large-scale datasets.
31
+
32
+
If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).
33
+
34
+
If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).
31
35
32
36
+++ {"id": "-rsMgVtO6asW"}
33
37
@@ -44,7 +48,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding
44
48
45
49
+++ {"id": "TsFdlkSZKp9S"}
46
50
47
-
###Checking TPU Availability for JAX
51
+
## Checking TPU Availability for JAX
48
52
49
53
```{code-cell}
50
54
---
@@ -58,7 +62,7 @@ jax.devices()
58
62
59
63
+++ {"id": "qyJ_WTghDnIc"}
60
64
61
-
###Setting Hyperparameters and Initializing Parameters
65
+
## Setting Hyperparameters and Initializing Parameters
62
66
63
67
You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network.
In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.
This notebook has introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for machine learning tasks. Each library offers distinct advantages, allowing you to select the best approach for your specific project needs.
721
+
This notebook introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to optimize the data loading process for machine learning tasks. Each library offers unique advantages, enabling you to choose the best approach based on your project’s requirements.
718
722
719
-
For more detailed strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html).
723
+
For more in-depth strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html).
0 commit comments