Skip to content

Commit 0cfbf9c

Browse files
authored
Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)
* fix all main files * fix test files * oups forgot modular * add link * update message
1 parent eefc86a commit 0cfbf9c

24 files changed

+88
-9
lines changed

src/transformers/data/datasets/glue.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torch.utils.data import Dataset
2525

2626
from ...tokenization_utils_base import PreTrainedTokenizerBase
27-
from ...utils import logging
27+
from ...utils import check_torch_load_is_safe, logging
2828
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
2929
from ..processors.utils import InputFeatures
3030

@@ -122,6 +122,7 @@ def __init__(
122122
with FileLock(lock_path):
123123
if os.path.exists(cached_features_file) and not args.overwrite_cache:
124124
start = time.time()
125+
check_torch_load_is_safe()
125126
self.features = torch.load(cached_features_file, weights_only=True)
126127
logger.info(
127128
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start

src/transformers/data/datasets/squad.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
2626
from ...tokenization_utils import PreTrainedTokenizer
27-
from ...utils import logging
27+
from ...utils import check_torch_load_is_safe, logging
2828
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
2929

3030

@@ -148,6 +148,7 @@ def __init__(
148148
with FileLock(lock_path):
149149
if os.path.exists(cached_features_file) and not args.overwrite_cache:
150150
start = time.time()
151+
check_torch_load_is_safe()
151152
self.old_features = torch.load(cached_features_file, weights_only=True)
152153

153154
# Legacy cache files have only features, while new cache files

src/transformers/modeling_flax_pytorch_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import transformers
2828

2929
from . import is_safetensors_available, is_torch_available
30-
from .utils import logging
30+
from .utils import check_torch_load_is_safe, logging
3131

3232

3333
if is_torch_available():
@@ -71,6 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
7171
)
7272
raise
7373

74+
check_torch_load_is_safe()
7475
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
7576
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
7677

@@ -247,6 +248,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
247248
flax_state_dict = {}
248249
for shard_file in shard_filenames:
249250
# load using msgpack utils
251+
check_torch_load_is_safe()
250252
pt_state_dict = torch.load(shard_file, weights_only=True)
251253
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
252254
pt_state_dict = {

src/transformers/modeling_tf_pytorch_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from .utils import (
2323
ExplicitEnum,
24+
check_torch_load_is_safe,
2425
expand_dims,
2526
is_numpy_array,
2627
is_safetensors_available,
@@ -198,6 +199,7 @@ def load_pytorch_checkpoint_in_tf2_model(
198199
if pt_path.endswith(".safetensors"):
199200
state_dict = safe_load_file(pt_path)
200201
else:
202+
check_torch_load_is_safe()
201203
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
202204

203205
pt_state_dict.update(state_dict)

src/transformers/modeling_utils.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
ModelOutput,
9595
PushToHubMixin,
9696
cached_file,
97+
check_torch_load_is_safe,
9798
copy_func,
9899
download_url,
99100
extract_commit_hash,
@@ -445,7 +446,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
445446
error_message += f"\nMissing key(s): {str_unexpected_keys}."
446447
raise RuntimeError(error_message)
447448

448-
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
449+
if load_safe:
450+
loader = safe_load_file
451+
else:
452+
check_torch_load_is_safe()
453+
loader = partial(torch.load, map_location="cpu", weights_only=True)
449454

450455
for shard_file in shard_files:
451456
state_dict = loader(os.path.join(folder, shard_file))
@@ -490,6 +495,7 @@ def load_state_dict(
490495
"""
491496
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
492497
"""
498+
# Use safetensors if possible
493499
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
494500
with safe_open(checkpoint_file, framework="pt") as f:
495501
metadata = f.metadata()
@@ -512,6 +518,9 @@ def load_state_dict(
512518
state_dict[k] = f.get_tensor(k)
513519
return state_dict
514520

521+
# Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe)
522+
if weights_only:
523+
check_torch_load_is_safe()
515524
try:
516525
if map_location is None:
517526
if (

src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ....tokenization_utils import PreTrainedTokenizer
3030
from ....utils import (
3131
cached_file,
32+
check_torch_load_is_safe,
3233
is_sacremoses_available,
3334
is_torch_available,
3435
logging,
@@ -222,6 +223,7 @@ def __init__(
222223
"from a PyTorch pretrained vocabulary, "
223224
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
224225
)
226+
check_torch_load_is_safe()
225227
vocab_dict = torch.load(pretrained_vocab_file, weights_only=True)
226228

227229
if vocab_dict is not None:
@@ -705,6 +707,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs,
705707

706708
# Instantiate tokenizer.
707709
corpus = cls(*inputs, **kwargs)
710+
check_torch_load_is_safe()
708711
corpus_dict = torch.load(resolved_corpus_file, weights_only=True)
709712
for key, value in corpus_dict.items():
710713
corpus.__dict__[key] = value
@@ -784,6 +787,7 @@ def get_lm_corpus(datadir, dataset):
784787
fn_pickle = os.path.join(datadir, "cache.pkl")
785788
if os.path.exists(fn):
786789
logger.info("Loading cached dataset...")
790+
check_torch_load_is_safe()
787791
corpus = torch.load(fn_pickle, weights_only=True)
788792
elif os.path.exists(fn):
789793
logger.info("Loading cached dataset from pickle...")

src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from transformers import AutoTokenizer, GPT2Config
2828
from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
29+
from transformers.utils import check_torch_load_is_safe
2930

3031

3132
def add_checkpointing_args(parser):
@@ -275,6 +276,7 @@ def merge_transformers_sharded_states(path, num_checkpoints):
275276
state_dict = {}
276277
for i in range(1, num_checkpoints + 1):
277278
checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin")
279+
check_torch_load_is_safe()
278280
current_chunk = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
279281
state_dict.update(current_chunk)
280282
return state_dict
@@ -298,6 +300,7 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank):
298300
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
299301
if os.path.isfile(checkpoint_path):
300302
break
303+
check_torch_load_is_safe()
301304
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
302305
tp_state_dicts.append(state_dict)
303306
return tp_state_dicts
@@ -338,6 +341,7 @@ def convert_checkpoint_from_megatron_to_transformers(args):
338341
rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name)
339342
break
340343
print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}")
344+
check_torch_load_is_safe()
341345
state_dict = torch.load(rank0_checkpoint_path, map_location="cpu", weights_only=True)
342346
megatron_args = state_dict.get("args", None)
343347
if megatron_args is None:
@@ -634,6 +638,7 @@ def convert_checkpoint_from_transformers_to_megatron(args):
634638
sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")]
635639
if len(sub_dirs) == 1:
636640
checkpoint_name = "pytorch_model.bin"
641+
check_torch_load_is_safe()
637642
state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu", weights_only=True)
638643
else:
639644
num_checkpoints = len(sub_dirs) - 1

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from ...utils import (
4242
add_start_docstrings,
4343
add_start_docstrings_to_model_forward,
44+
check_torch_load_is_safe,
4445
is_flash_attn_2_available,
4546
is_flash_attn_greater_or_equal_2_10,
4647
is_torch_flex_attn_available,
@@ -4391,7 +4392,8 @@ def enable_talker(self):
43914392
self.has_talker = True
43924393

43934394
def load_speakers(self, path):
4394-
for key, value in torch.load(path).items():
4395+
check_torch_load_is_safe()
4396+
for key, value in torch.load(path, weights_only=True).items():
43954397
self.speaker_map[key] = value
43964398
logger.info("Speaker {} loaded".format(list(self.speaker_map.keys())))
43974399

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from ...utils import (
5050
add_start_docstrings,
5151
add_start_docstrings_to_model_forward,
52+
check_torch_load_is_safe,
5253
is_flash_attn_2_available,
5354
is_flash_attn_greater_or_equal_2_10,
5455
logging,
@@ -4078,7 +4079,8 @@ def enable_talker(self):
40784079
self.has_talker = True
40794080

40804081
def load_speakers(self, path):
4081-
for key, value in torch.load(path).items():
4082+
check_torch_load_is_safe()
4083+
for key, value in torch.load(path, weights_only=True).items():
40824084
self.speaker_map[key] = value
40834085
logger.info("Speaker {} loaded".format(list(self.speaker_map.keys())))
40844086

src/transformers/models/wav2vec2/modeling_wav2vec2.py

+5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
add_start_docstrings,
4646
add_start_docstrings_to_model_forward,
4747
cached_file,
48+
check_torch_load_is_safe,
4849
is_peft_available,
4950
is_safetensors_available,
5051
logging,
@@ -1589,6 +1590,7 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs):
15891590
cache_dir=cache_dir,
15901591
)
15911592

1593+
check_torch_load_is_safe()
15921594
state_dict = torch.load(
15931595
weight_path,
15941596
map_location="cpu",
@@ -1600,6 +1602,9 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs):
16001602
# to the original exception.
16011603
raise
16021604

1605+
except ValueError:
1606+
raise
1607+
16031608
except Exception:
16041609
# For any other exception, we throw a generic error.
16051610
raise EnvironmentError(

src/transformers/trainer.py

+13
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
PushInProgress,
148148
PushToHubMixin,
149149
can_return_loss,
150+
check_torch_load_is_safe,
150151
find_labels,
151152
is_accelerate_available,
152153
is_apex_available,
@@ -2831,6 +2832,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
28312832
logger.warning(
28322833
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
28332834
)
2835+
check_torch_load_is_safe()
28342836
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
28352837
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
28362838
state_dict["_smp_is_partial"] = False
@@ -2850,6 +2852,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
28502852
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
28512853
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
28522854
else:
2855+
check_torch_load_is_safe()
28532856
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
28542857

28552858
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
@@ -2944,6 +2947,7 @@ def _load_best_model(self):
29442947
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
29452948
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
29462949
else:
2950+
check_torch_load_is_safe()
29472951
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
29482952

29492953
state_dict["_smp_is_partial"] = False
@@ -2999,6 +3003,7 @@ def _load_best_model(self):
29993003
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
30003004
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
30013005
else:
3006+
check_torch_load_is_safe()
30023007
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
30033008

30043009
# If the model is on the GPU, it still works!
@@ -3354,6 +3359,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
33543359
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
33553360
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
33563361
with warnings.catch_warnings(record=True) as caught_warnings:
3362+
check_torch_load_is_safe()
33573363
self.lr_scheduler.load_state_dict(
33583364
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
33593365
)
@@ -3386,6 +3392,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
33863392
if is_torch_xla_available():
33873393
# On TPU we have to take some extra precautions to properly load the states on the right device.
33883394
if self.is_fsdp_xla_v1_enabled:
3395+
check_torch_load_is_safe()
33893396
optimizer_state = torch.load(
33903397
os.path.join(
33913398
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
@@ -3396,10 +3403,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
33963403
# We only need `optimizer` when resuming from checkpoint
33973404
optimizer_state = optimizer_state["optimizer"]
33983405
else:
3406+
check_torch_load_is_safe()
33993407
optimizer_state = torch.load(
34003408
os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
34013409
)
34023410
with warnings.catch_warnings(record=True) as caught_warnings:
3411+
check_torch_load_is_safe()
34033412
lr_scheduler_state = torch.load(
34043413
os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True
34053414
)
@@ -3443,12 +3452,14 @@ def opt_load_hook(mod, opt):
34433452
**_get_fsdp_ckpt_kwargs(),
34443453
)
34453454
else:
3455+
check_torch_load_is_safe()
34463456
self.optimizer.load_state_dict(
34473457
torch.load(
34483458
os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True
34493459
)
34503460
)
34513461
with warnings.catch_warnings(record=True) as caught_warnings:
3462+
check_torch_load_is_safe()
34523463
self.lr_scheduler.load_state_dict(
34533464
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
34543465
)
@@ -3486,6 +3497,7 @@ def _load_scaler(self, checkpoint):
34863497
# Load in scaler states
34873498
if is_torch_xla_available():
34883499
with warnings.catch_warnings(record=True) as caught_warnings:
3500+
check_torch_load_is_safe()
34893501
scaler_state = torch.load(
34903502
os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
34913503
)
@@ -3494,6 +3506,7 @@ def _load_scaler(self, checkpoint):
34943506
self.accelerator.scaler.load_state_dict(scaler_state)
34953507
else:
34963508
with warnings.catch_warnings(record=True) as caught_warnings:
3509+
check_torch_load_is_safe()
34973510
self.accelerator.scaler.load_state_dict(
34983511
torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
34993512
)

src/transformers/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
OptionalDependencyNotAvailable,
116116
_LazyModule,
117117
ccl_version,
118+
check_torch_load_is_safe,
118119
direct_transformers_import,
119120
get_torch_version,
120121
is_accelerate_available,

src/transformers/utils/import_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,16 @@ def is_rich_available():
13871387
return _rich_available
13881388

13891389

1390+
def check_torch_load_is_safe():
1391+
if not is_torch_greater_or_equal("2.6"):
1392+
raise ValueError(
1393+
"Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users "
1394+
"to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply "
1395+
"when loading files with safetensors."
1396+
"\nSee the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434"
1397+
)
1398+
1399+
13901400
# docstyle-ignore
13911401
AV_IMPORT_ERROR = """
13921402
{0} requires the PyAv library but it was not found in your environment. You can install it with:

tests/models/autoformer/test_modeling_autoformer.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from transformers import is_torch_available
2323
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
24+
from transformers.utils import check_torch_load_is_safe
2425

2526
from ...test_configuration_common import ConfigTester
2627
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -414,6 +415,7 @@ def test_model_get_set_embeddings(self):
414415

415416
def prepare_batch(filename="train-batch.pt"):
416417
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
418+
check_torch_load_is_safe()
417419
batch = torch.load(file, map_location=torch_device, weights_only=True)
418420
return batch
419421

0 commit comments

Comments
 (0)