From a71f00360a677fcc480f81861a8d39dd4718f5c6 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 17 Mar 2025 21:39:15 +0000 Subject: [PATCH] Upgrade JAX AI Stack Machine Translation doc --- docs/source/JAX_machine_translation.ipynb | 128 +++++++++++++++++++--- docs/source/JAX_machine_translation.md | 98 ++++++++++++++--- 2 files changed, 198 insertions(+), 28 deletions(-) diff --git a/docs/source/JAX_machine_translation.ipynb b/docs/source/JAX_machine_translation.ipynb index 99394c6..ade7592 100644 --- a/docs/source/JAX_machine_translation.ipynb +++ b/docs/source/JAX_machine_translation.ipynb @@ -5,7 +5,7 @@ "id": "ee3e1116-f6cd-497e-b617-1d89d5d1f744", "metadata": {}, "source": [ - "# Machine Translation with encoder-decoder transformer model\n", + "# Machine translation with a transformer using JAX\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)" ] @@ -15,9 +15,23 @@ "id": "50f0bd58-dcc6-41f4-9dc4-3a08c8ef751b", "metadata": {}, "source": [ - "This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)\n", + "This tutorial will demonstrate how to use JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io) to perform machine translation. It was originally inspired by the [Keras English-to-Spanish translation tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/) (which was adapted from [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)).\n", "\n", - "We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation." + "Here, you will learn how to:\n", + "\n", + "- Load and preprocess the dataset\n", + "- Define the transformer model - the encoder, decoder and positional embedding classes - with Flax and JAX\n", + "- Create the loss and training step functions\n", + "- Train the model\n", + "\n", + "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).\n", + "\n", + "\n", + "## Setup\n", + "\n", + "JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site. We will use [Tiktoken](https://github.com/openai/tiktoken) for tokenization and [Grain](https://google-grain.readthedocs.io/en/latest/index.html) for data loading (`!pip install -Uq tiktoken grain)\n", + "\n", + "Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Tiktoken, and tqdm:" ] }, { @@ -48,9 +62,9 @@ "id": "e1f324b0-140a-48fa-9fcb-d6308f098343", "metadata": {}, "source": [ - "## Pull down data to temp and extract into memory\n", + "## Loading and preprocessing the data\n", "\n", - "There are lots of ways to get this done, but for simplicity and clear visibility into what's happening this is downloaded to a temporary directory, extracted there, and read into a python object with processing." + "For simplicity, we will download the Spanish-to-English dataset to a temporary location, extract it, and read it into a Python object." ] }, { @@ -92,8 +106,10 @@ "id": "9524904b-fa17-493f-bcfa-335963cb7c45", "metadata": {}, "source": [ - "## Build train/validate/test pair sets\n", - "We'll stay close to the original tutorial so it's clear how to follow what's the same vs what's different; one early difference is the choice to go with an off-the-shelf encoder/tokenizer in tiktoken. Specifically \"cl100k_base\" - it has a wide range of language understanding and it's fast." + "We will stay close to the original Keras tutorial, but use the \"off-the-shelf\" `cl100k_base` tokenizer from the [Tiktoken](https://github.com/openai/tiktoken) library, as it has a wide range of language understanding (and it's fast).\n", + "\n", + "\n", + "We need to extract the data, format it, and tokenize the phrases with padding." ] }, { @@ -127,6 +143,14 @@ "print(f\"{len(test_pairs)} test pairs\")" ] }, + { + "cell_type": "markdown", + "id": "ac597030", + "metadata": {}, + "source": [ + "Instantiate the `cl100k_base` tokenizer:" + ] + }, { "cell_type": "code", "execution_count": 4, @@ -142,7 +166,7 @@ "id": "a714c4ea-9ff6-4dab-ae9c-1a884d4857e7", "metadata": {}, "source": [ - "We strip out punctuation to keep things simple and in line with the original tutorial - the `[` `]` are kept in so that our `[start]` and `[end]` formatting is preserved." + "Remove any punctuation to keep things simple and in line with the original tutorial. The square brackets `[` `]` are kept to preserve `[start]` and `[end]` formatting." ] }, { @@ -160,6 +184,14 @@ "sequence_length = 20" ] }, + { + "cell_type": "markdown", + "id": "3124c302", + "metadata": {}, + "source": [ + "Define the input standardization function:" + ] + }, { "cell_type": "code", "execution_count": 6, @@ -172,6 +204,14 @@ " return re.sub(f\"[{re.escape(strip_chars)}]\", \"\", lowercase)" ] }, + { + "cell_type": "markdown", + "id": "628608c3", + "metadata": {}, + "source": [ + "Define the tokenizer function that also adding padding:" + ] + }, { "cell_type": "code", "execution_count": 7, @@ -185,6 +225,14 @@ " return padded" ] }, + { + "cell_type": "markdown", + "id": "4c644fa4", + "metadata": {}, + "source": [ + "Define the dataset formatting function that applies both `custom_standardization` and `tokenize_and_pad`:" + ] + }, { "cell_type": "code", "execution_count": 8, @@ -204,6 +252,14 @@ " }" ] }, + { + "cell_type": "markdown", + "id": "40393664", + "metadata": {}, + "source": [ + "Format the dataset:" + ] + }, { "cell_type": "code", "execution_count": 9, @@ -221,7 +277,7 @@ "id": "90bbae98-48dd-4ae4-99bb-92336d7c0a1c", "metadata": {}, "source": [ - "At this point we've extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in train/validate/test sets that each have dictionary entries, which look like the following:" + "At this point we have extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in training, validate and test sets, with dictionary entries that look like this:" ] }, { @@ -248,7 +304,7 @@ "id": "24c6271b-e359-4aba-a583-f18c40eddba9", "metadata": {}, "source": [ - "The output should look something like\n", + "The output should look something like:\n", "\n", "{'encoder_inputs': [9514, 265, 3339, 264, 2466, 16930, 1618, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0], 'target_output': [60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0]}" ] @@ -258,9 +314,11 @@ "id": "7a906a05-bd17-4a47-afe0-4422d2ea0f50", "metadata": {}, "source": [ - "## Define Transformer components: Encoder, Decoder, Positional Embed\n", + "## Defining the transformer model with Flax and JAX: Encoder, decoder, positional embedding\n", + "\n", + "Next, we will construct the transformer model using JAX and Flax NNX. In many ways our approach tries to stay close to the original Keras machine translation tutorial, with `ops` changing to `jnp` (JAX NumPy) and `keras` or `layers` becoming Flax NNX's `nnx` layers. Certain `Module`-specific arguments are also different, such as Flax NNX (`flax.nxx.rngs`), while `decode=False` in the `MultiHeadAttention` call.\n", "\n", - "In many ways this is very similar to the original source, with `ops` changing to `jnp` and `keras` or `layers` becoming `nnx`. Certain module-specific arguments come and go, like the rngs attached to most things in the updated version, and decode=False in the MultiHeadAttention call." + "Let's build with the transformer encoder class - `TransformerEncoder()`, `TransformerDecoder` and a token embedding class - `PositionalEmbedding()` - by subclassing `flax.nnx.Module`. The `PositionalEmbedding()` class transforms tokens and positions into embeddings that will be fed into the transformer. It will combine token embeddings (words in an input sentence) with positional embeddings (the position of each word in a sentence). (It handles embedding both word tokens and their positions within the sequence.)" ] }, { @@ -271,39 +329,69 @@ "outputs": [], "source": [ "class TransformerEncoder(nnx.Module):\n", + " \"\"\" A single Transformer encoder that processes the embedded sequences.\n", + "\n", + " Args:\n", + " embed_dim (int): Embedding dimensionality.\n", + " dense_dim (int): Dimensionality of the linear layers.\n", + " rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.\n", + " \"\"\"\n", " def __init__(self, embed_dim: int, dense_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):\n", " self.embed_dim = embed_dim\n", " self.dense_dim = dense_dim\n", " self.num_heads = num_heads\n", "\n", + " # Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`.\n", " self.attention = nnx.MultiHeadAttention(num_heads=num_heads,\n", " in_features=embed_dim,\n", " decode=False,\n", " rngs=rngs)\n", + " # Linear transformation with ReLU activation for the feed-forward network with `flax.nnx.Linear`\n", + " # and `flax.nnx.relu` activation.\n", " self.dense_proj = nnx.Sequential(\n", " nnx.Linear(embed_dim, dense_dim, rngs=rngs),\n", " nnx.relu,\n", " nnx.Linear(dense_dim, embed_dim, rngs=rngs),\n", " )\n", "\n", + " # First layer normalization with `flax.nnx.LayerNorm`.\n", " self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", + " # Second layer normalization with `flax.nnx.LayerNorm`.\n", " self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", "\n", " def __call__(self, inputs, mask=None):\n", + " # The padding mask for attention.\n", " if mask is not None:\n", " padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n", " else:\n", " padding_mask = None\n", "\n", + " # Apply Multi-Head Attention (with/without a mask).\n", " attention_output = self.attention(\n", " inputs_q = inputs, inputs_k = inputs, inputs_v = inputs, mask=padding_mask, decode = False\n", " )\n", + " # Apply the first layer normalization.\n", " proj_input = self.layernorm_1(inputs + attention_output)\n", + " # The feed-forward network.\n", + " # Apply the first linear transformation.\n", " proj_output = self.dense_proj(proj_input)\n", + " # Apply the second linear transformation.\n", " return self.layernorm_2(proj_input + proj_output)\n", "\n", "\n", "class PositionalEmbedding(nnx.Module):\n", + " \"\"\" Combines token embeddings (words in an input sentence) with positional embeddings\n", + " (the position of each word in a sentence).\n", + "\n", + " Args:\n", + " sequence_length (int): Matimum sequence length.\n", + " vocab_size (int): Vocabulary size.\n", + " embed_dim (int): Embedding dimensionality.\n", + " rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.\n", + " \"\"\"\n", + "\n", + " # Initializes the token embedding layer (using `flax.nnx.Embed`).\n", + " # Handles token and positional embeddings.\n", " def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs, **kwargs):\n", " self.token_embeddings = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)\n", " self.position_embeddings = nnx.Embed(num_embeddings=sequence_length, features=embed_dim, rngs=rngs)\n", @@ -311,6 +399,8 @@ " self.vocab_size = vocab_size\n", " self.embed_dim = embed_dim\n", "\n", + " # Generates embeddings for the input tokens and their positions.\n", + " # Takes a token sequence (integers) and returns the combined token and positional embeddings.\n", " def __call__(self, inputs):\n", " length = inputs.shape[1]\n", " positions = jnp.arange(0, length)[None, :]\n", @@ -318,6 +408,7 @@ " embedded_positions = self.position_embeddings(positions)\n", " return embedded_tokens + embedded_positions\n", "\n", + " # Computes the attention mask.\n", " def compute_mask(self, inputs, mask=None):\n", " if mask is None:\n", " return None\n", @@ -325,6 +416,14 @@ " return jnp.not_equal(inputs, 0)\n", "\n", "class TransformerDecoder(nnx.Module):\n", + " \"\"\" A single Transformer encoder that processes the embedded sequences.\n", + "\n", + " Args:\n", + " embed_dim (int): Embedding dimensionality.\n", + " latent_dim (int):\n", + " num_heads (int):\n", + " rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.\n", + " \"\"\"\n", " def __init__(self, embed_dim: int, latent_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):\n", " self.embed_dim = embed_dim\n", " self.latent_dim = latent_dim\n", @@ -383,7 +482,7 @@ "id": "d033ae31-cc43-4e61-8d7f-cdc6d55b8bf9", "metadata": {}, "source": [ - "Here we finally use our earlier encoder, decoder, and positional embed classes to construct the Model that we'll train and later use for inference." + "Here we finally use our earlier encoder, decoder, and positional embedding classes to construct the transformer class that we'll train and later use for inference:" ] }, { @@ -426,7 +525,8 @@ "id": "1744cd95-afcc-4a82-9a00-18fef4f6f7df", "metadata": {}, "source": [ - "## Build out Data Loader and Training Definitions\n", + "## Building the Grain data loader\n", + "\n", "It can be more computationally efficient to use pygrain for the data load stage, but this way it's abundandtly clear what's happening: data pairs go in and sets of jnp arrays come out, in step with our original dictionaries. 'Encoder_inputs', 'decoder_inputs' and 'target_output'." ] }, diff --git a/docs/source/JAX_machine_translation.md b/docs/source/JAX_machine_translation.md index 16e8777..e447bc2 100644 --- a/docs/source/JAX_machine_translation.md +++ b/docs/source/JAX_machine_translation.md @@ -12,15 +12,29 @@ kernelspec: name: python3 --- -# Machine Translation with encoder-decoder transformer model +# Machine translation with a transformer using JAX [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb) +++ -This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition) +This tutorial will demonstrate how to use JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io) to perform machine translation. It was originally inspired by the [Keras English-to-Spanish translation tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/) (which was adapted from [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)). -We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation. +Here, you will learn how to: + +- Load and preprocess the dataset +- Define the transformer model - the encoder, decoder and positional embedding classes - with Flax and JAX +- Create the loss and training step functions +- Train the model + +If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html). + + +## Setup + +JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site. We will use [Tiktoken](https://github.com/openai/tiktoken) for tokenization and [Grain](https://google-grain.readthedocs.io/en/latest/index.html) for data loading (`!pip install -Uq tiktoken grain) + +Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Tiktoken, and tqdm: ```{code-cell} ipython3 import pathlib @@ -39,9 +53,9 @@ import grain.python as grain import tqdm ``` -## Pull down data to temp and extract into memory +## Loading and preprocessing the data -There are lots of ways to get this done, but for simplicity and clear visibility into what's happening this is downloaded to a temporary directory, extracted there, and read into a python object with processing. +For simplicity, we will download the Spanish-to-English dataset to a temporary location, extract it, and read it into a Python object. ```{code-cell} ipython3 import requests @@ -71,8 +85,10 @@ with tempfile.TemporaryDirectory() as temp_dir: text_pairs.append((eng, spa)) ``` -## Build train/validate/test pair sets -We'll stay close to the original tutorial so it's clear how to follow what's the same vs what's different; one early difference is the choice to go with an off-the-shelf encoder/tokenizer in tiktoken. Specifically "cl100k_base" - it has a wide range of language understanding and it's fast. +We will stay close to the original Keras tutorial, but use the "off-the-shelf" `cl100k_base` tokenizer from the [Tiktoken](https://github.com/openai/tiktoken) library, as it has a wide range of language understanding (and it's fast). + + +We need to extract the data, format it, and tokenize the phrases with padding. ```{code-cell} ipython3 random.shuffle(text_pairs) @@ -88,11 +104,13 @@ print(f"{len(val_pairs)} validation pairs") print(f"{len(test_pairs)} test pairs") ``` +Instantiate the `cl100k_base` tokenizer: + ```{code-cell} ipython3 tokenizer = tiktoken.get_encoding("cl100k_base") ``` -We strip out punctuation to keep things simple and in line with the original tutorial - the `[` `]` are kept in so that our `[start]` and `[end]` formatting is preserved. +Remove any punctuation to keep things simple and in line with the original tutorial. The square brackets `[` `]` are kept to preserve `[start]` and `[end]` formatting. ```{code-cell} ipython3 strip_chars = string.punctuation + "¿" @@ -103,12 +121,16 @@ vocab_size = tokenizer.n_vocab sequence_length = 20 ``` +Define the input standardization function: + ```{code-cell} ipython3 def custom_standardization(input_string): lowercase = input_string.lower() return re.sub(f"[{re.escape(strip_chars)}]", "", lowercase) ``` +Define the tokenizer function that also adding padding: + ```{code-cell} ipython3 def tokenize_and_pad(text, tokenizer, max_length): tokens = tokenizer.encode(text)[:max_length] @@ -116,6 +138,8 @@ def tokenize_and_pad(text, tokenizer, max_length): return padded ``` +Define the dataset formatting function that applies both `custom_standardization` and `tokenize_and_pad`: + ```{code-cell} ipython3 def format_dataset(eng, spa, tokenizer, sequence_length): eng = custom_standardization(eng) @@ -129,64 +153,98 @@ def format_dataset(eng, spa, tokenizer, sequence_length): } ``` +Format the dataset: + ```{code-cell} ipython3 train_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in train_pairs] val_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in val_pairs] test_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in test_pairs] ``` -At this point we've extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in train/validate/test sets that each have dictionary entries, which look like the following: +At this point we have extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in training, validate and test sets, with dictionary entries that look like this: ```{code-cell} ipython3 ## data selection example print(train_data[135]) ``` -The output should look something like +The output should look something like: {'encoder_inputs': [9514, 265, 3339, 264, 2466, 16930, 1618, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0], 'target_output': [60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0]} +++ -## Define Transformer components: Encoder, Decoder, Positional Embed +## Defining the transformer model with Flax and JAX: Encoder, decoder, positional embedding + +Next, we will construct the transformer model using JAX and Flax NNX. In many ways our approach tries to stay close to the original Keras machine translation tutorial, with `ops` changing to `jnp` (JAX NumPy) and `keras` or `layers` becoming Flax NNX's `nnx` layers. Certain `Module`-specific arguments are also different, such as Flax NNX (`flax.nxx.rngs`), while `decode=False` in the `MultiHeadAttention` call. -In many ways this is very similar to the original source, with `ops` changing to `jnp` and `keras` or `layers` becoming `nnx`. Certain module-specific arguments come and go, like the rngs attached to most things in the updated version, and decode=False in the MultiHeadAttention call. +Let's build with the transformer encoder class - `TransformerEncoder()`, `TransformerDecoder` and a token embedding class - `PositionalEmbedding()` - by subclassing `flax.nnx.Module`. The `PositionalEmbedding()` class transforms tokens and positions into embeddings that will be fed into the transformer. It will combine token embeddings (words in an input sentence) with positional embeddings (the position of each word in a sentence). (It handles embedding both word tokens and their positions within the sequence.) ```{code-cell} ipython3 class TransformerEncoder(nnx.Module): + """ A single Transformer encoder that processes the embedded sequences. + + Args: + embed_dim (int): Embedding dimensionality. + dense_dim (int): Dimensionality of the linear layers. + rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys. + """ def __init__(self, embed_dim: int, dense_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs): self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads + # Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`. self.attention = nnx.MultiHeadAttention(num_heads=num_heads, in_features=embed_dim, decode=False, rngs=rngs) + # Linear transformation with ReLU activation for the feed-forward network with `flax.nnx.Linear` + # and `flax.nnx.relu` activation. self.dense_proj = nnx.Sequential( nnx.Linear(embed_dim, dense_dim, rngs=rngs), nnx.relu, nnx.Linear(dense_dim, embed_dim, rngs=rngs), ) + # First layer normalization with `flax.nnx.LayerNorm`. self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs) + # Second layer normalization with `flax.nnx.LayerNorm`. self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs) def __call__(self, inputs, mask=None): + # The padding mask for attention. if mask is not None: padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32) else: padding_mask = None + # Apply Multi-Head Attention (with/without a mask). attention_output = self.attention( inputs_q = inputs, inputs_k = inputs, inputs_v = inputs, mask=padding_mask, decode = False ) + # Apply the first layer normalization. proj_input = self.layernorm_1(inputs + attention_output) + # The feed-forward network. + # Apply the first linear transformation. proj_output = self.dense_proj(proj_input) + # Apply the second linear transformation. return self.layernorm_2(proj_input + proj_output) class PositionalEmbedding(nnx.Module): + """ Combines token embeddings (words in an input sentence) with positional embeddings + (the position of each word in a sentence). + + Args: + sequence_length (int): Matimum sequence length. + vocab_size (int): Vocabulary size. + embed_dim (int): Embedding dimensionality. + rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys. + """ + + # Initializes the token embedding layer (using `flax.nnx.Embed`). + # Handles token and positional embeddings. def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs, **kwargs): self.token_embeddings = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs) self.position_embeddings = nnx.Embed(num_embeddings=sequence_length, features=embed_dim, rngs=rngs) @@ -194,6 +252,8 @@ class PositionalEmbedding(nnx.Module): self.vocab_size = vocab_size self.embed_dim = embed_dim + # Generates embeddings for the input tokens and their positions. + # Takes a token sequence (integers) and returns the combined token and positional embeddings. def __call__(self, inputs): length = inputs.shape[1] positions = jnp.arange(0, length)[None, :] @@ -201,6 +261,7 @@ class PositionalEmbedding(nnx.Module): embedded_positions = self.position_embeddings(positions) return embedded_tokens + embedded_positions + # Computes the attention mask. def compute_mask(self, inputs, mask=None): if mask is None: return None @@ -208,6 +269,14 @@ class PositionalEmbedding(nnx.Module): return jnp.not_equal(inputs, 0) class TransformerDecoder(nnx.Module): + """ A single Transformer encoder that processes the embedded sequences. + + Args: + embed_dim (int): Embedding dimensionality. + latent_dim (int): + num_heads (int): + rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys. + """ def __init__(self, embed_dim: int, latent_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs): self.embed_dim = embed_dim self.latent_dim = latent_dim @@ -261,7 +330,7 @@ class TransformerDecoder(nnx.Module): return mask ``` -Here we finally use our earlier encoder, decoder, and positional embed classes to construct the Model that we'll train and later use for inference. +Here we finally use our earlier encoder, decoder, and positional embedding classes to construct the transformer class that we'll train and later use for inference: ```{code-cell} ipython3 class TransformerModel(nnx.Module): @@ -292,7 +361,8 @@ class TransformerModel(nnx.Module): return logits ``` -## Build out Data Loader and Training Definitions +## Building the Grain data loader + It can be more computationally efficient to use pygrain for the data load stage, but this way it's abundandtly clear what's happening: data pairs go in and sets of jnp arrays come out, in step with our original dictionaries. 'Encoder_inputs', 'decoder_inputs' and 'target_output'. ```{code-cell} ipython3