|
| 1 | +--- |
| 2 | +audio: true |
| 3 | +lang: en |
| 4 | +layout: post |
| 5 | +title: Finetune a model |
| 6 | +translated: false |
| 7 | +--- |
| 8 | + |
| 9 | +```python |
| 10 | +import os |
| 11 | +import glob |
| 12 | +import json |
| 13 | +from dotenv import load_dotenv |
| 14 | +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, LlamaTokenizerFast |
| 15 | +from datasets import Dataset, load_dataset |
| 16 | +import torch |
| 17 | + |
| 18 | +load_dotenv() |
| 19 | + |
| 20 | +MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # Changed to the specified model |
| 21 | +OUTPUT_DIR = "trained_model" |
| 22 | +TRAIN_FILE = "train.jsonl" |
| 23 | +MAX_LENGTH = 512 |
| 24 | +BATCH_SIZE = 8 |
| 25 | +EPOCHS = 3 |
| 26 | + |
| 27 | +def create_training_data(posts_dir): |
| 28 | + all_texts = [] |
| 29 | + for lang_dir in os.listdir(posts_dir): |
| 30 | + lang_path = os.path.join(posts_dir, lang_dir) |
| 31 | + if not os.path.isdir(lang_path): |
| 32 | + continue |
| 33 | + for file_path in glob.glob(os.path.join(lang_path, "*.md")): |
| 34 | + try: |
| 35 | + with open(file_path, 'r', encoding='utf-8') as f: |
| 36 | + content = f.read() |
| 37 | + # Remove front matter |
| 38 | + content = content.split("---", 2)[-1].strip() |
| 39 | + all_texts.append(content) |
| 40 | + except Exception as e: |
| 41 | + print(f"Error reading file {file_path}: {e}") |
| 42 | + return all_texts |
| 43 | + |
| 44 | +def prepare_dataset(texts, tokenizer): |
| 45 | + encodings = tokenizer(texts, truncation=True, padding=True, max_length=MAX_LENGTH, return_tensors="pt") |
| 46 | + return Dataset.from_dict(encodings) |
| 47 | + |
| 48 | +def train_model(dataset, tokenizer): |
| 49 | + training_args = TrainingArguments( |
| 50 | + output_dir=OUTPUT_DIR, |
| 51 | + overwrite_output_dir=True, |
| 52 | + num_train_epochs=EPOCHS, |
| 53 | + per_device_train_batch_size=BATCH_SIZE, |
| 54 | + save_steps=10_000, |
| 55 | + save_total_limit=2, |
| 56 | + prediction_loss_only=True, |
| 57 | + remove_unused_columns=False, |
| 58 | + ) |
| 59 | + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True) |
| 60 | + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
| 61 | + trainer = Trainer( |
| 62 | + model=model, |
| 63 | + args=training_args, |
| 64 | + train_dataset=dataset, |
| 65 | + data_collator=data_collator, |
| 66 | + ) |
| 67 | + trainer.train() |
| 68 | + trainer.save_model(OUTPUT_DIR) |
| 69 | + |
| 70 | +def main(): |
| 71 | + posts_dir = "_posts" |
| 72 | + texts = create_training_data(posts_dir) |
| 73 | + tokenizer = LlamaTokenizerFast.from_pretrained(MODEL_NAME, trust_remote_code=True, use_fast=True) |
| 74 | + tokenizer.pad_token = tokenizer.eos_token |
| 75 | + dataset = prepare_dataset(texts, tokenizer) |
| 76 | + train_model(dataset, tokenizer) |
| 77 | + |
| 78 | +if __name__ == "__main__": |
| 79 | + main() |
| 80 | + |
| 81 | +``` |
0 commit comments