Skip to content

Commit 12f65ee

Browse files
yao-matrixydshieh
andauthored
enable cpu offloading for Bark on xpu (#37599)
* enable cpu offloading of bark modeling on XPU Signed-off-by: YAO Matrix <[email protected]> * remove debug print Signed-off-by: YAO Matrix <[email protected]> * fix style Signed-off-by: YAO Matrix <[email protected]> * fix review comments Signed-off-by: YAO Matrix <[email protected]> * enhance test Signed-off-by: YAO Matrix <[email protected]> * update * add deprecate message Signed-off-by: YAO Matrix <[email protected]> * update * update * trigger CI --------- Signed-off-by: YAO Matrix <[email protected]> Co-authored-by: ydshieh <[email protected]>
1 parent 4f9893c commit 12f65ee

File tree

4 files changed

+55
-17
lines changed

4 files changed

+55
-17
lines changed

src/transformers/models/bark/modeling_bark.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""PyTorch BARK model."""
1616

1717
import math
18+
import warnings
1819
from typing import Dict, Optional, Tuple, Union
1920

2021
import numpy as np
@@ -36,6 +37,7 @@
3637
add_start_docstrings,
3738
add_start_docstrings_to_model_forward,
3839
is_accelerate_available,
40+
is_torch_accelerator_available,
3941
logging,
4042
)
4143
from ..auto import AutoModel
@@ -1598,26 +1600,45 @@ def device(self) -> torch.device:
15981600
):
15991601
return torch.device(module._hf_hook.execution_device)
16001602

1601-
def enable_cpu_offload(self, gpu_id: Optional[int] = 0):
1603+
def enable_cpu_offload(
1604+
self,
1605+
accelerator_id: Optional[int] = 0,
1606+
**kwargs,
1607+
):
16021608
r"""
16031609
Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
1604-
method moves one whole sub-model at a time to the GPU when it is used, and the sub-model remains in GPU until
1605-
the next sub-model runs.
1610+
method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs.
16061611
16071612
Args:
1608-
gpu_id (`int`, *optional*, defaults to 0):
1609-
GPU id on which the sub-models will be loaded and offloaded.
1613+
accelerator_id (`int`, *optional*, defaults to 0):
1614+
accelerator id on which the sub-models will be loaded and offloaded. This argument is deprecated.
1615+
kwargs (`dict`, *optional*):
1616+
additional keyword arguments:
1617+
`gpu_id`: accelerator id on which the sub-models will be loaded and offloaded.
16101618
"""
16111619
if is_accelerate_available():
16121620
from accelerate import cpu_offload_with_hook
16131621
else:
16141622
raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
16151623

1616-
device = torch.device(f"cuda:{gpu_id}")
1624+
gpu_id = kwargs.get("gpu_id", 0)
1625+
1626+
if gpu_id != 0:
1627+
warnings.warn(
1628+
"The argument `gpu_id` is deprecated and will be removed in version 4.54.0 of Transformers. Please use `accelerator_id` instead.",
1629+
FutureWarning,
1630+
)
1631+
accelerator_id = gpu_id
1632+
1633+
device_type = "cuda"
1634+
if is_torch_accelerator_available():
1635+
device_type = torch.accelerator.current_accelerator().type
1636+
device = torch.device(f"{device_type}:{accelerator_id}")
16171637

1638+
torch_accelerator_module = getattr(torch, device_type)
16181639
if self.device.type != "cpu":
16191640
self.to("cpu")
1620-
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1641+
torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
16211642

16221643
# this layer is used outside the first foward pass of semantic so need to be loaded before semantic
16231644
self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)

src/transformers/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@
211211
is_tiktoken_available,
212212
is_timm_available,
213213
is_tokenizers_available,
214+
is_torch_accelerator_available,
214215
is_torch_available,
215216
is_torch_bf16_available,
216217
is_torch_bf16_available_on_device,

src/transformers/utils/import_utils.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,28 @@ def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
346346
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
347347

348348

349+
def is_torch_accelerator_available():
350+
if is_torch_available():
351+
import torch
352+
353+
return hasattr(torch, "accelerator")
354+
355+
return False
356+
357+
349358
def is_torch_deterministic():
350359
"""
351360
Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
352361
"""
353-
import torch
362+
if is_torch_available():
363+
import torch
354364

355-
if torch.get_deterministic_debug_mode() == 0:
356-
return False
357-
else:
358-
return True
365+
if torch.get_deterministic_debug_mode() == 0:
366+
return False
367+
else:
368+
return True
369+
370+
return False
359371

360372

361373
def is_hadamard_available():

tests/models/bark/test_modeling_bark.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from transformers.testing_utils import (
3737
require_flash_attn,
3838
require_torch,
39+
require_torch_accelerator,
3940
require_torch_fp16,
4041
require_torch_gpu,
4142
slow,
@@ -1056,7 +1057,8 @@ def processor(self):
10561057
def inputs(self):
10571058
input_ids = self.processor("In the light of the moon, a little egg lay on a leaf", voice_preset="en_speaker_6")
10581059

1059-
input_ids = input_ids.to(torch_device)
1060+
for k, v in input_ids.items():
1061+
input_ids[k] = v.to(torch_device)
10601062

10611063
return input_ids
10621064

@@ -1295,7 +1297,7 @@ def test_generate_end_to_end_with_sub_models_args(self):
12951297
len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())
12961298
)
12971299

1298-
@require_torch_gpu
1300+
@require_torch_accelerator
12991301
@slow
13001302
def test_generate_end_to_end_with_offload(self):
13011303
input_ids = self.inputs
@@ -1304,15 +1306,17 @@ def test_generate_end_to_end_with_offload(self):
13041306
# standard generation
13051307
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)
13061308

1307-
torch.cuda.empty_cache()
1309+
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
1310+
1311+
torch_accelerator_module.empty_cache()
13081312

1309-
memory_before_offload = torch.cuda.memory_allocated()
1313+
memory_before_offload = torch_accelerator_module.memory_allocated()
13101314
model_memory_footprint = self.model.get_memory_footprint()
13111315

13121316
# activate cpu offload
13131317
self.model.enable_cpu_offload()
13141318

1315-
memory_after_offload = torch.cuda.memory_allocated()
1319+
memory_after_offload = torch_accelerator_module.memory_allocated()
13161320

13171321
# checks if the model have been offloaded
13181322

0 commit comments

Comments
 (0)