Skip to content

Commit 28ba8c2

Browse files
committed
refactor(scripts): optimize finetune_model.py
1 parent 3da0333 commit 28ba8c2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

scripts/train_model.py renamed to scripts/finetune_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import glob
33
import json
44
from dotenv import load_dotenv
5-
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, LlamaTokenizerFast
66
from datasets import Dataset, load_dataset
77
import torch
88

99
load_dotenv()
1010

11-
MODEL_NAME = "gpt2" # You can change this to a different model
11+
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # Changed to the specified model
1212
OUTPUT_DIR = "trained_model"
1313
TRAIN_FILE = "train.jsonl"
1414
MAX_LENGTH = 512
@@ -47,7 +47,7 @@ def train_model(dataset, tokenizer):
4747
prediction_loss_only=True,
4848
remove_unused_columns=False,
4949
)
50-
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
50+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
5151
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
5252
trainer = Trainer(
5353
model=model,
@@ -61,7 +61,7 @@ def train_model(dataset, tokenizer):
6161
def main():
6262
posts_dir = "_posts"
6363
texts = create_training_data(posts_dir)
64-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
64+
tokenizer = LlamaTokenizerFast.from_pretrained(MODEL_NAME, trust_remote_code=True, use_fast=True)
6565
tokenizer.pad_token = tokenizer.eos_token
6666
dataset = prepare_dataset(texts, tokenizer)
6767
train_model(dataset, tokenizer)

0 commit comments

Comments
 (0)