Skip to content

Commit 2902839

Browse files
authored
Merge pull request #54 from huggingface/fix-tp-pipeline
Fix tp pipeline
2 parents fb495fd + 83282a1 commit 2902839

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

src/transformers/models/auto/tokenization_auto.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,20 @@
292292
"LlamaTokenizerFast" if is_tokenizers_available() else None,
293293
),
294294
),
295-
("llama4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
295+
(
296+
"llama4",
297+
(
298+
"LlamaTokenizer" if is_sentencepiece_available() else None,
299+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
300+
),
301+
),
302+
(
303+
"llama4_text",
304+
(
305+
"LlamaTokenizer" if is_sentencepiece_available() else None,
306+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
307+
),
308+
),
296309
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
297310
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
298311
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,10 @@ def _update_causal_mask(
746746
):
747747
if self.config._attn_implementation == "flash_attention_2":
748748
if attention_mask is not None and (attention_mask == 0.0).any():
749-
return attention_mask, attention_mask # flash does not support chunked attn
749+
return attention_mask, attention_mask # flash does not support chunked attn TODO support flash
750+
return None, None
751+
752+
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
750753
return None, None
751754

752755
sequence_length = input_tensor.shape[1]
@@ -808,8 +811,8 @@ def _update_causal_mask(
808811
if sequence_length == 1:
809812
chunked_attention_mask = chunked_attention_mask[-1:]
810813
if self.config._attn_implementation == "eager":
811-
chunked_attention_mask = chunked_attention_mask[None,None,:,:].to(dtype).masked_fill(
812-
chunked_attention_mask, min_dtype
814+
chunked_attention_mask = (
815+
chunked_attention_mask[None, None, :, :].to(dtype).masked_fill(chunked_attention_mask, min_dtype)
813816
)
814817

815818
if (

src/transformers/pipelines/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,8 @@ def __init__(
981981
else:
982982
self.device = device if device is not None else -1
983983

984+
if torch.distributed.is_initialized():
985+
self.device = self.model.device
984986
logger.warning(f"Device set to use {self.device}")
985987

986988
self.binary_output = binary_output

src/transformers/quantizers/base.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,26 +294,35 @@ def is_serializable(self, safe_serialization=None): ...
294294
@property
295295
@abstractmethod
296296
def is_trainable(self): ...
297-
297+
298298
def _convert_model_for_quantization(self, model):
299299
from accelerate import init_empty_weights
300+
300301
for name, module in model.named_modules():
301302
module_class_name = module.__class__.__name__
302-
if module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION.keys() and self.quantization_config.quant_method == QuantizationMethod.COMPRESSED_TENSORS:
303+
if (
304+
module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION.keys()
305+
and self.quantization_config.quant_method == QuantizationMethod.COMPRESSED_TENSORS
306+
):
303307
with init_empty_weights():
304308
parent_module, name = get_module_from_name(model, name)
305-
parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name](model.config.get_text_config())
309+
parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name](
310+
model.config.get_text_config()
311+
)
312+
306313

307314
class SequentialLlama4TextExperts(torch.nn.ModuleList):
308315
"""
309316
A module that implements a compressed version of a list of expert modules.
310317
This is specifically designed to work with Llama4TextExperts in MoE layers.
311318
"""
319+
312320
def __init__(self, config):
313321
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
322+
314323
super().__init__([Llama4TextMLP(config) for _ in range(config.num_local_experts)])
315324
self.num_experts = config.num_local_experts
316-
325+
317326
def forward(
318327
self,
319328
hidden_states: torch.Tensor,
@@ -324,4 +333,5 @@ def forward(
324333
routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx])
325334
return routed_out
326335

327-
MODULES_TO_PATCH_FOR_QUANTIZATION = { "Llama4TextExperts": SequentialLlama4TextExperts }
336+
337+
MODULES_TO_PATCH_FOR_QUANTIZATION = {"Llama4TextExperts": SequentialLlama4TextExperts}

0 commit comments

Comments
 (0)