Skip to content

Commit c358a1b

Browse files
committed
fix post merge with main
1 parent ec85fa3 commit c358a1b

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def shard_and_distribute_module(
545545
)
546546
else:
547547
# TODO log no plan modules in set
548+
print("No plan for", parameter_name,end ="\r")
548549
param = param[...].to(param_casting_dtype)
549550
if is_contiguous:
550551
param = param.contiguous()

src/transformers/modeling_utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,23 +1930,25 @@ def post_init(self):
19301930
)
19311931

19321932
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
1933-
if self.base_model is self:
1934-
self._pp_plan = (
1935-
self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
1936-
)
1937-
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
1938-
else:
1939-
self._tp_plan = self._tp_plan or {}
1940-
for name, module in self.named_children():
1941-
if plan := getattr(module, "_tp_plan", None):
1942-
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
1933+
self._pp_plan = (
1934+
self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
1935+
)
1936+
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
1937+
for name, module in self.named_children():
1938+
if plan := getattr(module, "_tp_plan", None):
1939+
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
19431940

19441941
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
1945-
for _, v in self._tp_plan.items():
1942+
unique_names = {re.sub(r"\d+", "*", name) for name, _ in self.named_children() if len(name) > 0}
1943+
for k, v in self._tp_plan.items():
19461944
if v not in SUPPORTED_TP_STYLES:
19471945
raise ValueError(
19481946
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
19491947
)
1948+
if k not in unique_names:
1949+
raise ValueError(
1950+
f"Unsupported tensor parallel mapping: {k} is not part of the model"
1951+
)
19501952

19511953
def dequantize(self):
19521954
"""
@@ -5819,10 +5821,10 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
58195821
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
58205822
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
58215823

5822-
parameter_count[device] += param_byte_count
5824+
total_byte_count[device] += param_byte_count
58235825

58245826
# This will kick off the caching allocator to avoid having to Malloc afterwards
5825-
for device, byte_count in parameter_count.items():
5827+
for device, byte_count in total_byte_count.items():
58265828
if device.type == "cuda":
58275829
index = device.index if device.index is not None else torch.cuda.current_device()
58285830
device_memory = torch.cuda.mem_get_info(index)[0]

src/transformers/models/llama4/convert_llama4_weights_to_hf.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525

2626
torch.serialization.add_safe_globals([io.BytesIO])
2727
# fmt: off
28-
29-
# layers.29.feed_forward.model.norm.weight
30-
# layers.30.attention.wqkv.layer_model.norm.weight
31-
# Still not sure what to do with those!
3228
# `None` means we drop the key
3329

3430

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,6 +1546,7 @@ class Llama4PreTrainedModel(PreTrainedModel):
15461546
_skip_keys_device_placement = "past_key_values"
15471547
_supports_cache_class = True
15481548
_supports_flash_attn_2 = True
1549+
_supports_flex_attn = True
15491550
_supports_sdpa = True
15501551
_supports_quantized_cache = True
15511552
_supports_static_cache = True

0 commit comments

Comments
 (0)