Skip to content

Commit 97abdd2

Browse files
authored
make tensors contiguous before passing to safetensors (#10761)
fix contiguous bug
1 parent 051ebc3 commit 97abdd2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/modeling_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def save_pretrained(
549549
os.remove(full_filename)
550550

551551
for filename, tensors in state_dict_split.filename_to_tensors.items():
552-
shard = {tensor: state_dict[tensor] for tensor in tensors}
552+
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
553553
filepath = os.path.join(save_directory, filename)
554554
if safe_serialization:
555555
# At some point we will need to deal better with save_function (used for TPU and other distributed

0 commit comments

Comments
 (0)