@@ -1930,23 +1930,25 @@ def post_init(self):
1930
1930
)
1931
1931
1932
1932
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
1933
- if self .base_model is self :
1934
- self ._pp_plan = (
1935
- self .config .base_model_pp_plan .copy () if self .config .base_model_pp_plan is not None else None
1936
- )
1937
- self ._tp_plan = self .config .base_model_tp_plan .copy () if self .config .base_model_tp_plan is not None else {}
1938
- else :
1939
- self ._tp_plan = self ._tp_plan or {}
1940
- for name , module in self .named_children ():
1941
- if plan := getattr (module , "_tp_plan" , None ):
1942
- self ._tp_plan .update ({f"{ name } .{ k } " : v for k , v in plan .items ()})
1933
+ self ._pp_plan = (
1934
+ self .config .base_model_pp_plan .copy () if self .config .base_model_pp_plan is not None else None
1935
+ )
1936
+ self ._tp_plan = self .config .base_model_tp_plan .copy () if self .config .base_model_tp_plan is not None else {}
1937
+ for name , module in self .named_children ():
1938
+ if plan := getattr (module , "_tp_plan" , None ):
1939
+ self ._tp_plan .update ({f"{ name } .{ k } " : v for k , v in plan .copy ().items ()})
1943
1940
1944
1941
if self ._tp_plan is not None and is_torch_greater_or_equal ("2.3" ):
1945
- for _ , v in self ._tp_plan .items ():
1942
+ unique_names = {re .sub (r"\d+" , "*" , name ) for name , _ in self .named_children () if len (name ) > 0 }
1943
+ for k , v in self ._tp_plan .items ():
1946
1944
if v not in SUPPORTED_TP_STYLES :
1947
1945
raise ValueError (
1948
1946
f"Unsupported tensor parallel style { v } . Supported styles are { SUPPORTED_TP_STYLES } "
1949
1947
)
1948
+ if k not in unique_names :
1949
+ raise ValueError (
1950
+ f"Unsupported tensor parallel mapping: { k } is not part of the model"
1951
+ )
1950
1952
1951
1953
def dequantize (self ):
1952
1954
"""
@@ -5819,10 +5821,10 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
5819
5821
generic_name = re .sub (r"\.\d+\." , ".*." , param_name )
5820
5822
param_byte_count //= torch .distributed .get_world_size () if tp_plan_regex .search (generic_name ) else 1
5821
5823
5822
- parameter_count [device ] += param_byte_count
5824
+ total_byte_count [device ] += param_byte_count
5823
5825
5824
5826
# This will kick off the caching allocator to avoid having to Malloc afterwards
5825
- for device , byte_count in parameter_count .items ():
5827
+ for device , byte_count in total_byte_count .items ():
5826
5828
if device .type == "cuda" :
5827
5829
index = device .index if device .index is not None else torch .cuda .current_device ()
5828
5830
device_memory = torch .cuda .mem_get_info (index )[0 ]
0 commit comments