Skip to content

Commit c0737f1

Browse files
committed
Upgrade JAX AI Stack for LLM pretratining doc
1 parent f211d7f commit c0737f1

File tree

1 file changed

+82
-39
lines changed

1 file changed

+82
-39
lines changed

docs/source/JAX_for_LLM_pretraining.md

+82-39
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,27 @@ kernelspec:
1313

1414
+++ {"id": "NIOXoY1xgiww"}
1515

16-
# Pre-training an LLM (miniGPT)
16+
# Train a miniGPT language model with JAX for AI
1717

1818
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb)
1919

20-
This tutorial demonstrates how to use JAX/Flax for LLM pretraining via data and tensor parallelism. It is originally inspired by this [Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/).
20+
This tutorial will demonstrate how to use JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io) for language model training using data and tensor [parallelism](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization) for [Single-Program Multi-Data](https://en.wikipedia.org/wiki/Single_program,_multiple_data)). It was originally inspired by the [Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/).
2121

22-
We will use Google TPUs and [SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data) to train a language model `miniGPT`. Instead of using a GPU, you should use the free TPU on Colab or Kaggle for this tutorial.
22+
Here, you will learn how to:
23+
24+
- Define the miniGPT model with Flax and JAX automatic parallelism
25+
- Load and preprocess the dataset
26+
- Create the loss and training step functions
27+
- Train the model on Google Colab’s Cloud TPU v2
28+
- Profile for hyperparameter tuning
29+
30+
If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).
2331

2432
+++ {"id": "hTmz5Cbco7n_"}
2533

2634
## Setup
2735

28-
Install JAX and Flax first. We will install Tiktoken for tokenization and Grain for data loading as well.
36+
JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site. We will use [Tiktoken](https://github.com/openai/tiktoken) for tokenization and [Grain](https://google-grain.readthedocs.io/en/latest/index.html) for data loading.
2937

3038
```{code-cell}
3139
---
@@ -34,13 +42,14 @@ colab:
3442
id: 6zMsOIc7ouCO
3543
outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068
3644
---
37-
!pip install -q jax-ai-stack
3845
!pip install -Uq tiktoken grain matplotlib
3946
```
4047

4148
+++ {"id": "Rcji_799n4eA"}
4249

43-
Confirm we have TPUs set up.
50+
**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.
51+
52+
Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices.
4453

4554
```{code-cell}
4655
---
@@ -69,48 +78,69 @@ outputId: e6eff24e-5578-4277-a0f9-24e27bd91ee0
6978

7079
+++ {"id": "sKE2uUafLobI"}
7180

72-
Take care of the imports.
81+
Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Grain, pandas, and Tiktoken:
7382

7483
```{code-cell}
7584
:id: MKYFNOhdLq98
7685
7786
import jax
7887
import jax.numpy as jnp
88+
89+
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later)
90+
from jax.experimental import mesh_utils
91+
7992
import flax.nnx as nnx
8093
import optax
94+
8195
from dataclasses import dataclass
8296
import grain.python as pygrain
83-
from jax.experimental import mesh_utils
84-
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
8597
import pandas as pd
8698
import tiktoken
8799
import time
88100
```
89101

90102
+++ {"id": "rPyt7MV6prz1"}
91103

92-
## Build the model
104+
## Define the miniGPT model with Flax and JAX automatic parallelism
105+
106+
### Leveraging JAX parallelism
107+
108+
One of the most powerful features of JAX is [device parallelism](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization) for SPMD.
109+
110+
- The data parallelism technique enables, for example, the training data to run via multiple parts (this is called sharding) - batches - in parallel and simultaneously across different devices, such as GPUs and Google TPUs. This allows to use larger batch sizes to speed up training
111+
- Tensor parallelism allows us to split the model parameter tensors across several devices (sharding model tensors).
112+
- You can learn more about the basics of JAX parallelism in more detail in the [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) on the JAX documentation site.
113+
114+
In this example, we'll utilize a 4-way data parallel and 2-way tensor parallel setup. The free Google Cloud TPU v2 on Google Colab offers 4 chips, each with 2 TPU cores. The TPU v2 architeture aligns with the proposed setup.
93115

94-
One of the biggest advantages of JAX is how easy it is to enable parallelism. To demonstrate this, we are going to use 4-way data parallel and 2-way tensor parallel. Tensor parallelism is one kind of model parallelism, which shards model tensors; there are other kinds of model parallelism, which we won't cover in this tutorial.
116+
### jax.sharding.Mesh
95117

96-
As a background, data parallel means splitting a batch of training data into multiple parts (this is called sharding); this way you can use bigger batch sizes to accelerate training, if you have multiple devices that can run in parallel. On the other hand, you can shard not just the training data. Sometimes your model is so big that the model parameters don't fit on a single accelerator. In this case, tensor parallel helps splitting the parameter tensors within a model onto multiple accelerators so that the model can actually run. Both approaches can take advantage of modern accelerators. For example, TPU v2 on the free Colab tier offers 4 chips, each of which has 2 TPU cores. So this architeture works well with 4-way data parallel and 2-way tensor parallel.
118+
Earlier, we imported [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) - is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, such as `'x'` or `'y'`. This will help encapsulate the information about the TPU resource organization for distributing computations across the devices.
97119

98-
To get a detailed understanding of how JAX automatic parallelism works, please refer to this [JAX tutorial](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#way-batch-data-parallelism-and-2-way-model-tensor-parallelism). In our case to leverage parallelism, we first need to define a `Mesh`, which declares the TPU resources with 2 axes: `batch` axis as 4 and `model` axis as 2, which maps to the TPU v2 cores. Here, the `model` axis enables the tensor parallel for us.
120+
Our `Mesh` will have two arguments:
121+
- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..
122+
- `axis_names`, where:
123+
- `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and
124+
- `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.
125+
126+
This matches the `(4, 2)` structure in the Colab's TPU v2 setup.
127+
128+
Let's instantiate `Mesh` as `mesh` and declare the TPU configuration to define how data and model parameters are distributed across the devices:
99129

100130
```{code-cell}
101131
:id: xuMlCK3Q8WJD
102132
103133
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
104134
105-
### Alternative 8-way data parallel with only one line of code change.
135+
### Alternatively, we could use the 8-way data parallelism with only one line of code change.
106136
### JAX enables quick experimentation with different partitioning strategies
107137
### like this. We will come back to this point at the end of this tutorial.
108138
# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
109139
```
110140

111141
+++ {"id": "_ZKdhNo98NgG"}
112142

113-
We are going to use the GPT-2 tokenizer via [Tiktoken](https://github.com/openai/tiktoken).
143+
We will use the GPT-2 tokenizer from the [Tiktoken](https://github.com/openai/tiktoken) library:
114144

115145
```{code-cell}
116146
:id: iWbkk1V7-Isg
@@ -120,69 +150,92 @@ tokenizer = tiktoken.get_encoding("gpt2")
120150

