Skip to content

Commit 96faf80

Browse files
enable gpu usage
1 parent 1b8159d commit 96faf80

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

src/trustyai/language/detoxify/tmarco.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
tokenizer=None,
5252
max_length=150,
5353
model_type: str = "causal_lm",
54+
device=None,
5455
):
5556
if expert_weights is None:
5657
expert_weights = [-0.5, 0.5]
@@ -94,6 +95,13 @@ def __init__(
9495
)
9596
self.content_feature = "comment_text"
9697

98+
if isinstance(device, str):
99+
self.device = torch.device(device)
100+
else:
101+
self.device = torch.device(
102+
"cuda" if torch.cuda.is_available() else "cpu"
103+
)
104+
97105
def load_models(self, experts: list, expert_weights: list = None):
98106
"""Load expert models."""
99107
if expert_weights is not None:
@@ -102,7 +110,9 @@ def load_models(self, experts: list, expert_weights: list = None):
102110
for expert in experts:
103111
if isinstance(expert, str):
104112
expert = BartForConditionalGeneration.from_pretrained(
105-
expert, forced_bos_token_id=self.tokenizer.bos_token_id
113+
expert,
114+
forced_bos_token_id=self.tokenizer.bos_token_id,
115+
device_map="auto",
106116
)
107117
expert_models.append(expert)
108118
self.experts = expert_models
@@ -200,15 +210,21 @@ def train_models(
200210

201211
if model_type is None:
202212
gminus = BartForConditionalGeneration.from_pretrained(
203-
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
213+
base_model,
214+
forced_bos_token_id=self.tokenizer.bos_token_id,
215+
device_map="auto",
204216
)
205217
elif model_type == "causal_lm":
206218
gminus = AutoModelForCausalLM.from_pretrained(
207-
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
219+
base_model,
220+
forced_bos_token_id=self.tokenizer.bos_token_id,
221+
device_map="auto",
208222
)
209223
elif model_type == "seq2seq_lm":
210224
gminus = AutoModelForSeq2SeqLM.from_pretrained(
211-
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
225+
base_model,
226+
forced_bos_token_id=self.tokenizer.bos_token_id,
227+
device_map="auto",
212228
)
213229
else:
214230
raise Exception(f"unsupported model type {model_type}")
@@ -254,15 +270,21 @@ def train_models(
254270

255271
if model_type is None:
256272
gplus = BartForConditionalGeneration.from_pretrained(
257-
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
273+
base_model,
274+
forced_bos_token_id=self.tokenizer.bos_token_id,
275+
device_map="auto",
258276
)
259277
elif model_type == "causal_lm":
260278
gplus = AutoModelForCausalLM.from_pretrained(
261-
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
279+
base_model,
280+
forced_bos_token_id=self.tokenizer.bos_token_id,
281+
device_map="auto",
262282
)
263283
elif model_type == "seq2seq_lm":
264284
gplus = AutoModelForSeq2SeqLM.from_pretrained(
265-
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
285+
base_model,
286+
forced_bos_token_id=self.tokenizer.bos_token_id,
287+
device_map="auto",
266288
)
267289
else:
268290
raise Exception(f"unsupported model type {model_type}")
@@ -380,6 +402,7 @@ def rephrase(
380402
model=expert,
381403
tokenizer=self.tokenizer,
382404
top_k=self.tokenizer.vocab_size,
405+
device=self.device,
383406
)
384407
)
385408
for idx in range(len(masked_sentence_tokens)):
@@ -477,9 +500,10 @@ def compute_mask_logits(
477500
self, model, sequence, verbose: bool = False, mask: bool = True
478501
):
479502
"""Compute mask logits."""
503+
model.to(self.device)
480504
if verbose:
481505
print(f"input sequence: {sequence}")
482-
subseq_ids = self.tokenizer(sequence, return_tensors="pt")
506+
subseq_ids = self.tokenizer(sequence, return_tensors="pt").to(self.device)
483507
if verbose:
484508
raw_outputs = model.generate(**subseq_ids)
485509
print(sequence)
@@ -502,9 +526,12 @@ def compute_mask_logits_multiple(
502526
self, model, sequences, verbose: bool = False, mask: bool = True
503527
):
504528
"""Compute mask logits multiple."""
529+
model.to(self.device)
505530
if verbose:
506531
print(f"input sequences: {sequences}")
507-
subseq_ids = self.tokenizer(sequences, return_tensors="pt", padding=True)
532+
subseq_ids = self.tokenizer(
533+
sequences, return_tensors="pt", padding=True
534+
).to(self.device)
508535
if verbose:
509536
raw_outputs = model.generate(**subseq_ids)
510537
print(sequences)
@@ -554,6 +581,7 @@ def score(
554581
model=model,
555582
tokenizer=self.tokenizer,
556583
top_k=10,
584+
device=self.device,
557585
)
558586
for masked_sentence in masked_sentences:
559587
# approximated probabilities for top_k tokens
@@ -567,7 +595,9 @@ def score(
567595
js_distances = []
568596
for distr_pair in distr_pairs:
569597
js_distance = jensenshannon(
570-
distr_pair[0], distr_pair[1], axis=1
598+
distr_pair[0].cpu().clone().numpy(),
599+
distr_pair[1].cpu().clone().numpy(),
600+
axis=1,
571601
)
572602
if normalize:
573603
js_distance = js_distance / np.average(js_distance)
@@ -653,7 +683,10 @@ def reflect(
653683
chat_tokenizer.chat_template = chat_template
654684

655685
converse_pipeline = pipeline(
656-
"conversational", model=chat_model, tokenizer=chat_tokenizer
686+
"conversational",
687+
model=chat_model,
688+
tokenizer=chat_tokenizer,
689+
device=self.device,
657690
)
658691

659692
for text_id in range(len(texts)):
@@ -729,6 +762,7 @@ def reflect(
729762
conversation_output = converse_pipeline(
730763
formatted_messages,
731764
pad_token_id=converse_pipeline.tokenizer.eos_token_id,
765+
device=self.device,
732766
)
733767
if verbose:
734768
print(f"chat conversation:\n{conversation_output}")

0 commit comments

Comments
 (0)