You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/JAX_for_LLM_pretraining.md
+82-39
Original file line number
Diff line number
Diff line change
@@ -13,19 +13,27 @@ kernelspec:
13
13
14
14
+++ {"id": "NIOXoY1xgiww"}
15
15
16
-
# Pre-training an LLM (miniGPT)
16
+
# Train a miniGPT language model with JAX for AI
17
17
18
18
[](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb)
19
19
20
-
This tutorial demonstrates how to use JAX/Flax for LLM pretraining via data and tensor parallelism. It is originally inspired by this[Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/).
20
+
This tutorial will demonstrate how to use JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io)for language model training using data and tensor [parallelism](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization) for [Single-Program Multi-Data](https://en.wikipedia.org/wiki/Single_program,_multiple_data)). It was originally inspired by the[Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/).
21
21
22
-
We will use Google TPUs and [SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data) to train a language model `miniGPT`. Instead of using a GPU, you should use the free TPU on Colab or Kaggle for this tutorial.
22
+
Here, you will learn how to:
23
+
24
+
- Define the miniGPT model with Flax and JAX automatic parallelism
25
+
- Load and preprocess the dataset
26
+
- Create the loss and training step functions
27
+
- Train the model on Google Colab’s Cloud TPU v2
28
+
- Profile for hyperparameter tuning
29
+
30
+
If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).
23
31
24
32
+++ {"id": "hTmz5Cbco7n_"}
25
33
26
34
## Setup
27
35
28
-
Install JAX and Flax first. We will install Tiktoken for tokenization and Grain for data loading as well.
36
+
JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site. We will use [Tiktoken](https://github.com/openai/tiktoken) for tokenization and [Grain](https://google-grain.readthedocs.io/en/latest/index.html) for data loading.
29
37
30
38
```{code-cell}
31
39
---
@@ -34,13 +42,14 @@ colab:
34
42
id: 6zMsOIc7ouCO
35
43
outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068
36
44
---
37
-
!pip install -q jax-ai-stack
38
45
!pip install -Uq tiktoken grain matplotlib
39
46
```
40
47
41
48
+++ {"id": "Rcji_799n4eA"}
42
49
43
-
Confirm we have TPUs set up.
50
+
**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.
51
+
52
+
Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices.
Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Grain, pandas, and Tiktoken:
73
82
74
83
```{code-cell}
75
84
:id: MKYFNOhdLq98
76
85
77
86
import jax
78
87
import jax.numpy as jnp
88
+
89
+
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later)
90
+
from jax.experimental import mesh_utils
91
+
79
92
import flax.nnx as nnx
80
93
import optax
94
+
81
95
from dataclasses import dataclass
82
96
import grain.python as pygrain
83
-
from jax.experimental import mesh_utils
84
-
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
85
97
import pandas as pd
86
98
import tiktoken
87
99
import time
88
100
```
89
101
90
102
+++ {"id": "rPyt7MV6prz1"}
91
103
92
-
## Build the model
104
+
## Define the miniGPT model with Flax and JAX automatic parallelism
105
+
106
+
### Leveraging JAX parallelism
107
+
108
+
One of the most powerful features of JAX is [device parallelism](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization) for SPMD.
109
+
110
+
- The data parallelism technique enables, for example, the training data to run via multiple parts (this is called sharding) - batches - in parallel and simultaneously across different devices, such as GPUs and Google TPUs. This allows to use larger batch sizes to speed up training
111
+
- Tensor parallelism allows us to split the model parameter tensors across several devices (sharding model tensors).
112
+
- You can learn more about the basics of JAX parallelism in more detail in the [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) on the JAX documentation site.
113
+
114
+
In this example, we'll utilize a 4-way data parallel and 2-way tensor parallel setup. The free Google Cloud TPU v2 on Google Colab offers 4 chips, each with 2 TPU cores. The TPU v2 architeture aligns with the proposed setup.
93
115
94
-
One of the biggest advantages of JAX is how easy it is to enable parallelism. To demonstrate this, we are going to use 4-way data parallel and 2-way tensor parallel. Tensor parallelism is one kind of model parallelism, which shards model tensors; there are other kinds of model parallelism, which we won't cover in this tutorial.
116
+
### jax.sharding.Mesh
95
117
96
-
As a background, data parallel means splitting a batch of training data into multiple parts (this is called sharding); this way you can use bigger batch sizes to accelerate training, if you have multiple devices that can run in parallel. On the other hand, you can shard not just the training data. Sometimes your model is so big that the model parameters don't fit on a single accelerator. In this case, tensor parallel helps splitting the parameter tensors within a model onto multiple accelerators so that the model can actually run. Both approaches can take advantage of modern accelerators. For example, TPU v2 on the free Colab tier offers 4 chips, each of which has 2 TPU cores. So this architeture works well with 4-way data parallel and 2-way tensor parallel.
118
+
Earlier, we imported [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) - is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, such as `'x'` or `'y'`. This will help encapsulate the information about the TPU resource organization for distributing computations across the devices.
97
119
98
-
To get a detailed understanding of how JAX automatic parallelism works, please refer to this [JAX tutorial](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#way-batch-data-parallelism-and-2-way-model-tensor-parallelism). In our case to leverage parallelism, we first need to define a `Mesh`, which declares the TPU resources with 2 axes: `batch` axis as 4 and `model` axis as 2, which maps to the TPU v2 cores. Here, the `model` axis enables the tensor parallel for us.
120
+
Our `Mesh` will have two arguments:
121
+
-`devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..
122
+
-`axis_names`, where:
123
+
-`batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and
124
+
-`model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.
125
+
126
+
This matches the `(4, 2)` structure in the Colab's TPU v2 setup.
127
+
128
+
Let's instantiate `Mesh` as `mesh` and declare the TPU configuration to define how data and model parameters are distributed across the devices:
To use model parallel, we need to tell JAX compiler how to shard the model tensors. We first use `PartitionSpec` (shorted to `P` in the code) to describe how to shard a tensor: in our case a tensor could be either sharded along the `model` axis or be replicated on other dimensions (which is denoted by `None`). [`NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) can then specify how a model tensor is sharded across the devices mesh using a pair of `Mesh` and `PartitionSpec`.
153
+
To leverage model parallelism, we need to instruct the JAX compiler how to shard the model tensors across the TPU devices. Earlier, we also imported [`jax.sharding.PartitionSpec`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec) and [`jax.sharding.NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding):
154
+
-[`PartitionSpec`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec) (using alias `P`) defines how tensors are sharded across the devices in our `Mesh`. Its elements describe how an input dimension is partitioned across mesh dimensions. For example, in `PartitionSpec('x', 'y')` the first dimension of data is sharded across `x` axis of the mesh, and the second one - across the `y` axis.
155
+
- We'll use `PartitionSpec` to describe how to shard a tensor across, for example, the `model` axis or be replicated on other dimensions (which is denoted by `None`).
156
+
-[`NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) is a (`Mesh`, `PartitionSpec`) pair that describes how to shard a model tensor across our `mesh`.
157
+
- We combine `Mesh` (the TPU resources) with `PartitionSpec` and create a `NamedSharding`, which instructs how to shard each model tensor across the TPU devices.
124
158
125
-
Finally, we use `nnx.with_partitioning` to let the layers know that their tensors need to be shared/replicated according to our spec. You need to do this for every tensor/layer in your model.
159
+
Additionally, we'll use Flax NNX's [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to let each model layer know that the model weights or tensors need to be sharded according to our specification. We need to do this for every tensor/layer in the model.
160
+
-`nnx.with_partitioning` will take two arguments, such as the `initializer` (such as [`flax.nnx.initializers.xavier_uniform`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/initializers.html#flax.nnx.initializers.xavier_uniform) and [`flax.nnx.initializers.zeros_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/initializers.html#flax.nnx.initializers.zeros_init)) and `sharding` (e.g. `NamedSharding(Mesh, PartitionSpec)` or `NamedSharding(mesh, P('model')` in our case).
126
161
127
162
```{code-cell}
128
163
:id: z0p-IHurrB9i
129
164
165
+
# Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`.
@@ -550,14 +604,3 @@ By looking at the Trace Viewer tool and looking under each TPU's ops, we can see
550
604
```
551
605
552
606
By changing hyperparameters and comparing profiles, we're able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization.
0 commit comments