diff --git a/site/en/gemma/docs/core/keras_inference.ipynb b/site/en/gemma/docs/core/keras_inference.ipynb index cf532f3fe..4dcd12fa9 100644 --- a/site/en/gemma/docs/core/keras_inference.ipynb +++ b/site/en/gemma/docs/core/keras_inference.ipynb @@ -69,20 +69,11 @@ "id": "PXNm5_p_oxMF" }, "source": [ - "# Get started with Gemma using KerasNLP\n", + "# Run Gemma with Keras\n", "\n", - "This tutorial shows you how to get started with Gemma using [KerasNLP](https://keras.io/keras_nlp/). Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models. KerasNLP is a collection of natural language processing (NLP) models implemented in [Keras](https://keras.io/) and runnable on JAX, PyTorch, and TensorFlow.\n", + "Generating content, summarizing, and analysing content are just some of the tasks you can accomplish with Gemma open models. This tutorial shows you how to get started running Gemma using Keras, including generating text content with text and image input. [Keras](https://keras.io/) provides implementations for running Gemma and other models using JAX, PyTorch, and TensorFlow. If you're new to Keras, you might want to read [Getting started with Keras](https://keras.io/getting_started/) before you begin.\n", "\n", - "In this tutorial, you'll use Gemma to generate text responses to several prompts. If you're new to Keras, you might want to read [Getting started with Keras](https://keras.io/getting_started/) before you begin, but you don't have to. You'll learn more about Keras as you work through this tutorial." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mERVCCsGUPIJ" - }, - "source": [ - "## Setup" + "Gemma 3 and later models support text and image input. Earlier versions of Gemma only support text input, except for some variants, including [PaliGemma](https://ai.google.dev/gemma/docs/setup)." ] }, { @@ -91,16 +82,16 @@ "id": "QQ6W7NzRe1VM" }, "source": [ - "### Gemma setup\n", + "## Setup\n", "\n", - "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n", + "Before starting this tutorial, make sure you have completed the following steps:\n", "\n", - "* Get access to Gemma on kaggle.com.\n", + "* Get access to Gemma on [kaggle.com](https://www.kaggle.com).\n", "* Select a Colab runtime with sufficient resources to run\n", - " the Gemma 2B model.\n", + " the Gemma model size you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n", "* Generate and configure a Kaggle username and API key.\n", "\n", - "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment." + "If you need help completing these steps, see the [Gemma setup](https://ai.google.dev/gemma/docs/setup) instructions. After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment." ] }, { @@ -116,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "id": "DrBoa_Urw9Vx" }, @@ -137,9 +128,9 @@ "id": "z9oy3QUmXtSd" }, "source": [ - "### Install dependencies\n", + "### Install Keras packages\n", "\n", - "Install Keras and KerasNLP." + "Install the Keras and KerasHub Python packages." ] }, { @@ -150,9 +141,8 @@ }, "outputs": [], "source": [ - "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n", - "!pip install -q -U keras-nlp\n", - "!pip install -q -U \"keras>=3\"" + "!pip install -q -U keras-hub\n", + "!pip install -q -U keras" ] }, { @@ -163,21 +153,19 @@ "source": [ "### Select a backend\n", "\n", - "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. [Keras 3](https://keras.io/keras_3) lets you choose the backend: TensorFlow, JAX, or PyTorch. All three will work for this tutorial." + "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. [Keras 3](https://keras.io/keras_3) lets you choose the backend: TensorFlow, JAX, or PyTorch. All three will work for this tutorial. For this tutorial, configure the backend for JAX as it typically provides the better performance." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "id": "7rS7ryTs5wjf" }, "outputs": [], "source": [ - "import os\n", - "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"tensorflow\" or \"torch\".\n", - "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"0.9\"" + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\"" ] }, { @@ -188,19 +176,19 @@ "source": [ "### Import packages\n", "\n", - "Import Keras and KerasNLP." + "Import the Keras and KerasHub packages." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "id": "f2fa267d75bc" }, "outputs": [], "source": [ "import keras\n", - "import keras_nlp" + "import keras_hub" ] }, { @@ -209,11 +197,9 @@ "id": "ZsxDCbLN555T" }, "source": [ - "## Create a model\n", - "\n", - "KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.\n", + "## Load model\n", "\n", - "Create the model using the `from_preset` method:" + "Keras provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). Download and configure a Gemma model using the `Gemma3CausalLM` class to build an end-to-end, causal language modeling implementation for Gemma 3 models. Create the model using the `from_preset()` method, as shown in the following code example:" ] }, { @@ -224,7 +210,10 @@ }, "outputs": [], "source": [ - "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_2b_en\")\n" + "gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\n", + " \"gemma3_instruct_4b\",\n", + " dtype=\"bfloat16\",\n", + ")" ] }, { @@ -233,16 +222,7 @@ "id": "XrAWvsU6pI0E" }, "source": [ - "The `GemmaCausalLM.from_preset()` function instantiates the model from a preset architecture and weights. In the code above, the string `\"gemma2_2b_en\"` specifies the preset the Gemma 2 2B model with 2 billion parameters. Gemma models with [7B, 9B, and 27B parameters](/gemma/docs/get_started#models-list) are also available. You can find the code strings for Gemma models in their **Model Variation** listings on [Kaggle](https://www.kaggle.com/models/google/gemma).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ij73k0PfUhjE" - }, - "source": [ - "Note: To run the larger models in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform inferences using Kaggle notebooks or Google Cloud projects.\n" + "The `Gemma3CausalLM.from_preset()` method instantiates the model from a preset architecture and weights. In the code above, the string `\"gemma#_xxxxxxx\"` specifies a preset version and parameter size for Gemma. You can find the code strings for Gemma models in their **Model Variation** listings on [Kaggle](https://www.kaggle.com/models/keras/gemma3).\n" ] }, { @@ -251,12 +231,12 @@ "id": "E-cSEjULUhST" }, "source": [ - "Use `summary` to get more info about the model:" + "Once you have the model downloaded, Use the `summary()` function to get more info about the model:" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "id": "e5nEbTdApL7W" }, @@ -264,121 +244,44 @@ { "data": { "text/html": [ - "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Tokenizer (type) ┃ Vocab # ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ gemma_tokenizer (GemmaTokenizer) │ 256,000 │\n", - "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", - "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
Model: \"gemma_causal_lm\"\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "Preprocessor: \"gemma3_causal_lm_preprocessor\"\n", + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Layer (type) ┃ Config ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ gemma3_tokenizer (Gemma3Tokenizer) │ Vocab size: 262,144 │\n", + "├───────────────────────────────────────────────────────────────┼──────────────────────────────────────────┤\n", + "│ gemma3_image_converter (Gemma3ImageConverter) │ Image size: (896, 896) │\n", + "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n", + "Model: \"gemma3_causal_lm_1\"\n", + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ images (InputLayer) │ (None, None, 896, 896, 3) │ 0 │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ padding_mask (InputLayer) │ (None, None) │ 0 │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_ids (InputLayer) │ (None, None) │ 0 │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ gemma_backbone │ (None, None, 2304) │ 2,614,341,888 │ padding_mask[0][0], │\n", - "│ (GemmaBackbone) │ │ │ token_ids[0][0] │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_embedding │ (None, None, 256000) │ 589,824,000 │ gemma_backbone[0][0] │\n", - "│ (ReversibleEmbedding) │ │ │ │\n", - "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ vision_indices (InputLayer) │ (None, None) │ 0 │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ vision_mask (InputLayer) │ (None, None) │ 0 │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", - "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ gemma3_backbone │ (None, 256, 2560) │ 4,299,915,632 │ images[0][0], │\n", + "│ (Gemma3Backbone) │ │ │ padding_mask[0][0], │\n", + "│ │ │ │ token_ids[0][0], │\n", + "│ │ │ │ vision_indices[0][0], │\n", + "│ │ │ │ vision_mask[0][0] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", - "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", - "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Total params: 2,614,341,888 (9.74 GB)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Trainable params: 2,614,341,888 (9.74 GB)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Non-trainable params: 0 (0.00 B)\n", + "│ token_embedding │ (None, 256, 262144) │ 671,088,640 │ gemma3_backbone[0][0] │\n", + "│ (ReversibleEmbedding) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", + " Total params: 4,299,915,632 (8.79 GB)\n", + " Trainable params: 4,299,915,632 (8.79 GB)\n", + " Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + "\u001b[1mModel: \"gemma3_causal_lm_1\"\u001b[0m\n" ] }, "metadata": {}, @@ -395,9 +298,17 @@ "id": "81KHdRYOrWYm" }, "source": [ - "As you can see from the summary, the model has 2.6 billion trainable parameters.\n", - "\n", - "Note: For purposes of naming the model (\"2B\"), the embedding layer is not counted against the number of parameters." + "The output of the summary shows the models total number of trainable parameters.\n", + "For purposes of naming the model, the embedding layer is not counted against the number of parameters." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ij73k0PfUhjE" + }, + "source": [ + "Note: To run larger Gemma models with Google Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform inferences using [Kaggle](https://www.kaggle.com/code) notebooks or Google Cloud projects.\n" ] }, { @@ -406,34 +317,18 @@ "id": "FOBW7piN5-sl" }, "source": [ - "## Generate text\n", - "\n", - "Now it's time to generate some text! The model has a `generate` method that generates text based on a prompt. The optional `max_length` argument specifies the maximum length of the generated sequence.\n", + "## Generate text with text\n", "\n", - "Try it out with the prompt `\"what is keras in 3 bullet points?\"`." + "Generate text with a text prompt with using `generate()` method of the Gemma model object you configured in the previous steps. The optional `max_length` argument specifies the maximum length of the generated sequence. The following code examples shows a few ways to prompt the model." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "id": "aae5GHrdpj2_" }, - "outputs": [ - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "'what is keras in 3 bullet points?\\n\\n[Answer 1]\\n\\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\\n\\n'" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "gemma_lm.generate(\"what is keras in 3 bullet points?\", max_length=64)" ] @@ -441,35 +336,24 @@ { "cell_type": "markdown", "metadata": { - "id": "qH0eFH_DvYwM" + "id": "mw5XkiHU11Ft" }, "source": [ - "Try calling `generate` again with a different prompt." + "You can also provide batched prompts using a list as input:" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { - "id": "VEyTnnNGvgGG" + "id": "xV6vs8_C2BGt" }, - "outputs": [ - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "gemma_lm.generate(\"The universe is\", max_length=64)" + "gemma_lm.generate(\n", + " [\"what is keras in 3 bullet points?\",\n", + " \"The universe is\"],\n", + " max_length=64)" ] }, { @@ -478,42 +362,55 @@ "id": "vVlCnY7Gvm7U" }, "source": [ - "If you're running on JAX or TensorFlow backends, you'll notice that the second `generate` call returns nearly instantly. This is because each call to `generate` for a given batch size and `max_length` is compiled with XLA. The first run is expensive, but subsequent runs are much faster." + "If you're running on JAX or TensorFlow backends, you should notice that the second `generate()` call returns an answer more quickly. This performance improvement is because each call to `generate()` for a given batch size and `max_length` is compiled with XLA. The first run is expensive, but subsequent runs are faster." ] }, { "cell_type": "markdown", "metadata": { - "id": "mw5XkiHU11Ft" + "id": "CLZodRy8bqBQ" }, "source": [ - "You can also provide batched prompts using a list as input:" + "### Use a prompt template\n", + "\n", + "When building more complex requests or multi-turn chat interactions use a prompt template to structure your request. The following code creates a standard template for Gemma prompts:" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 1, "metadata": { - "id": "xV6vs8_C2BGt" + "id": "suAz3uOEb4rb" }, - "outputs": [ - { - "data": { - "text/plain": [ - "['what is keras in 3 bullet points?\\n\\n[Answer 1]\\n\\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\\n\\n',\n", - " 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "gemma_lm.generate(\n", - " [\"what is keras in 3 bullet points?\",\n", - " \"The universe is\"],\n", - " max_length=64)" + "PROMPT_TEMPLATE = \"\"\"user\n", + "{question}\n", + " \n", + " model\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0VXrxEVScl_P" + }, + "source": [ + "The following code shows how to use the template to format a simple request:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "G6SuO3BdftM7" + }, + "outputs": [], + "source": [ + "question = \"\"\"\"what is keras in 3 bullet points?\"\"\"\n", + "prompt = PROMPT_TEMPLATE.format(question=question)\n", + "gemma_lm.generate(prompt)" ] }, { @@ -524,32 +421,16 @@ "source": [ "### Optional: Try a different sampler\n", "\n", - "You can control the generation strategy for `GemmaCausalLM` by setting the `sampler` argument on `compile()`. By default, [`\"greedy\"`](https://keras.io/api/keras_nlp/samplers/greedy_sampler/#greedysampler-class) sampling will be used.\n", - "\n", - "As an experiment, try setting a [`\"top_k\"`](https://keras.io/api/keras_nlp/samplers/top_k_sampler/) strategy:" + "You can control the generation strategy for model object by setting the `sampler` argument on `compile()`. By default, [`\"greedy\"`](https://keras.io/api/keras_nlp/samplers/greedy_sampler/#greedysampler-class) sampling will be used. As an experiment, try setting a [`\"top_k\"`](https://keras.io/api/keras_nlp/samplers/top_k_sampler/) strategy:" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { "id": "mx55VQpN4DAK" }, - "outputs": [ - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "'The universe is a big place, and there are so many things we do not know or understand about it.\\n\\nBut we can learn a lot about our world by studying what is known to us.\\n\\nFor example, if you look at the moon, it has many features that can be seen from the surface.'" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "gemma_lm.compile(sampler=\"top_k\")\n", "gemma_lm.generate(\"The universe is\", max_length=64)" @@ -561,9 +442,179 @@ "id": "-okKgK4LfO0f" }, "source": [ - "While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability.\n", + "While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability. You don't have to specify a sampler, and you can ignore the last code snippet if it's not helpful to your use case. If you'd like learn more about the available samplers, see [Samplers](https://keras.io/api/keras_nlp/samplers/)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NvQ1ajNlCjU6" + }, + "source": [ + "## Generate text with image data\n", + "\n", + "With Gemma 3 and later models, you can use images as part of a prompt to generate output. This capability allows you to use Gemma to interpret visual content or use images as data for content generation.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SDp_zfIYD4b3" + }, + "source": [ + "### Create image loader function\n", + "\n", + "The following function loads an image file from a URL and tokenizes it for use in Gemma prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WUgo0S__EX9O" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import PIL\n", + "\n", + "def read_image(url):\n", + " \"\"\"Reads image from URL as NumPy array.\"\"\"\n", + "\n", + " image_path = keras.utils.get_file(origin=url)\n", + " image = PIL.Image.open(image_path)\n", + " image = np.array(image)\n", + " return image" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RBqnJeQReUH_" + }, + "source": [ + "### Load image for a prompt\n", + "\n", + "Load the image and format the data so the model can process it. Use `read_image()` function defined in the previous section, as shown in the example code below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UnIMXwl2evBe" + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "\n", + "image = read_image(\n", + " \"https://ai.google.dev/gemma/docs/images/thali-indian-plate.jpg\"\n", + ")\n", + "plt.imshow(image)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tTe6GUozMfja" + }, + "source": [ + " \n", + "\n", + "**Figure 1.** Image of Thali Indian food on a metal plate." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hq-RhOZMcgkY" + }, + "source": [ + "### Run request with an image\n", + "\n", + "When prompting the Gemma model with image content, you use a specific string sequence, `
`, within your prompt to include the image as part of the prompt. Use a prompt template, such as the `PROMPT_TEMPLATE` string defined previously, to format your request as shown in the following prompt code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HtMiNmD-eESl" + }, + "outputs": [], + "source": [ + "question = \"\"\"Which cuisine is this: ? \\\n", + "Identify the food items present. Which macros is the meal \\\n", + "high and low on? Keep your answer short.\\\n", + "\"\"\"\n", + "\n", + "gemma_lm.generate(\n", + " {\n", + " \"images\": image,\n", + " \"prompts\": PROMPT_TEMPLATE.format(question=question),\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fjUsHgumE9jD" + }, + "source": [ + "If you are using a smaller GPU, and encountering out of memory (OOM) errors, you can set `max_images_per_prompt` and `sequence_length` to smaller values. The following code shows how to reduce sequence length to 768." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "h1nZRvyNFMxt" + }, + "outputs": [], + "source": [ + "gemma_lm.preprocessor.max_images_per_prompt = 2\n", + "gemma_lm.preprocessor.sequence_length = 768" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7E6h-doZfmmb" + }, + "source": [ + "### Run requests with multiple images\n", + "\n", + "When using more than one image in a prompt, use multiple ` ` tokens for each provided image, as shown in the following example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XTpFhba-gT5v" + }, + "outputs": [], + "source": [ + "dog_a = read_image(\"http://localhost/images/dog-a.jpg\")\n", + "dog_b = read_image(\"http://localhost/images/dog-b.jpg\")\n", "\n", - "You don't have to specify a sampler, and you can ignore the last code snippet if it's not helpful to your use case. If you'd like learn more about the available samplers, see [Samplers](https://keras.io/api/keras_nlp/samplers/)." + "question = \"\"\"I have two images:\n", + "\n", + "Dog A: \n", + "Dog B: \n", + "\n", + "Which breeds are they? Tell me a bit about them. \\\n", + "Keep it short.\\\n", + "\"\"\"\n", + "\n", + "gemma_lm.generate(\n", + " {\n", + " \"images\": [dog_a, dog_b],\n", + " \"prompts\": PROMPT_TEMPLATE.format(question=question),\n", + " },\n", + ")" ] }, { @@ -574,7 +625,7 @@ "source": [ "## What's next\n", "\n", - "In this tutorial, you learned how to generate text using KerasNLP and Gemma. Here are a few suggestions for what to learn next:\n", + "In this tutorial, you learned how to generate text using Keras and Gemma. Here are a few suggestions for what to learn next:\n", "\n", "* Learn how to [finetune a Gemma model](https://ai.google.dev/gemma/docs/core/lora_tuning).\n", "* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n", diff --git a/site/en/gemma/docs/core/lora_tuning.ipynb b/site/en/gemma/docs/core/lora_tuning.ipynb index ed8ea92f2..391f25f8b 100644 --- a/site/en/gemma/docs/core/lora_tuning.ipynb +++ b/site/en/gemma/docs/core/lora_tuning.ipynb @@ -47,7 +47,7 @@ "id": "SDEExiAk4fLb" }, "source": [ - "# Fine-tune Gemma models in Keras using LoRA" + "# Fine-tune Gemma in Keras using LoRA" ] }, { @@ -77,22 +77,9 @@ "id": "lSGRSsRPgkzK" }, "source": [ - "Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).\n", + "Generative artificial intelligent (AI) models like Gemma are effective at a variety of tasks. You can further fine-tune Gemma models with domain-specific data to perform tasks such as sentiment analysis. However, full fine-tuning of generative models by updating billions of parameters is resource intensive, requiring specialized hardware, such as GPUs, processing time, and memory to load the model parameters.\n", "\n", - "LLMs are extremely large in size (parameters in the order of billions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.\n", - "\n", - "[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.\n", - "\n", - "This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the [Databricks Dolly 15k dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k). This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w1q6-W_mKIT-" - }, - "source": [ - "## Setup" + "[Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This technique makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs. This tutorial walks you through using Keras to perform LoRA fine-tuning on a Gemma model." ] }, { @@ -101,13 +88,13 @@ "id": "lyhHCMfoRZ_v" }, "source": [ - "### Get access to Gemma\n", + "## Setup\n", "\n", "To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n", "\n", "* Get access to Gemma on [kaggle.com](https://kaggle.com).\n", - "* Select a Colab runtime with sufficient resources to run\n", - " the Gemma 2B model.\n", + "* Select a Colab runtime with sufficient resources to tune\n", + " the Gemma model you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n", "* Generate and configure a Kaggle username and API key.\n", "\n", "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment." @@ -119,7 +106,7 @@ "id": "AZ5Qo0fxRZ1V" }, "source": [ - "### Select the runtime\n", + "### Select a Colab runtime\n", "\n", "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n", "\n", @@ -138,7 +125,7 @@ "\n", "To use Gemma, you must provide your Kaggle username and a Kaggle API key.\n", "\n", - "To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.\n", + "To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n", "\n", "In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`." ] @@ -156,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "id": "0_EdOg9DPK6Q" }, @@ -178,9 +165,9 @@ "id": "CuEUAKJW1QkQ" }, "source": [ - "### Install dependencies\n", + "### Install Keras packages\n", "\n", - "Install Keras, KerasNLP, and other dependencies." + "Install the Keras and KerasHub Python packages." ] }, { @@ -191,9 +178,8 @@ }, "outputs": [], "source": [ - "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n", - "!pip install -q -U keras-nlp\n", - "!pip install -q -U \"keras>=3\"" + "!pip install -q -U keras-hub\n", + "!pip install -q -U keras" ] }, { @@ -204,14 +190,12 @@ "source": [ "### Select a backend\n", "\n", - "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.\n", - "\n", - "For this tutorial, configure the backend for JAX." + "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch. For this tutorial, configure the backend for JAX as it typically provides the better performance." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "id": "yn5uy8X8sdD0" }, @@ -230,95 +214,19 @@ "source": [ "### Import packages\n", "\n", - "Import Keras and KerasNLP." + "Import the Python packages needed for this tutorial, including Keras and KerasHub." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "id": "FYHyPUA9hKTf" }, "outputs": [], "source": [ "import keras\n", - "import keras_nlp" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9T7xe_jzslv4" - }, - "source": [ - "## Load Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "xRaNCPUXKoa7" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--2024-07-31 01:56:39-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\n", - "Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...\n", - "Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.\n", - "HTTP request sent, awaiting response... 302 Found\n", - "Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]\n", - "--2024-07-31 01:56:39-- https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7\n", - "Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...\n", - "Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 13085339 (12M) [text/plain]\n", - "Saving to: ‘databricks-dolly-15k.jsonl’\n", - "\n", - "databricks-dolly-15 100%[===================>] 12.48M 73.7MB/s in 0.2s \n", - "\n", - "2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]\n", - "\n" - ] - } - ], - "source": [ - "!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "45UpBDfBgf0I" - }, - "source": [ - "Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "ZiS-KU9osh_N" - }, - "outputs": [], - "source": [ - "import json\n", - "data = []\n", - "with open(\"databricks-dolly-15k.jsonl\") as file:\n", - " for line in file:\n", - " features = json.loads(line)\n", - " # Filter out examples with context, to keep it simple.\n", - " if features[\"context\"]:\n", - " continue\n", - " # Format the entire example as a single string.\n", - " template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", - " data.append(template.format(**features))\n", - "\n", - "# Only use 1000 training examples, to keep it fast.\n", - "data = data[:1000]" + "import keras_hub" ] }, { @@ -327,146 +235,20 @@ "id": "7RCE3fdGhDE5" }, "source": [ - "## Load Model\n", - "\n", - "KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.\n", + "## Load model\n", "\n", - "Create the model using the `from_preset` method:" + "Keras provides implementations of Gemma and many other popular [model architectures](https://keras.io/keras_hub/api/models/). Use the `Gemma3CausalLM.from_preset()` method to configure an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "id": "vz5zLEyLstfn" }, - "outputs": [ - { - "data": { - "text/html": [ - " Preprocessor: \"gemma_causal_lm_preprocessor\"\n", - "
\n" - ], - "text/plain": [ - "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Tokenizer (type) ┃ Vocab # ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ gemma_tokenizer (GemmaTokenizer) │ 256,000 │\n", - "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", - "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Model: \"gemma_causal_lm\"\n", - "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ padding_mask (InputLayer) │ (None, None) │ 0 │ - │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_ids (InputLayer) │ (None, None) │ 0 │ - │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ gemma_backbone │ (None, None, 2304) │ 2,614,341,888 │ padding_mask[0][0], │\n", - "│ (GemmaBackbone) │ │ │ token_ids[0][0] │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_embedding │ (None, None, 256000) │ 589,824,000 │ gemma_backbone[0][0] │\n", - "│ (ReversibleEmbedding) │ │ │ │\n", - "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", - "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", - "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", - "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Total params: 2,614,341,888 (9.74 GB)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Trainable params: 2,614,341,888 (9.74 GB)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Non-trainable params: 0 (0.00 B)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_2b_en\")\n", + "gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\"gemma3_instruct_1b\")\n", "gemma_lm.summary()" ] }, @@ -476,10 +258,7 @@ "id": "Nl4lvPy5zA26" }, "source": [ - "The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string \"gemma2_2b_en\" specifies the preset architecture — a Gemma model with 2 billion parameters.\n", - "\n", - "NOTE: A Gemma model with 7\n", - "billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud." + "The `Gemma3CausalLM.from_preset()` method instantiates the model from a preset architecture and weights. In the code above, the string `\"gemma#_xxxxxxx\"` specifies a preset version and parameter size for Gemma. You can find the code strings for Gemma models in their **Model Variation** listings on [Kaggle](https://www.kaggle.com/models/keras/gemma3)." ] }, { @@ -490,7 +269,7 @@ "source": [ "## Inference before fine tuning\n", "\n", - "In this section, you will query the model with various prompts to see how it responds." + "Once you have downloaded and configured a Gemma model, you can query it with various prompts to see how it responds." ] }, { @@ -499,14 +278,14 @@ "id": "PVLXadptyo34" }, "source": [ - "### Europe Trip Prompt\n", + "### Europe trip prompt\n", "\n", "Query the model for suggestions on what to do on a trip to Europe." ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "id": "ZwQz3xxxKciD" }, @@ -519,40 +298,23 @@ "What should I do on a trip to Europe?\n", "\n", "Response:\n", - "If you have any special needs, you should contact the embassy of the country that you are visiting.\n", - "You should contact the embassy of the country that I will be visiting.\n", + "The first thing to know is that you will have a great time!\n", "\n", - "What are my responsibilities when I go on a trip?\n", + "Europe is a great place for a vacation. The countries of Europe are all very different and offer a wide range of activities and attractions. The countries of Europe are also very close to each other, which means you can visit many different places within a short time.\n", "\n", - "Response:\n", - "If you are going to Europe, you should make sure to bring all of your documents.\n", - "If you are going to Europe, make sure that you have all of your documents.\n", + "The best way to plan a trip to Europe is to look up the countries you want to visit and see what activities are offered in each country. You can also look for tours and tours that offer a good value for money.\n", "\n", - "When do you travel abroad?\n", + "You can also look for hotels and flights that offer good deals. If you are looking for a good value for money, you should look for hotels and flights that offer good deals. This means you will have a great time on your trip!\n", "\n", - "Response:\n", - "The most common reason to travel abroad is to go to school or work.\n", - "The most common reason to travel abroad is to work.\n", + "The next step is to book your tickets to the countries you want to visit. If you are planning to visit many countries, it's a good idea to book your tickets early. This means you’ll be able to get the best deal and avoid the long queues.\n", "\n", - "How can I get a visa to Europe?\n", - "\n", - "Response:\n", - "If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.\n", - "If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.\n", - "\n", - "When should I go to Europe?\n", - "\n", - "Response:\n", - "You should go to Europe when the weather is nice.\n", - "You should go to Europe when the weather is bad.\n", - "\n", - "How can I make a reservation for a trip?\n", - "\n", - "\n" + "The next step is to plan your itinerary. You can use a travel guide to plan your itinerary\n" ] } ], "source": [ + "template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", + "\n", "prompt = template.format(\n", " instruction=\"What should I do on a trip to Europe?\",\n", " response=\"\",\n", @@ -577,14 +339,14 @@ "id": "YQ74Zz_S0iVv" }, "source": [ - "### ELI5 Photosynthesis Prompt\n", + "### Photosynthesis prompt\n", "\n", "Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "id": "lorJMbsusgoo" }, @@ -597,25 +359,13 @@ "Explain the process of photosynthesis in a way that a child could understand.\n", "\n", "Response:\n", - "Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.\n", - "\n", - "Instruction:\n", - "What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?\n", - "\n", - "Response:\n", - "The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.\n", - "\n", - "Instruction:\n", - "Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.\n", - "\n", - "Response:\n", - "Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.\n", - "\n", - "Instruction:\n", - "How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?\n", - "\n", - "Response:\n", - "Photosynthesis occurs in the cells of a plant. The purpose of\n" + "Photosynthesis is a biological process that occurs in plants, algae, and some other organisms. In the process, light energy is captured and converted into the energy stored in the bonds of organic molecules. The process is crucial for life on Earth because it enables plants to use carbon dioxide and water to produce glucose and oxygen, which are essential for all living things.\n", + "The process involves several stages:\n", + "1. Light Reactions: Light energy is absorbed by pigments in the chloroplasts of the plant, converting it into chemical energy in the form of ATP and reducing power.\n", + "2. Carbon Fixation: During this stage, carbon dioxide is combined with hydrogen to form organic molecules such as starch or glucose, which are used as a source of energy.\n", + "3. Calvin Cycle: The process of carbon fixation occurs in the stroma of the chloroplasts. It involves the capture and reduction of carbon dioxide, producing glucose and reducing power in the form of ATP and NADPH molecules.\n", + "4. Stroma: The stroma is the fluid-filled space where the light reactions occur in the chloroplasts.\n", + "5. Chloroplasts: The chloroplasts contain the green pigments that absorb\n" ] } ], @@ -642,151 +392,144 @@ "id": "Pt7Nr6a7tItO" }, "source": [ - "## LoRA Fine-tuning\n", - "\n", - "To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.\n", + "## LoRA fine-tuning\n", "\n", - "The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.\n", - "\n", - "A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.\n", + "This section shows you how to do fine-tuning using the Low Rank Adaptation (LoRA) tuning technique. This approach allows you to change the behavior of Gemma models using fewer compute resources." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9T7xe_jzslv4" + }, + "source": [ + "### Load dataset\n", "\n", - "This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance." + "Prepare a dataset for tuning by downloading an existing data set and formatting if for use with the the Keras `fit()` fine-tuning method. This tutorial uses the [Databricks Dolly 15k dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k) for fine-tuning. The dataset contains 15,000 high-quality human-generated prompt and response pairs specifically designed for tuning generative models." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { - "id": "RCucu6oHz53G" + "id": "xRaNCPUXKoa7" }, "outputs": [ { - "data": { - "text/html": [ - "Preprocessor: \"gemma_causal_lm_preprocessor\"\n", - "
\n" - ], - "text/plain": [ - "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Tokenizer (type) ┃ Vocab # ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ gemma_tokenizer (GemmaTokenizer) │ 256,000 │\n", - "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", - "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Model: \"gemma_causal_lm\"\n", - "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ padding_mask (InputLayer) │ (None, None) │ 0 │ - │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_ids (InputLayer) │ (None, None) │ 0 │ - │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ gemma_backbone │ (None, None, 2304) │ 2,617,270,528 │ padding_mask[0][0], │\n", - "│ (GemmaBackbone) │ │ │ token_ids[0][0] │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_embedding │ (None, None, 256000) │ 589,824,000 │ gemma_backbone[0][0] │\n", - "│ (ReversibleEmbedding) │ │ │ │\n", - "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,617,270,528\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", - "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", - "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", - "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", - "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", - "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Total params: 2,617,270,528 (9.75 GB)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Trainable params: 2,928,640 (11.17 MB)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Non-trainable params: 2,614,341,888 (9.74 GB)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "--2025-04-10 20:48:49-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\n", + "Resolving huggingface.co (huggingface.co)... 3.163.189.37, 3.163.189.114, 3.163.189.74, ...\n", + "Connecting to huggingface.co (huggingface.co)|3.163.189.37|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1744321729&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDMyMTcyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vh0VIGB-UkK57FSfRikYCREpKuHt%7EnDKPcHHgC1V9rDXLABIRF81nK7olQhAq6zSbAqEtMNnvHgd8IBK1j54mdIYdVLiBwImqez3xu2CPhzYBtKWInnXj9lTXW0p-9GEHcbU%7Eoot22qFSdwyZf1UIdmHZLTHPWjtLhfRkKbg-ptA3CFeegtmvCtY-WG2GffJ%7Em2q2bbs-U1m0yI7cSTW18nD8VSBihxGOMnS1IhkO-LgE4I6GJISXROTk-61%7EJiEIKcagcijL4QGi8j1g9xeQamBXX4hWBdkbJgX5PtX15Ftd0HCM4zCzcJAUrE3ZEJRLe2XRUwfKU3ai7-%7ErPpnSA__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n", + "--2025-04-10 20:48:49-- https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1744321729&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDMyMTcyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vh0VIGB-UkK57FSfRikYCREpKuHt%7EnDKPcHHgC1V9rDXLABIRF81nK7olQhAq6zSbAqEtMNnvHgd8IBK1j54mdIYdVLiBwImqez3xu2CPhzYBtKWInnXj9lTXW0p-9GEHcbU%7Eoot22qFSdwyZf1UIdmHZLTHPWjtLhfRkKbg-ptA3CFeegtmvCtY-WG2GffJ%7Em2q2bbs-U1m0yI7cSTW18nD8VSBihxGOMnS1IhkO-LgE4I6GJISXROTk-61%7EJiEIKcagcijL4QGi8j1g9xeQamBXX4hWBdkbJgX5PtX15Ftd0HCM4zCzcJAUrE3ZEJRLe2XRUwfKU3ai7-%7ErPpnSA__&Key-Pair-Id=K3RPWS32NSSJCE\n", + "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.238.217.63, 18.238.217.81, 18.238.217.120, ...\n", + "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.238.217.63|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 13085339 (12M) [text/plain]\n", + "Saving to: ‘databricks-dolly-15k.jsonl’\n", + "\n", + "databricks-dolly-15 100%[===================>] 12.48M --.-KB/s in 0.08s \n", + "\n", + "2025-04-10 20:48:49 (156 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]\n", + "\n" + ] } ], + "source": [ + "!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "45UpBDfBgf0I" + }, + "source": [ + "### Format tuning data\n", + "\n", + "Format the downloaded data for use with the Keras `fit()` method. The following code extracts a subset of the training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZiS-KU9osh_N" + }, + "outputs": [], + "source": [ + "import json\n", + "\n", + "prompts = []\n", + "responses = []\n", + "line_count = 0\n", + "\n", + "with open(\"databricks-dolly-15k.jsonl\") as file:\n", + " for line in file:\n", + " if line_count >= 1000:\n", + " break # Limit the training examples, to reduce execution time.\n", + "\n", + " examples = json.loads(line)\n", + " # Filter out examples with context, to keep it simple.\n", + " if examples[\"context\"]:\n", + " continue\n", + " # Format data into prompts and response lists.\n", + " prompts.append(examples[\"instruction\"])\n", + " responses.append(examples[\"response\"])\n", + "\n", + " line_count += 1\n", + "\n", + "data = {\n", + " \"prompts\": prompts,\n", + " \"responses\": responses\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cBLW5hiGj31i" + }, + "source": [ + "### Configure LoRA tuning\n", + "\n", + "Activate LoRA tuning using the Keras `model.backbone.enable_lora()` method, including a LoRA rank value. The *LoRA rank* determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments. A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.\n", + "\n", + "This example uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This setting is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RCucu6oHz53G" + }, + "outputs": [], "source": [ "# Enable LoRA for the model and set the LoRA rank to 4.\n", - "gemma_lm.backbone.enable_lora(rank=4)\n", + "gemma_lm.backbone.enable_lora(rank=4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PlMLp_NVbRoQ" + }, + "source": [ + "Check the model summary after setting the LoRA rank. Notice that enabling LoRA reduces the number of trainable parameters significantly compared to the total number of parameters in the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KqYyS0gm6pNy" + }, + "outputs": [], + "source": [ "gemma_lm.summary()" ] }, @@ -796,12 +539,48 @@ "id": "hQQ47kcdpbZ9" }, "source": [ - "Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million)." + "Configure the rest of the fine-tuning settings, including the preprocessor settings, optimizer, number of tuning epochs, and batch size:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "p9sBNH8SAjgB" + }, + "outputs": [], + "source": [ + "# Limit the input sequence length to 256 (to control memory usage).\n", + "gemma_lm.preprocessor.sequence_length = 256\n", + "# Use AdamW (a common optimizer for transformer models).\n", + "optimizer = keras.optimizers.AdamW(\n", + " learning_rate=5e-5,\n", + " weight_decay=0.01,\n", + ")\n", + "# Exclude layernorm and bias terms from decay.\n", + "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", + "\n", + "gemma_lm.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=optimizer,\n", + " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OA0ozGC66tk1" + }, + "source": [ + "### Run the fine-tune process\n", + "\n", + "Run the fine-tuning process using the `fit()` method. This process can take several minutes depending on your compute resources, data size, and number of epochs:" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "id": "_Peq7TnLtHse" }, @@ -825,21 +604,6 @@ } ], "source": [ - "# Limit the input sequence length to 256 (to control memory usage).\n", - "gemma_lm.preprocessor.sequence_length = 256\n", - "# Use AdamW (a common optimizer for transformer models).\n", - "optimizer = keras.optimizers.AdamW(\n", - " learning_rate=5e-5,\n", - " weight_decay=0.01,\n", - ")\n", - "# Exclude layernorm and bias terms from decay.\n", - "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", - "\n", - "gemma_lm.compile(\n", - " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " optimizer=optimizer,\n", - " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", - ")\n", "gemma_lm.fit(data, epochs=1, batch_size=1)" ] }, @@ -849,12 +613,9 @@ "id": "bx3m8f1dB7nk" }, "source": [ - "### Note on mixed precision fine-tuning on NVIDIA GPUs\n", - "\n", - "Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, note that you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality. Mixed precision fine-tuning does consume more memory so is useful only on larger GPUs.\n", + "#### Mixed precision fine-tuning on NVIDIA GPUs\n", "\n", - "\n", - "For inference, half-precision (`keras.config.set_floatx(\"bfloat16\")`) will work and save memory while mixed precision is not applicable." + "Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality." ] }, { @@ -876,7 +637,8 @@ }, "source": [ "## Inference after fine-tuning\n", - "After fine-tuning, responses follow the instruction provided in the prompt." + "\n", + "After fine-tuning, you should see changes in the responses when the tuned model is given the same prompt." ] }, { @@ -885,12 +647,14 @@ "id": "H55JYJ1a1Kos" }, "source": [ - "### Europe Trip Prompt" + "### Europe trip prompt\n", + "\n", + "Try the Europe trip prompt from earlier and note the differences in the response." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "id": "Y7cDJHy8WfCB" }, @@ -923,7 +687,7 @@ "id": "OXP6gg2mjs6u" }, "source": [ - "The model now recommends places to visit in Europe." + "The model now provides a shorter response to a question about visiting Europe." ] }, { @@ -932,12 +696,14 @@ "id": "H7nVd8Mi1Yta" }, "source": [ - "### ELI5 Photosynthesis Prompt" + "### Photosynthesis prompt\n", + "\n", + "Try the photosynthesis explanation prompt from earlier and note the differences in the response." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "id": "X-2sYl2jqwl7" }, @@ -977,7 +743,9 @@ "id": "I8kFG12l0mVe" }, "source": [ - "Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:\n", + "## Improving fine-tune results\n", + "\n", + "For demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:\n", "\n", "1. Increasing the size of the fine-tuning dataset\n", "2. Training for more steps (epochs)\n", @@ -993,12 +761,12 @@ "source": [ "## Summary and next steps\n", "\n", - "This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:\n", + "This tutorial covered LoRA fine-tuning on a Gemma model using Keras. Check out the following docs next:\n", "\n", "* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n", "* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n", "* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n", - "* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)." + "* Learn how to [fine-tune Gemma using Keras and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)." ] } ],