11
11
import argparse
12
12
import conversation as convo
13
13
import retrieval .wikipedia as wp
14
- from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig , StoppingCriteria , StoppingCriteriaList
14
+ from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig , StoppingCriteria , StoppingCriteriaList , BitsAndBytesConfig
15
15
from accelerate import infer_auto_device_map , init_empty_weights
16
16
17
17
@@ -51,14 +51,16 @@ class ChatModel:
51
51
def __init__ (self , model_name , gpu_id , max_memory , load_in_8bit ):
52
52
device = torch .device ('cuda' , gpu_id ) # TODO: allow sending to cpu
53
53
54
+ quantization_config = BitsAndBytesConfig (
55
+ load_in_8bit = load_in_8bit ,
56
+ llm_int8_enable_fp32_cpu_offload = True ,
57
+ ) # config to load in 8-bit if load_in_8bit
58
+
54
59
# recommended default for devices with > 40 GB VRAM
55
60
# load model onto one device
56
61
if max_memory is None :
57
- self ._model = AutoModelForCausalLM .from_pretrained (
58
- model_name , torch_dtype = torch .float16 , device_map = "auto" , load_in_8bit = load_in_8bit )
59
- if not load_in_8bit :
60
- self ._model .to (device ) # not supported by load_in_8bit
61
- # load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)
62
+ device_map = "auto"
63
+
62
64
else :
63
65
config = AutoConfig .from_pretrained (model_name )
64
66
# load empty weights
@@ -67,21 +69,24 @@ def __init__(self, model_name, gpu_id, max_memory, load_in_8bit):
67
69
68
70
model_from_conf .tie_weights ()
69
71
70
- # create a device_map from max_memory
72
+ #create a device_map from max_memory
71
73
device_map = infer_auto_device_map (
72
74
model_from_conf ,
73
75
max_memory = max_memory ,
74
76
no_split_module_classes = ["GPTNeoXLayer" ],
75
- dtype = "float16"
76
- )
77
- # load the model with the above device_map
78
- self ._model = AutoModelForCausalLM .from_pretrained (
79
- model_name ,
80
- device_map = device_map ,
81
- offload_folder = "offload" , # optional offload-to-disk overflow directory (auto-created)
82
- offload_state_dict = True ,
83
- torch_dtype = torch .float16
77
+ dtype = "float16" ,
84
78
)
79
+
80
+ self ._model = AutoModelForCausalLM .from_pretrained (
81
+ model_name ,
82
+ torch_dtype = torch .float16 ,
83
+ device_map = device_map ,
84
+ offload_folder = "offload" ,
85
+ quantization_config = quantization_config ,
86
+ )
87
+ if not load_in_8bit :
88
+ self ._model .to (device ) # not supported by load_in_8bit
89
+
85
90
self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
86
91
87
92
def do_inference (self , prompt , max_new_tokens , do_sample , temperature , top_k , stream_callback = None ):
0 commit comments