Skip to content

Commit b52de3b

Browse files
committed
Upgrade JAX AI Stack Machine Translation doc
1 parent dff0eec commit b52de3b

File tree

2 files changed

+200
-28
lines changed

2 files changed

+200
-28
lines changed

docs/source/JAX_machine_translation.ipynb

Lines changed: 115 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"id": "ee3e1116-f6cd-497e-b617-1d89d5d1f744",
66
"metadata": {},
77
"source": [
8-
"# Machine Translation with encoder-decoder transformer model\n",
8+
"# Machine translation with a transformer using JAX AI\n",
99
"\n",
1010
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)"
1111
]
@@ -15,9 +15,23 @@
1515
"id": "50f0bd58-dcc6-41f4-9dc4-3a08c8ef751b",
1616
"metadata": {},
1717
"source": [
18-
"This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)\n",
18+
"This tutorial will demonstrate how to use JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io) to perform machine translation. It was originally inspired by the [Keras English-to-Spanish translation tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/) (which was adaptated from [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)).\n",
1919
"\n",
20-
"We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation."
20+
"Here, you will learn how to:\n",
21+
"\n",
22+
"- Load and preprocess the dataset\n",
23+
"- Define the transformer model - the encoder, decoder and positional embedding classes - with Flax and JAX\n",
24+
"- Create the loss and training step functions\n",
25+
"- Train the model\n",
26+
"\n",
27+
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).\n",
28+
"\n",
29+
"\n",
30+
"## Setup\n",
31+
"\n",
32+
"JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site. We will use [Tiktoken](https://github.com/openai/tiktoken) for tokenization and [Grain](https://google-grain.readthedocs.io/en/latest/index.html) for data loading (`!pip install -Uq tiktoken grain)\n",
33+
"\n",
34+
"Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Tiktoken, and tqdm:"
2135
]
2236
},
2337
{
@@ -48,9 +62,9 @@
4862
"id": "e1f324b0-140a-48fa-9fcb-d6308f098343",
4963
"metadata": {},
5064
"source": [
51-
"## Pull down data to temp and extract into memory\n",
65+
"## Loading and preprocessing the data\n",
5266
"\n",
53-
"There are lots of ways to get this done, but for simplicity and clear visibility into what's happening this is downloaded to a temporary directory, extracted there, and read into a python object with processing."
67+
"For simplicity, we'll download the Spanish-to-English dataset to a temporary location, extract it, read into a Python object."
5468
]
5569
},
5670
{
@@ -92,8 +106,9 @@
92106
"id": "9524904b-fa17-493f-bcfa-335963cb7c45",
93107
"metadata": {},
94108
"source": [
95-
"## Build train/validate/test pair sets\n",
96-
"We'll stay close to the original tutorial so it's clear how to follow what's the same vs what's different; one early difference is the choice to go with an off-the-shelf encoder/tokenizer in tiktoken. Specifically \"cl100k_base\" - it has a wide range of language understanding and it's fast."
109+
"We'll stay close the original Keras tutorial, but use the \"off-the-shelf\" `cl100k_base` tokenizer from the [Tiktoken](https://github.com/openai/tiktoken) library, as it has a wide range of language understanding (and it's fast).\n",
110+
"\n",
111+
"We need to extracte the data, format it, and tokenize the phrases with padding."
97112
]
98113
},
99114
{
@@ -127,6 +142,14 @@
127142
"print(f\"{len(test_pairs)} test pairs\")"
128143
]
129144
},
145+
{
146+
"cell_type": "markdown",
147+
"id": "ac597030",
148+
"metadata": {},
149+
"source": [
150+
"Instantiate the `cl100k_base` tokenizer:"
151+
]
152+
},
130153
{
131154
"cell_type": "code",
132155
"execution_count": 4,
@@ -142,7 +165,7 @@
142165
"id": "a714c4ea-9ff6-4dab-ae9c-1a884d4857e7",
143166
"metadata": {},
144167
"source": [
145-
"We strip out punctuation to keep things simple and in line with the original tutorial - the `[` `]` are kept in so that our `[start]` and `[end]` formatting is preserved."
168+
"Remove any punctuation to keep things simple and in line with the original tutorial. The square brackets `[` `]` are kept to preserve `[start]` and `[end]` formatting."
146169
]
147170
},
148171
{
@@ -160,6 +183,14 @@
160183
"sequence_length = 20"
161184
]
162185
},
186+
{
187+
"cell_type": "markdown",
188+
"id": "3124c302",
189+
"metadata": {},
190+
"source": [
191+
"Define the input standardization function:"
192+
]
193+
},
163194
{
164195
"cell_type": "code",
165196
"execution_count": 6,
@@ -172,6 +203,14 @@
172203
" return re.sub(f\"[{re.escape(strip_chars)}]\", \"\", lowercase)"
173204
]
174205
},
206+
{
207+
"cell_type": "markdown",
208+
"id": "628608c3",
209+
"metadata": {},
210+
"source": [
211+
"Define the tokenizer function that also adding padding:"
212+
]
213+
},
175214
{
176215
"cell_type": "code",
177216
"execution_count": 7,
@@ -185,6 +224,14 @@
185224
" return padded"
186225
]
187226
},
227+
{
228+
"cell_type": "markdown",
229+
"id": "4c644fa4",
230+
"metadata": {},
231+
"source": [
232+
"Define the dataset formatting function that applies both `custom_standardization` and `tokenize_and_pad`:"
233+
]
234+
},
188235
{
189236
"cell_type": "code",
190237
"execution_count": 8,
@@ -204,6 +251,14 @@
204251
" }"
205252
]
206253
},
254+
{
255+
"cell_type": "markdown",
256+
"id": "40393664",
257+
"metadata": {},
258+
"source": [
259+
"Format the dataset:"
260+
]
261+
},
207262
{
208263
"cell_type": "code",
209264
"execution_count": 9,
@@ -221,7 +276,7 @@
221276
"id": "90bbae98-48dd-4ae4-99bb-92336d7c0a1c",
222277
"metadata": {},
223278
"source": [
224-
"At this point we've extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in train/validate/test sets that each have dictionary entries, which look like the following:"
279+
"At this point we have extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in training, validate and test sets, with dictionary entries that look like this:"
225280
]
226281
},
227282
{
@@ -248,7 +303,7 @@
248303
"id": "24c6271b-e359-4aba-a583-f18c40eddba9",
249304
"metadata": {},
250305
"source": [
251-
"The output should look something like\n",
306+
"The output should look something like:\n",
252307
"\n",
253308
"{'encoder_inputs': [9514, 265, 3339, 264, 2466, 16930, 1618, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0], 'target_output': [60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0]}"
254309
]
@@ -258,9 +313,11 @@
258313
"id": "7a906a05-bd17-4a47-afe0-4422d2ea0f50",
259314
"metadata": {},
260315
"source": [
261-
"## Define Transformer components: Encoder, Decoder, Positional Embed\n",
316+
"## Defining the transformer model with Flax and JAX: Encoder, decoder, positional embedding\n",
317+
"\n",
318+
"Next, we will construct the transformer model using JAX and Flax NNX. In many ways our approach tries to stay close to the original Keras machine translation tutorial, with `ops` changing to `jnp` (JAX NumPy) and `keras` or `layers` becoming Flax NNX's `nnx` layers. Certain `Module`-specific arguments are also different, such as Flax NNX (`flax.nxx.rngs`), while `decode=False` in the `MultiHeadAttention` call.\n",
262319
"\n",
263-
"In many ways this is very similar to the original source, with `ops` changing to `jnp` and `keras` or `layers` becoming `nnx`. Certain module-specific arguments come and go, like the rngs attached to most things in the updated version, and decode=False in the MultiHeadAttention call."
320+
"Let's build with the transformer encoder class - `TransformerEncoder()`, `TransformerDecoder` and a token embedding class - `PositionalEmbedding()` - by subclassing `flax.nnx.Module`. The `PositionalEmbedding()` class transforms tokens and positions into embeddings that will be fed into the transformer. It will combine token embeddings (words in an input sentence) with positional embeddings (the position of each word in a sentence). (It handles embedding both word tokens and their positions within the sequence.)"
264321
]
265322
},
266323
{
@@ -271,60 +328,101 @@
271328
"outputs": [],
272329
"source": [
273330
"class TransformerEncoder(nnx.Module):\n",
331+
" \"\"\" A single Transformer encoder that processes the embedded sequences.\n",
332+
"\n",
333+
" Args:\n",
334+
" embed_dim (int): Embedding dimensionality.\n",
335+
" dense_dim (int): Dimensionality of the linear layers.\n",
336+
" rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.\n",
337+
" \"\"\"\n",
274338
" def __init__(self, embed_dim: int, dense_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):\n",
275339
" self.embed_dim = embed_dim\n",
276340
" self.dense_dim = dense_dim\n",
277341
" self.num_heads = num_heads\n",
278342
"\n",
343+
" # Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`.\n",
279344
" self.attention = nnx.MultiHeadAttention(num_heads=num_heads,\n",
280345
" in_features=embed_dim,\n",
281346
" decode=False,\n",
282347
" rngs=rngs)\n",
348+
" # Linear transformation with ReLU activation for the feed-forward network with `flax.nnx.Linear`\n",
349+
" # and `flax.nnx.relu` activation.\n",
283350
" self.dense_proj = nnx.Sequential(\n",
284351
" nnx.Linear(embed_dim, dense_dim, rngs=rngs),\n",
285352
" nnx.relu,\n",
286353
" nnx.Linear(dense_dim, embed_dim, rngs=rngs),\n",
287354
" )\n",
288355
"\n",
356+
" # First layer normalization with `flax.nnx.LayerNorm`.\n",
289357
" self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)\n",
358+
" # Second layer normalization with `flax.nnx.LayerNorm`.\n",
290359
" self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)\n",
291360
"\n",
292361
" def __call__(self, inputs, mask=None):\n",
362+
" # The padding mask for attention.\n",
293363
" if mask is not None:\n",
294364
" padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n",
295365
" else:\n",
296366
" padding_mask = None\n",
297367
"\n",
368+
" # Apply Multi-Head Attention (with/without a mask).\n",
298369
" attention_output = self.attention(\n",
299370
" inputs_q = inputs, inputs_k = inputs, inputs_v = inputs, mask=padding_mask, decode = False\n",
300371
" )\n",
372+
" # Apply the first layer normalization.\n",
301373
" proj_input = self.layernorm_1(inputs + attention_output)\n",
374+
" # The feed-forward network.\n",
375+
" # Apply the first linear transformation.\n",
302376
" proj_output = self.dense_proj(proj_input)\n",
377+
" # Apply the second linear transformation.\n",
303378
" return self.layernorm_2(proj_input + proj_output)\n",
304379
"\n",
305380
"\n",
306381
"class PositionalEmbedding(nnx.Module):\n",
382+
" \"\"\" Combines token embeddings (words in an input sentence) with positional embeddings\n",
383+
" (the position of each word in a sentence).\n",
384+
" \n",
385+
" Args:\n",
386+
" sequence_length (int): Matimum sequence length.\n",
387+
" vocab_size (int): Vocabulary size.\n",
388+
" embed_dim (int): Embedding dimensionality.\n",
389+
" rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.\n",
390+
" \"\"\"\n",
391+
"\n",
392+
" # Initializes the token embedding layer (using `flax.nnx.Embed`).\n",
393+
" # Handles token and positional embeddings.\n",
307394
" def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs, **kwargs):\n",
308395
" self.token_embeddings = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)\n",
309396
" self.position_embeddings = nnx.Embed(num_embeddings=sequence_length, features=embed_dim, rngs=rngs)\n",
310397
" self.sequence_length = sequence_length\n",
311398
" self.vocab_size = vocab_size\n",
312399
" self.embed_dim = embed_dim\n",
313400
"\n",
401+
" # Generates embeddings for the input tokens and their positions.\n",
402+
" # Takes a token sequence (integers) and returns the combined token and positional embeddings.\n",
314403
" def __call__(self, inputs):\n",
315404
" length = inputs.shape[1]\n",
316405
" positions = jnp.arange(0, length)[None, :]\n",
317406
" embedded_tokens = self.token_embeddings(inputs)\n",
318407
" embedded_positions = self.position_embeddings(positions)\n",
319408
" return embedded_tokens + embedded_positions\n",
320409
"\n",
410+
" # Computes the attention mask.\n",
321411
" def compute_mask(self, inputs, mask=None):\n",
322412
" if mask is None:\n",
323413
" return None\n",
324414
" else:\n",
325415
" return jnp.not_equal(inputs, 0)\n",
326416
"\n",
327417
"class TransformerDecoder(nnx.Module):\n",
418+
" \"\"\" A single Transformer encoder that processes the embedded sequences.\n",
419+
"\n",
420+
" Args:\n",
421+
" embed_dim (int): Embedding dimensionality.\n",
422+
" latent_dim (int):\n",
423+
" num_heads (int):\n",
424+
" rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.\n",
425+
" \"\"\"\n",
328426
" def __init__(self, embed_dim: int, latent_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):\n",
329427
" self.embed_dim = embed_dim\n",
330428
" self.latent_dim = latent_dim\n",
@@ -383,7 +481,7 @@
383481
"id": "d033ae31-cc43-4e61-8d7f-cdc6d55b8bf9",
384482
"metadata": {},
385483
"source": [
386-
"Here we finally use our earlier encoder, decoder, and positional embed classes to construct the Model that we'll train and later use for inference."
484+
"Here we finally use our earlier encoder, decoder, and positional embedding classes to construct the transformer class that we'll train and later use for inference:"
387485
]
388486
},
389487
{
@@ -426,7 +524,8 @@
426524
"id": "1744cd95-afcc-4a82-9a00-18fef4f6f7df",
427525
"metadata": {},
428526
"source": [
429-
"## Build out Data Loader and Training Definitions\n",
527+
"## Building the Grain data loader\n",
528+
"\n",
430529
"It can be more computationally efficient to use pygrain for the data load stage, but this way it's abundandtly clear what's happening: data pairs go in and sets of jnp arrays come out, in step with our original dictionaries. 'Encoder_inputs', 'decoder_inputs' and 'target_output'."
431530
]
432531
},
@@ -494,6 +593,8 @@
494593
"id": "40d9707d-a73c-47f5-8c12-1f336e526e61",
495594
"metadata": {},
496595
"source": [
596+
"# \n",
597+
"\n",
497598
"Optax doesn't have the identical loss function that the source tutorial uses, but this softmax cross entropy works well here - you can one_hot_encode if you don't use the `_with_integer_labels` version of the loss."
498599
]
499600
},

0 commit comments

Comments
 (0)