diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 9e69bd6a668b..798980e86b5e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -995,7 +995,8 @@ def main(args): if args.enable_npu_flash_attention: if is_torch_npu_available(): logger.info("npu flash attention enabled.") - transformer.enable_npu_flash_attention() + for block in transformer.transformer_blocks: + block.attn2.set_use_npu_flash_attention(True) else: raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")