2
2
import glob
3
3
import json
4
4
from dotenv import load_dotenv
5
- from transformers import AutoTokenizer , AutoModelForCausalLM , Trainer , TrainingArguments , DataCollatorForLanguageModeling
5
+ from transformers import AutoTokenizer , AutoModelForCausalLM , Trainer , TrainingArguments , DataCollatorForLanguageModeling , LlamaTokenizerFast
6
6
from datasets import Dataset , load_dataset
7
7
import torch
8
8
9
9
load_dotenv ()
10
10
11
- MODEL_NAME = "gpt2 " # You can change this to a different model
11
+ MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B " # Changed to the specified model
12
12
OUTPUT_DIR = "trained_model"
13
13
TRAIN_FILE = "train.jsonl"
14
14
MAX_LENGTH = 512
@@ -47,7 +47,7 @@ def train_model(dataset, tokenizer):
47
47
prediction_loss_only = True ,
48
48
remove_unused_columns = False ,
49
49
)
50
- model = AutoModelForCausalLM .from_pretrained (MODEL_NAME )
50
+ model = AutoModelForCausalLM .from_pretrained (MODEL_NAME , trust_remote_code = True )
51
51
data_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm = False )
52
52
trainer = Trainer (
53
53
model = model ,
@@ -61,7 +61,7 @@ def train_model(dataset, tokenizer):
61
61
def main ():
62
62
posts_dir = "_posts"
63
63
texts = create_training_data (posts_dir )
64
- tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
64
+ tokenizer = LlamaTokenizerFast .from_pretrained (MODEL_NAME , trust_remote_code = True , use_fast = True )
65
65
tokenizer .pad_token = tokenizer .eos_token
66
66
dataset = prepare_dataset (texts , tokenizer )
67
67
train_model (dataset , tokenizer )
0 commit comments