diff --git a/training/lora/README.md b/training/lora/README.md new file mode 100644 index 0000000..8d4c90f --- /dev/null +++ b/training/lora/README.md @@ -0,0 +1,84 @@ +## Fine-tuning with DeeperSpeed +### Install dependencies + +`mamba install -c conda-forge cudatoolkit-dev` + +`export CUDA_HOME=$CONDA_PREFIX` + +`pip install evaluate datasets peft transformers git+https://github.com/EleutherAI/DeeperSpeed.git` + +`pip install 'transformers[sklearn]'` + +#### Install bitsandbytes if loading in 8-bit +`pip install bitsandbytes` + +### Start... + +`cd training/lora` + +## Examples +#### From HuggingFace dataset: +``` +deepspeed --num_gpus=1 finetune.py \ +--deepspeed example/config.json \ +--model_name_or_path togethercomputer/RedPajama-INCITE-Base-3B-v1 \ +--dataset_name imdb \ +--do_train \ +--do_eval \ +--fp16 \ +--overwrite_cache \ +--evaluation_strategy="steps" \ +--output_dir finetuned \ +--num_train_epochs 1 \ +--eval_steps 15 \ +--gradient_accumulation_steps 1 \ +--per_device_train_batch_size 4 \ +--use_fast_tokenizer True \ +--learning_rate 1e-5 \ +--warmup_steps 10 +``` +#### From train and validation files: +``` +deepspeed --num_gpus=1 finetune.py \ +--deepspeed example/config.json \ +--model_name_or_path togethercomputer/RedPajama-INCITE-Base-3B-v1 \ +--train_file train.csv \ +--validation_file validation.csv \ +--do_train \ +--do_eval \ +--fp16 \ +--overwrite_cache \ +--evaluation_strategy="steps" \ +--output_dir finetuned \ +--num_train_epochs 1 \ +--eval_steps 15 \ +--gradient_accumulation_steps 1 \ +--per_device_train_batch_size 4 \ +--use_fast_tokenizer True \ +--learning_rate 1e-5 \ +--warmup_steps 10 +``` + +#### In 8-bit +** Change `fp16.enabled` to `false` in `example/config.json` ** +``` +deepspeed --num_gpus=1 finetune.py \ +--deepspeed example/config.json \ +--model_name_or_path togethercomputer/RedPajama-INCITE-Base-3B-v1 \ +--dataset_name imdb \ +--do_train \ +--do_eval \ +--int8 \ +--low_cpu_mem_usage \ +--overwrite_cache \ +--evaluation_strategy="steps" \ +--output_dir finetuned \ +--num_train_epochs 1 \ +--eval_steps 15 \ +--gradient_accumulation_steps 1 \ +--per_device_train_batch_size 4 \ +--use_fast_tokenizer True \ +--learning_rate 1e-5 \ +--warmup_steps 10 \ +--no_cache +``` diff --git a/training/lora/example/config.json b/training/lora/example/config.json new file mode 100644 index 0000000..2a6b6b6 --- /dev/null +++ b/training/lora/example/config.json @@ -0,0 +1,39 @@ +{ + "train_batch_size": "auto", + "fp16": { + "enabled": true, + "min_loss_scale": 1, + "opt_level": "O2" + }, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "cpu" + }, + "offload_optimizer": { + "device": "cpu" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + } +} diff --git a/training/lora/example/finetuning.ipynb b/training/lora/example/finetuning.ipynb new file mode 100644 index 0000000..be74eb3 --- /dev/null +++ b/training/lora/example/finetuning.ipynb @@ -0,0 +1,250 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "<a href=\"https://colab.research.google.com/github/orangetin/OpenChatKit/blob/peft/training/lora/example/finetuning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# OpenChatKit - Fine-tuning" + ], + "metadata": { + "id": "sLrKqm0BULlD" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Check GPU availability" + ], + "metadata": { + "id": "eZsgPnayURrc" + } + }, + { + "cell_type": "code", + "source": [ + "!nvidia-smi" + ], + "metadata": { + "id": "qy_ENUlFgG4a" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Install conda" + ], + "metadata": { + "id": "0gy7ssnoT_SI" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && chmod +x Miniconda3-latest-Linux-x86_64.sh && ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local" + ], + "metadata": { + "id": "11MMVFkAKtyg" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Setting up conda environment" + ], + "metadata": { + "id": "CD7yF4rvT3Y8" + } + }, + { + "cell_type": "code", + "source": [ + "!conda install mamba -n base -c conda-forge -y" + ], + "metadata": { + "id": "-W6PrOSILQoc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/orangetin/OpenChatKit.git --branch peft && cd OpenChatKit && mamba create -n OpenChatKit python=3.10.9 -y" + ], + "metadata": { + "id": "hC8ob6kuLSn2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!source activate OpenChatKit && mamba install pytorch torchvision torchaudio cudatoolkit-dev pytorch-cuda=11.6 -c pytorch -c nvidia -c conda-forge -y" + ], + "metadata": { + "id": "waQdRff3Dee4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!source activate OpenChatKit && export CUDA_HOME=$CONDA_PREFIX && pip install accelerate evaluate datasets peft chardet cchardet transformers git+https://github.com/EleutherAI/DeeperSpeed.git bitsandbytes && pip install 'transformers[sklearn]'" + ], + "metadata": { + "id": "T_K3hXCVz7I1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Download dataset and convert jsonl to json" + ], + "metadata": { + "id": "cVc_deb3O9q1" + } + }, + { + "cell_type": "code", + "source": [ + "!cd OpenChatKit/training/lora && mkdir data && mkdir data_jsonl" + ], + "metadata": { + "id": "RoNQGlepO-Uj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!cd OpenChatKit/training/lora/data_jsonl && wget https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl" + ], + "metadata": { + "id": "2xZJ3uSdO_xT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import json\n", + "\n", + "with open('OpenChatKit/training/lora/data_jsonl/unified_chip2.jsonl', 'r') as in_file:\n", + " lines = [json.loads(line) for line in in_file.readlines()]\n", + "\n", + "with open('OpenChatKit/training/lora/data/unified_chip2.json', 'w') as out_file:\n", + " json.dump(lines, out_file)" + ], + "metadata": { + "id": "peZQbFRXPA4q" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Initialize training in 8-bit" + ], + "metadata": { + "id": "jOKRM0VVUjwk" + } + }, + { + "cell_type": "markdown", + "source": [ + "Edits config to disable fp16" + ], + "metadata": { + "id": "RLk6ghH1PgZ8" + } + }, + { + "cell_type": "code", + "source": [ + "!cd OpenChatKit/training/lora && sed -i -e 's/\"enabled\": true,/\"enabled\": false,/g' example/config.json" + ], + "metadata": { + "id": "AzkcI5ll-mDt" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "To change to fp16, replace `--int8 \\ --low_cpu_mem_usage \\` with `--fp16 \\`" + ], + "metadata": { + "id": "0kmhEjGlPjzZ" + } + }, + { + "cell_type": "code", + "source": [ + "!source activate OpenChatKit && export CUDA_HOME=$CONDA_PREFIX && cd OpenChatKit/training/lora && deepspeed --num_gpus=1 finetune.py \\\n", + "--deepspeed example/config.json \\\n", + "--model_name_or_path togethercomputer/RedPajama-INCITE-Chat-3B-v1 \\\n", + "--train_file data/unified_chip2.json \\\n", + "--validation_split_percentage 10 \\\n", + "--do_train \\\n", + "--do_eval \\\n", + "--overwrite_cache \\\n", + "--evaluation_strategy=\"steps\" \\\n", + "--output_dir finetuned \\\n", + "--num_train_epochs 1 \\\n", + "--eval_steps 15 \\\n", + "--gradient_accumulation_steps 2 \\\n", + "--per_device_train_batch_size 4 \\\n", + "--use_fast_tokenizer True \\\n", + "--learning_rate 1e-5 \\\n", + "--warmup_steps 10 \\\n", + "--int8 \\\n", + "--low_cpu_mem_usage \\\n", + "--no_cache" + ], + "metadata": { + "id": "82cyWiyi8y9f" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/training/lora/finetune.py b/training/lora/finetune.py new file mode 100644 index 0000000..4756e96 --- /dev/null +++ b/training/lora/finetune.py @@ -0,0 +1,618 @@ +import logging +import math +import os +import sys +from dataclasses import dataclass, field +from itertools import chain +from typing import Optional + +import datasets +import evaluate +import torch +from datasets import load_dataset + +from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training + +import transformers +from transformers import ( + CONFIG_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, + default_data_collator, + is_torch_tpu_available, + set_seed, +) +from transformers.testing_utils import CaptureLogger +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.29.0.dev0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +logger = logging.getLogger(__name__) + + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." + ) + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + ) + }, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": ( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script " + "with private models)." + ) + }, + ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + int8: bool = field( + default=False, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "set True will benefit LLM loading time and RAM consumption." + ) + }, + ) + + def __post_init__(self): + if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) + block_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + no_cache: bool = field( + default=False + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + keep_linebreaks: bool = field( + default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} + ) + + def __post_init__(self): + if self.streaming: + require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") + + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, + ) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, + ) + raw_datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, + ) + else: + data_files = {} + dataset_args = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = ( + data_args.train_file.split(".")[-1] + if data_args.train_file is not None + else data_args.validation_file.split(".")[-1] + ) + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = data_args.keep_linebreaks + raw_datasets = load_dataset( + extension, + data_files=data_files, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + **dataset_args, + ) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + **dataset_args, + ) + + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") + + tokenizer_kwargs = { + "cache_dir": model_args.cache_dir, + "use_fast": model_args.use_fast_tokenizer, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + + target_modules = ["query_key_value", "xxx"] # workaround to use 8bit training on this model + + peft_config = LoraConfig( + r=16, lora_alpha=32, target_modules=target_modules, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" + ) + + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if model_args.model_name_or_path: + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + device_map="auto" if model_args.int8 else None, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + torch_dtype=torch_dtype, + low_cpu_mem_usage=model_args.low_cpu_mem_usage, + load_in_8bit=model_args.int8, + ) + else: + model = AutoModelForCausalLM.from_config(config) + n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) + logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") + + if model_args.int8: + model = prepare_model_for_int8_training(model) + + + model.gradient_checkpointing_enable() # reduce number of stored activations + model.enable_input_require_grads() + + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + if training_args.do_train: + column_names = list(raw_datasets["train"].features) + else: + column_names = list(raw_datasets["validation"].features) + text_column_name = "text" if "text" in column_names else column_names[0] + + # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function + tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") + + def tokenize_function(examples): + with CaptureLogger(tok_logger) as cl: + output = tokenizer(examples[text_column_name], padding="max_length", truncation=True) + # clm input could be much much longer than block_size + if "Token indices sequence length is longer than the" in cl.out: + tok_logger.warning( + "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" + " before being passed to the model." + ) + return output + + with training_args.main_process_first(desc="dataset map tokenization"): + if not data_args.streaming: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + else: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + remove_columns=column_names, + ) + + if data_args.block_size is None: + block_size = tokenizer.model_max_length + #if block_size > 1024: + # logger.warning( + # "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" + # " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" + # " override this default with `--block_size xxx`." + # ) + # block_size = 1024 + else: + if data_args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(data_args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + with training_args.main_process_first(desc="grouping texts together"): + if not data_args.streaming: + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + else: + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + ) + + if training_args.do_train: + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = lm_datasets["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + + if training_args.do_eval: + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = lm_datasets["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + + def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + # Depending on the model and config, logits may contain extra tensors, + # like past_key_values, but logits always come first + logits = logits[0] + return logits.argmax(dim=-1) + + metric = evaluate.load("accuracy") + + def compute_metrics(eval_preds): + preds, labels = eval_preds + # preds have the same shape as the labels, after the argmax(-1) has been calculated + # by preprocess_logits_for_metrics but we need to shift the labels + labels = labels[:, 1:].reshape(-1) + preds = preds[:, :-1].reshape(-1) + return metric.compute(predictions=preds, references=labels) + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + # Data collator will default to DataCollatorWithPadding, so we change it. + data_collator=default_data_collator, + compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, + preprocess_logits_for_metrics=preprocess_logits_for_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + if data_args.no_cache: + model.config.use_cache = False + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + model.save_pretrained(training_args.output_dir) + + metrics = train_result.metrics + + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + metrics = trainer.evaluate() + + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["perplexity"] = perplexity + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main()