@@ -51,6 +51,7 @@ def __init__(
51
51
tokenizer = None ,
52
52
max_length = 150 ,
53
53
model_type : str = "causal_lm" ,
54
+ device = None ,
54
55
):
55
56
if expert_weights is None :
56
57
expert_weights = [- 0.5 , 0.5 ]
@@ -94,6 +95,13 @@ def __init__(
94
95
)
95
96
self .content_feature = "comment_text"
96
97
98
+ if isinstance (device , str ):
99
+ self .device = torch .device (device )
100
+ else :
101
+ self .device = torch .device (
102
+ "cuda" if torch .cuda .is_available () else "cpu"
103
+ )
104
+
97
105
def load_models (self , experts : list , expert_weights : list = None ):
98
106
"""Load expert models."""
99
107
if expert_weights is not None :
@@ -102,7 +110,9 @@ def load_models(self, experts: list, expert_weights: list = None):
102
110
for expert in experts :
103
111
if isinstance (expert , str ):
104
112
expert = BartForConditionalGeneration .from_pretrained (
105
- expert , forced_bos_token_id = self .tokenizer .bos_token_id
113
+ expert ,
114
+ forced_bos_token_id = self .tokenizer .bos_token_id ,
115
+ device_map = "auto" ,
106
116
)
107
117
expert_models .append (expert )
108
118
self .experts = expert_models
@@ -200,15 +210,21 @@ def train_models(
200
210
201
211
if model_type is None :
202
212
gminus = BartForConditionalGeneration .from_pretrained (
203
- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
213
+ base_model ,
214
+ forced_bos_token_id = self .tokenizer .bos_token_id ,
215
+ device_map = "auto" ,
204
216
)
205
217
elif model_type == "causal_lm" :
206
218
gminus = AutoModelForCausalLM .from_pretrained (
207
- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
219
+ base_model ,
220
+ forced_bos_token_id = self .tokenizer .bos_token_id ,
221
+ device_map = "auto" ,
208
222
)
209
223
elif model_type == "seq2seq_lm" :
210
224
gminus = AutoModelForSeq2SeqLM .from_pretrained (
211
- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
225
+ base_model ,
226
+ forced_bos_token_id = self .tokenizer .bos_token_id ,
227
+ device_map = "auto" ,
212
228
)
213
229
else :
214
230
raise Exception (f"unsupported model type { model_type } " )
@@ -254,15 +270,21 @@ def train_models(
254
270
255
271
if model_type is None :
256
272
gplus = BartForConditionalGeneration .from_pretrained (
257
- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
273
+ base_model ,
274
+ forced_bos_token_id = self .tokenizer .bos_token_id ,
275
+ device_map = "auto" ,
258
276
)
259
277
elif model_type == "causal_lm" :
260
278
gplus = AutoModelForCausalLM .from_pretrained (
261
- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
279
+ base_model ,
280
+ forced_bos_token_id = self .tokenizer .bos_token_id ,
281
+ device_map = "auto" ,
262
282
)
263
283
elif model_type == "seq2seq_lm" :
264
284
gplus = AutoModelForSeq2SeqLM .from_pretrained (
265
- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
285
+ base_model ,
286
+ forced_bos_token_id = self .tokenizer .bos_token_id ,
287
+ device_map = "auto" ,
266
288
)
267
289
else :
268
290
raise Exception (f"unsupported model type { model_type } " )
@@ -380,6 +402,7 @@ def rephrase(
380
402
model = expert ,
381
403
tokenizer = self .tokenizer ,
382
404
top_k = self .tokenizer .vocab_size ,
405
+ device = self .device ,
383
406
)
384
407
)
385
408
for idx in range (len (masked_sentence_tokens )):
@@ -477,9 +500,10 @@ def compute_mask_logits(
477
500
self , model , sequence , verbose : bool = False , mask : bool = True
478
501
):
479
502
"""Compute mask logits."""
503
+ model .to (self .device )
480
504
if verbose :
481
505
print (f"input sequence: { sequence } " )
482
- subseq_ids = self .tokenizer (sequence , return_tensors = "pt" )
506
+ subseq_ids = self .tokenizer (sequence , return_tensors = "pt" ). to ( self . device )
483
507
if verbose :
484
508
raw_outputs = model .generate (** subseq_ids )
485
509
print (sequence )
@@ -502,9 +526,12 @@ def compute_mask_logits_multiple(
502
526
self , model , sequences , verbose : bool = False , mask : bool = True
503
527
):
504
528
"""Compute mask logits multiple."""
529
+ model .to (self .device )
505
530
if verbose :
506
531
print (f"input sequences: { sequences } " )
507
- subseq_ids = self .tokenizer (sequences , return_tensors = "pt" , padding = True )
532
+ subseq_ids = self .tokenizer (
533
+ sequences , return_tensors = "pt" , padding = True
534
+ ).to (self .device )
508
535
if verbose :
509
536
raw_outputs = model .generate (** subseq_ids )
510
537
print (sequences )
@@ -554,6 +581,7 @@ def score(
554
581
model = model ,
555
582
tokenizer = self .tokenizer ,
556
583
top_k = 10 ,
584
+ device = self .device ,
557
585
)
558
586
for masked_sentence in masked_sentences :
559
587
# approximated probabilities for top_k tokens
@@ -567,7 +595,9 @@ def score(
567
595
js_distances = []
568
596
for distr_pair in distr_pairs :
569
597
js_distance = jensenshannon (
570
- distr_pair [0 ], distr_pair [1 ], axis = 1
598
+ distr_pair [0 ].cpu ().clone ().numpy (),
599
+ distr_pair [1 ].cpu ().clone ().numpy (),
600
+ axis = 1 ,
571
601
)
572
602
if normalize :
573
603
js_distance = js_distance / np .average (js_distance )
@@ -653,7 +683,10 @@ def reflect(
653
683
chat_tokenizer .chat_template = chat_template
654
684
655
685
converse_pipeline = pipeline (
656
- "conversational" , model = chat_model , tokenizer = chat_tokenizer
686
+ "conversational" ,
687
+ model = chat_model ,
688
+ tokenizer = chat_tokenizer ,
689
+ device = self .device ,
657
690
)
658
691
659
692
for text_id in range (len (texts )):
@@ -729,6 +762,7 @@ def reflect(
729
762
conversation_output = converse_pipeline (
730
763
formatted_messages ,
731
764
pad_token_id = converse_pipeline .tokenizer .eos_token_id ,
765
+ device = self .device ,
732
766
)
733
767
if verbose :
734
768
print (f"chat conversation:\n { conversation_output } " )
0 commit comments