Skip to content

enable torchao test cases on XPU and switch to device agnostic APIs for test cases #11654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]]
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
Expand Down Expand Up @@ -645,7 +645,7 @@ def generate_fpx_quantization_types(bits: int):
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)

if cls._is_cuda_capability_atleast_8_9():
if cls._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)

return QUANTIZATION_TYPES
Expand All @@ -655,14 +655,16 @@ def generate_fpx_quantization_types(bits: int):
)

@staticmethod
def _is_cuda_capability_atleast_8_9() -> bool:
if not torch.cuda.is_available():
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")

major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
elif torch.xpu.is_available():
return True
else:
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only check device capability when it's cuda; for non-cuda device, should check in separate utilities. In this case, non-cuda device(like XPU)'s case will be skipped by original implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably still raise an error if torchao is being used with mps or other devices, otherwise it leads to an obscure error somewhere deep in the code that common users will not understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w , i enhanced this utility per your comments, pls help review again, thx.

def get_apply_tensor_subclass(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,7 @@ def require_torch_gpu(test_case):

def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if not torch.cuda.is_available():
return unittest.skip(test_case)
else:
if torch.cuda.is_available():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only check cuda device compatibility, if non-cuda device, just pass. For non-cuda device which needs compatibility, should check by themselves.

current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
slow,
Expand Down Expand Up @@ -162,13 +163,13 @@ def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

@torch.no_grad()
def test_encode_decode(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/models/unets/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from diffusers import UNet2DModel
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
Expand Down Expand Up @@ -229,7 +230,7 @@ def test_from_pretrained_accelerate_wont_change_results(self):

# two models don't need to stay in the device at the same time
del model_accelerate
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
gc.collect()

model_normal_load, _ = UNet2DModel.from_pretrained(
Expand Down
5 changes: 2 additions & 3 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
Expand Down Expand Up @@ -980,13 +979,13 @@ def test_ip_adapter_plus(self):
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)

@require_torch_gpu
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
Expand All @@ -996,13 +995,13 @@ def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
Expand Down
5 changes: 3 additions & 2 deletions tests/pipelines/allegro/test_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_hf_hub_version_greater,
Expand Down Expand Up @@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_allegro(self):
generator = torch.Generator("cpu").manual_seed(0)
Expand Down
10 changes: 5 additions & 5 deletions tests/pipelines/audioldm/test_audioldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
UNet2DConditionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device
from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device

from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
Expand Down Expand Up @@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
Expand Down Expand Up @@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
Expand Down
12 changes: 9 additions & 3 deletions tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils.testing_utils import enable_full_determinism, is_torch_version, nightly, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
is_torch_version,
nightly,
torch_device,
)

from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
Expand Down Expand Up @@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
Expand Down
5 changes: 3 additions & 2 deletions tests/pipelines/cogvideo/test_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
Expand Down Expand Up @@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0)
Expand Down
5 changes: 3 additions & 2 deletions tests/pipelines/cogview3/test_cogview3plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
Expand Down Expand Up @@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_cogview3plus(self):
generator = torch.Generator("cpu").manual_seed(0)
Expand Down
5 changes: 3 additions & 2 deletions tests/pipelines/controlnet/test_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
Expand Down Expand Up @@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
Expand Down
5 changes: 3 additions & 2 deletions tests/pipelines/controlnet/test_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
Expand Down Expand Up @@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_xformers_attention_forwardGenerator_pass(self):

@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline

Expand Down
3 changes: 2 additions & 1 deletion tests/pipelines/deepfloyd_if/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
load_numpy,
Expand Down Expand Up @@ -135,7 +136,7 @@ def test_if_text_to_image(self):

image = output.images[0]

mem_bytes = torch.cuda.max_memory_allocated()
mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9

expected_image = load_numpy(
Expand Down
3 changes: 2 additions & 1 deletion tests/pipelines/deepfloyd_if/test_if_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
floats_tensor,
Expand Down Expand Up @@ -151,7 +152,7 @@ def test_if_img2img(self):
)
image = output.images[0]

mem_bytes = torch.cuda.max_memory_allocated()
mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9

expected_image = load_numpy(
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/flux/test_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_flux_true_cfg(self):

@nightly
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
Expand Down Expand Up @@ -312,7 +312,7 @@ def test_flux_inference(self):

@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/flux/test_pipeline_flux_redux.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
repo_id = "black-forest-labs/FLUX.1-Redux-dev"
Expand Down
Loading