121151
+++ {"id": "0XHQ0BQ9-KIj"}
122152

123-
To use model parallel, we need to tell JAX compiler how to shard the model tensors. We first use `PartitionSpec` (shorted to `P` in the code) to describe how to shard a tensor: in our case a tensor could be either sharded along the `model` axis or be replicated on other dimensions (which is denoted by `None`). [`NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) can then specify how a model tensor is sharded across the devices mesh using a pair of `Mesh` and `PartitionSpec`.
153+
To leverage model parallelism, we need to instruct the JAX compiler how to shard the model tensors across the TPU devices. Earlier, we also imported [`jax.sharding.PartitionSpec`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec) and [`jax.sharding.NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding):
154+
- [`PartitionSpec`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec) (using alias `P`) defines how tensors are sharded across the devices in our `Mesh`. Its elements describe how an input dimension is partitioned across mesh dimensions. For example, in `PartitionSpec('x', 'y')` the first dimension of data is sharded across `x` axis of the mesh, and the second one - across the `y` axis.
155+
- We'll use `PartitionSpec` to describe how to shard a tensor across, for example, the `model` axis or be replicated on other dimensions (which is denoted by `None`).
156+
- [`NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) is a (`Mesh`, `PartitionSpec`) pair that describes how to shard a model tensor across our `mesh`.
157+
- We combine `Mesh` (the TPU resources) with `PartitionSpec` and create a `NamedSharding`, which instructs how to shard each model tensor across the TPU devices.
124158

125-
Finally, we use `nnx.with_partitioning` to let the layers know that their tensors need to be shared/replicated according to our spec. You need to do this for every tensor/layer in your model.
159+
Additionally, we'll use Flax NNX's [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to let each model layer know that the model weights or tensors need to be sharded according to our specification. We need to do this for every tensor/layer in the model.
160+
- `nnx.with_partitioning` will take two arguments, such as the `initializer` (such as [`flax.nnx.initializers.xavier_uniform`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/initializers.html#flax.nnx.initializers.xavier_uniform) and [`flax.nnx.initializers.zeros_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/initializers.html#flax.nnx.initializers.zeros_init)) and `sharding` (e.g. `NamedSharding(Mesh, PartitionSpec)` or `NamedSharding(mesh, P('model')` in our case).
126161

127162
```{code-cell}
128163
:id: z0p-IHurrB9i
129164
165+
# Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`.
130166
def causal_attention_mask(seq_len):
131167
return jnp.tril(jnp.ones((seq_len, seq_len)))
132168
169+
# Define a single Transformer block.
133170
class TransformerBlock(nnx.Module):
171+
# Initialize layers of the Transformer block.
134172
def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1):
173+
# Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`.
135174
self.mha = nnx.MultiHeadAttention(num_heads=num_heads,
136175
in_features=embed_dim,
137-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
138-
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
176+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))), # Specify tensor sharding.
177+
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))), # Specify tensor sharding.
139178
rngs=rngs)
179+
# The first dropout with `flax.nnx.Dropout`.
140180
self.dropout1 = nnx.Dropout(rate=rate)
181+
# First layer normalization with `flax.nnx.LayerNorm`.
141182
self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6,
142183
num_features=embed_dim,
143184
scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))),
144185
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
145186
rngs=rngs)
187+
# The first linear transformation for the feed-forward network with `flax.nnx.Linear`.
146188
self.linear1 = nnx.Linear(in_features=embed_dim,
147189
out_features=ff_dim,
148190
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
149191
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
150192
rngs=rngs)
193+
# The second linear transformation for the feed-forward network with `flax.nnx.Linear`.
151194
self.linear2 = nnx.Linear(in_features=ff_dim,
152195
out_features=embed_dim,
153196
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
154197
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
155198
rngs=rngs)
199+
# The second dropout with `flax.nnx.Dropout`.
156200
self.dropout2 = nnx.Dropout(rate=rate)
201+
# Second layer normalization with `flax.nnx.LayerNorm`.
157202
self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6,
158203
num_features=embed_dim,
159204
scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
160205
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
161206
rngs=rngs)
162207
163208
209+
# Apply the Transformer block to the input sequence.
164210
def __call__(self, inputs, training: bool = False):
165211
input_shape = inputs.shape
166212
_, seq_len, _ = input_shape
167213
168-
# Create causal mask
214+
# Instantiate the causal attention mask.
169215
mask = causal_attention_mask(seq_len)
170216
171-
# Apply MultiHeadAttention with causal mask
217+
# Apply Multi-Head Attention with the causal attention mask.
172218
attention_output = self.mha(
173219
inputs_q=inputs,
174220
mask=mask,
175221
decode=False
176222
)
223+
# Apply the first dropout.
177224
attention_output = self.dropout1(attention_output, deterministic=not training)
225+
# Apply the first layer normalization.
178226
out1 = self.layer_norm1(inputs + attention_output)
179227
180-
# Feed-forward network
228+
# Feed-forward network.
229+
# Apply the first linear transformation.
181230
ffn_output = self.linear1(out1)
231+
# Apply the ReLU activation with `flax.nnx.relu`.
182232
ffn_output = nnx.relu(ffn_output)
233+
# Apply the second linear transformation.
183234
ffn_output = self.linear2(ffn_output)
235+
# Apply the second dropout.
184236
ffn_output = self.dropout2(ffn_output, deterministic=not training)
185237
238+
# Apply the second layer normalization and return the output of the Transformer block.
186239
return self.layer_norm2(out1 + ffn_output)
187240
188241
@@ -275,7 +328,7 @@ num_epochs = 1
275328

