Skip to content

Commit f285999

Browse files
authored
Update bot.py
1 parent dfd8f8b commit f285999

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

inference/bot.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import argparse
1212
import conversation as convo
1313
import retrieval.wikipedia as wp
14-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList
14+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList, BitsAndBytesConfig
1515
from accelerate import infer_auto_device_map, init_empty_weights
1616

1717

@@ -51,14 +51,16 @@ class ChatModel:
5151
def __init__(self, model_name, gpu_id, max_memory, load_in_8bit):
5252
device = torch.device('cuda', gpu_id) # TODO: allow sending to cpu
5353

54+
quantization_config = BitsAndBytesConfig(
55+
load_in_8bit=load_in_8bit,
56+
llm_int8_enable_fp32_cpu_offload=True,
57+
) # config to load in 8-bit if load_in_8bit
58+
5459
# recommended default for devices with > 40 GB VRAM
5560
# load model onto one device
5661
if max_memory is None:
57-
self._model = AutoModelForCausalLM.from_pretrained(
58-
model_name, torch_dtype=torch.float16, device_map="auto", load_in_8bit=load_in_8bit)
59-
if not load_in_8bit:
60-
self._model.to(device) # not supported by load_in_8bit
61-
# load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)
62+
device_map="auto"
63+
6264
else:
6365
config = AutoConfig.from_pretrained(model_name)
6466
# load empty weights
@@ -67,21 +69,24 @@ def __init__(self, model_name, gpu_id, max_memory, load_in_8bit):
6769

6870
model_from_conf.tie_weights()
6971

70-
# create a device_map from max_memory
72+
#create a device_map from max_memory
7173
device_map = infer_auto_device_map(
7274
model_from_conf,
7375
max_memory=max_memory,
7476
no_split_module_classes=["GPTNeoXLayer"],
75-
dtype="float16"
76-
)
77-
# load the model with the above device_map
78-
self._model = AutoModelForCausalLM.from_pretrained(
79-
model_name,
80-
device_map=device_map,
81-
offload_folder="offload", # optional offload-to-disk overflow directory (auto-created)
82-
offload_state_dict=True,
83-
torch_dtype=torch.float16
77+
dtype="float16",
8478
)
79+
80+
self._model = AutoModelForCausalLM.from_pretrained(
81+
model_name,
82+
torch_dtype=torch.float16,
83+
device_map=device_map,
84+
offload_folder="offload",
85+
quantization_config=quantization_config,
86+
)
87+
if not load_in_8bit:
88+
self._model.to(device) # not supported by load_in_8bit
89+
8590
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
8691

8792
def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None):

0 commit comments

Comments
 (0)