From 945a3394a99377261b2e1cd16a4904ad034d0d7f Mon Sep 17 00:00:00 2001 From: binhang Date: Fri, 31 Mar 2023 01:47:09 +0800 Subject: [PATCH] update the 7B model script for memory efficiency --- README.md | 14 ++++++++------ training/finetune_Pythia-Chat-Base-7B.sh | 6 +++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 26c968f..b57d911 100644 --- a/README.md +++ b/README.md @@ -9,19 +9,21 @@ In this repo, you'll find code for: # Contents +- [OpenChatKit](#openchatkit) +- [Contents](#contents) - [Requirements](#requirements) - [Pre-trained Weights](#pre-trained-weights) - [Datasets](#datasets) - * [Data Contributions](#data-contributions) + - [Data Contributions](#data-contributions) - [Pretrained Base Model](#pretrained-base-model) - [Training and Finetuning](#training-and-finetuning) - * [(Optional) 8bit Adam](#optional-8bit-adam) - * [Train GPT-NeoX-Chat-Base-20B](#train-gpt-neox-chat-base-20b) + - [(Optional) 8bit Adam](#optional-8bit-adam) + - [Train GPT-NeoX-Chat-Base-20B](#train-gpt-neox-chat-base-20b) - [Converting Weights to Huggingface Format](#converting-weights-to-huggingface-format) - [Inference](#inference) - [Monitoring](#monitoring) - * [Loguru](#loguru) - * [Weights & Biases](#weights--biases) + - [Loguru](#loguru) + - [Weights \& Biases](#weights--biases) - [Experimental: Retrieval-Augmented Models](#experimental-retrieval-augmented-models) - [License](#license) - [Citing OpenChatKit](#citing-openchatkit) @@ -122,7 +124,7 @@ As the training loop runs, checkpoints are saved to the `model_ckpts` directory Please see [the training README](training/README.md) for more details about customizing the training run. -The `training/finetune_Pythia-Chat-Base-7B.sh` script is another example to fine-tune a 7B pythia (gpt-neox) model. The script launches 8 processes with a pipeline-parallel degree of 4 and a data-parallel degree of 2. +The `training/finetune_Pythia-Chat-Base-7B.sh` script is another example to fine-tune a 7B pythia (gpt-neox) model. The script launches 8 processes with a pipeline-parallel degree of 8 and a data-parallel degree of 1. # Converting Weights to Huggingface Format diff --git a/training/finetune_Pythia-Chat-Base-7B.sh b/training/finetune_Pythia-Chat-Base-7B.sh index 92a92e0..6b22dba 100644 --- a/training/finetune_Pythia-Chat-Base-7B.sh +++ b/training/finetune_Pythia-Chat-Base-7B.sh @@ -52,10 +52,10 @@ ARGS="--model-name ${BASE_MODEL} \ --checkpoint-path ${DIR}/../model_ckpts/${MODEL_NAME} \ --total-steps 20000 --warmup-steps 10 --train-warmup-steps 0 \ --checkpoint-steps ${CHECKPOINT_STEPS} \ ---lr 1e-5 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \ +--lr 1e-5 --seq-length 2048 --batch-size 16 --micro-batch-size 1 --gradient-accumulate-step 4 \ --dist-url tcp://127.0.0.1:7033 \ ---num-layers 8 --embedding-dim 4096 \ ---world-size 8 --pipeline-group-size 4 --data-group-size 2 \ + --num-layers 4 --embedding-dim 4096 \ +--world-size 8 --pipeline-group-size 8 --data-group-size 1 \ --job-id 0 --net-interface ${netif} \ --fp16 \ --dp-backend nccl \