276329
+++ {"id": "mI1ci-HyMspJ"}
277330

278-
## Prepare data
331+
## Loading and preprocessing the data
279332

280333
Data loading and preprocessing with [Grain](https://github.com/google/grain).
281334

@@ -327,9 +380,7 @@ text_dl = load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen)
327380

328381
+++ {"id": "BKVSD8KSM1um"}
329382

330-
## Train the model
331-
332-
Define loss function and training step function.
383+
## Defining the loss function and training step function
333384

334385
```{code-cell}
335386
:id: 8rRuTmABNV4b
@@ -349,6 +400,8 @@ def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetri
349400

350401
+++ {"id": "5um2vkeUNckm"}
351402

403+
## Training the model
404+
352405
Start training. It takes ~50 minutes on Colab.
353406

354407
Note that for data parallel, we are sharding the training data along the `batch` axis using `jax.device_put` with `NamedeSharding`.
@@ -441,7 +494,8 @@ As you can see, the model goes from generating completely random words at the be
441494

442495
+++ {"id": "soPqiR1JNmjf"}
443496

444-
## Saving
497+
## Saving the checkpoint
498+
445499
Save the model checkpoint.
446500

447501
```{code-cell}
@@ -462,7 +516,7 @@ checkpointer.save('/content/save', state)
462516
!ls /content/save/
463517
```
464518

465-
## Profiling for Hyperparameter Tuning
519+
## Profiling for hyperparameter tuning
466520

467521
```{code-cell}
468522
!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard
@@ -550,14 +604,3 @@ By looking at the Trace Viewer tool and looking under each TPU's ops, we can see
550604
```
551605

552606
By changing hyperparameters and comparing profiles, we're able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization.
553-
554-
+++ {"id": "jCApVd7671c1"}
555-
556-
## Disconnect the Colab runtime
557-
558-
```{code-cell}
559-
:id: NsqYdbrDVKSq
560-
561-
from google.colab import runtime
562-
runtime.unassign()
563-
```

0 commit comments

Comments
 (0)