Skip to content

Custom Diffusion: RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Half != float #6879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
rezkanas opened this issue Feb 6, 2024 · 16 comments · Fixed by #9217
Labels
bug Something isn't working training

Comments

@rezkanas
Copy link

rezkanas commented Feb 6, 2024

Describe the bug

when running custom diffusion on my 20 photos repositories ... I run into this error that is related to data type difference..

Reproduction

!accelerate launch train_custom_diffusion.py
--pretrained_model_name_or_path=$MODEL_NAME
--instance_data_dir=$INSTANCE_DIR
--output_dir=$OUTPUT_DIR
--class_data_dir=$class_data_dir
--with_prior_preservation
--prior_loss_weight=1.0
--class_prompt="person"
--num_class_images=200
--instance_prompt="photo of a person"
--resolution=512
--train_batch_size=2
--learning_rate=5e-6
--lr_warmup_steps=0
--max_train_steps=1200
--freeze_model=crossattn
--scale_lr
--hflip
--use_8bit_adam
--gradient_checkpointing
--enable_xformers_memory_efficient_attention
--modifier_token ""
--validation_prompt=" person sitting in a bucket"

Logs

/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)
/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
02/06/2024 23:50:18 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'thresholding', 'variance_type', 'dynamic_thresholding_ratio', 'clip_sample_range', 'sample_max_value', 'timestep_spacing', 'rescale_betas_zero_snr'} was not found in config. Values will be initialized to default values.
{'scaling_factor', 'force_upcast'} was not found in config. Values will be initialized to default values.
{'mid_block_only_cross_attention', 'cross_attention_norm', 'encoder_hid_dim', 'encoder_hid_dim_type', 'reverse_transformer_layers_per_block', 'attention_type', 'time_embedding_act_fn', 'projection_class_embeddings_input_dim', 'time_embedding_dim', 'mid_block_type', 'transformer_layers_per_block', 'class_embed_type', 'conv_out_kernel', 'class_embeddings_concat', 'addition_time_embed_dim', 'addition_embed_type', 'addition_embed_type_num_heads', 'conv_in_kernel', 'time_embedding_type', 'resnet_skip_time_act', 'num_attention_heads', 'resnet_out_scale_factor', 'resnet_time_scale_shift', 'time_cond_proj_dim', 'timestep_post_act', 'dropout'} was not found in config. Values will be initialized to default values.
[42170]
02/06/2024 23:52:45 - INFO - __main__ - ***** Running training *****
02/06/2024 23:52:45 - INFO - __main__ -   Num examples = 200
02/06/2024 23:52:45 - INFO - __main__ -   Num batches each epoch = 100
02/06/2024 23:52:45 - INFO - __main__ -   Num Epochs = 12
02/06/2024 23:52:45 - INFO - __main__ -   Instantaneous batch size per device = 2
02/06/2024 23:52:45 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 2
02/06/2024 23:52:45 - INFO - __main__ -   Gradient Accumulation steps = 1
02/06/2024 23:52:45 - INFO - __main__ -   Total optimization steps = 1200
Steps:   0%|                                           | 0/1200 [00:00<?, ?it/s]/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
Traceback (most recent call last):
  File "/home/anasrezklinux/test_pycharm_link/diffusers/examples/custom_diffusion/train_custom_diffusion.py", line 1350, in <module>
    main(args)
  File "/home/anasrezklinux/test_pycharm_link/diffusers/examples/custom_diffusion/train_custom_diffusion.py", line 1131, in main
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1121, in forward
    sample, res_samples = downsample_block(
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 1189, in forward
    hidden_states = attn(
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 379, in forward
    hidden_states = torch.utils.checkpoint.checkpoint(
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
    ret = function(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 374, in custom_forward
    return module(*inputs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/models/attention.py", line 366, in forward
    attn_output = self.attn2(
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 512, in forward
    return self.processor(
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 1429, in __call__
    query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Half != float
Steps:   0%|                                           | 0/1200 [00:22<?, ?it/s]
Traceback (most recent call last):
  File "/home/anasrezklinux/.local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1017, in launch_command
    simple_launcher(args)
  File "/home/anasrezklinux/.local/lib/python3.10/site-packages/accelerate/commands/launch.py", line 637, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/usr/bin/python3', 'train_custom_diffusion.py', '--pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1', '--instance_data_dir=/mnt/c/Users/noobw/PycharmProjects/pythonProject/Anas', '--output_dir=/mnt/c/Users/noobw/PycharmProjects/pythonProject/custom_diffusion_anas', '--class_data_dir=/mnt/c/Users/noobw/PycharmProjects/pythonProject/custom_diffusion_anas/class_prior', '--with_prior_preservation', '--prior_loss_weight=1.0', '--class_prompt=person', '--num_class_images=200', '--instance_prompt=photo of a <new1> person', '--resolution=512', '--train_batch_size=2', '--learning_rate=5e-6', '--lr_warmup_steps=0', '--max_train_steps=1200', '--freeze_model=crossattn', '--scale_lr', '--hflip', '--use_8bit_adam', '--gradient_checkpointing', '--enable_xformers_memory_efficient_attention', '--modifier_token', '<new1>', '--validation_prompt=<new1> person sitting in a bucket']' returned non-zero exit status 1.

System Info

  • diffusers version: 0.26.1
  • Platform: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.2.0+cu121 (True)
  • Huggingface_hub version: 0.20.3
  • Transformers version: 4.37.0
  • Accelerate version: 0.25.0
  • xFormers version: 0.0.24
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul @patrickvonplaten

@rezkanas rezkanas added the bug Something isn't working label Feb 6, 2024
@sayakpaul
Copy link
Member

It could be a dataloading problem, because I don't see this error when using the example from the README. I would suggest debugging the dataloader.

@rezkanas
Copy link
Author

rezkanas commented Feb 7, 2024

Hi again, I run the code with the 'cat' example from the README file and it works fine. I also adjusted to 'person' and it works.

The problem comes when you activate the --freeze_model crossattn attribute as README file suggests when training on human faces

https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion/README.md#training-on-human-faces

Can you try to reproduce it at yours? I am almost certain bug is related to '--freeze_model crossattn ', Currently, I run below code with no problem...

> '!accelerate launch train_custom_diffusion.py \
> --pretrained_model_name_or_path=$MODEL_NAME  \
> --instance_data_dir=$INSTANCE_DIR \
> --output_dir=$OUTPUT_DIR \
> --class_data_dir=$class_data_dir \
> --with_prior_preservation --prior_loss_weight=1.0 \
> --class_prompt="person" --num_class_images=200 \
> --instance_prompt="photo of a <new1> person"  \
> --resolution=512  \
> --train_batch_size=2  \
> --learning_rate=5e-6  \
> --lr_warmup_steps=0 \
> --max_train_steps=1200 \
> --scale_lr --hflip --noaug \
> --modifier_token "<new1>"\
> --enable_xformers_memory_efficient_attention '

@rezkanas
Copy link
Author

rezkanas commented Feb 7, 2024

when training above, I encountered a issue #4704 where I added --no_safe_serialization attribute ... it works... but then I am encountering now a different error that I could not find in your bug resolution history....

when running this


!
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME  \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--class_data_dir=$class_data_dir \
--with_prior_preservation --prior_loss_weight=1.0 \
--class_prompt="person" --num_class_images=200 \
--instance_prompt="photo of a <new1> person"  \
--resolution=512  \
--train_batch_size=1  \
--learning_rate=5e-6  \
--lr_warmup_steps=0 \
--max_train_steps=1200 \
--scale_lr --hflip --noaug \
--no_safe_serialization\
--num_validation_images=5 \
--validation_prompt="<new1> person sitting in a bucket" \
--modifier_token "<new1>"\
--enable_xformers_memory_efficient_attention 

the model finishes training and breaks only when loading the pipeline components

Loading pipeline components...: 29%|███▋ | 2/7 [00:06<00:17, 3.45s/it]Loaded tokenizer as CLIPTokenizer from tokenizer subfolder of CompVis/stable-diffusion-v1-4.
Traceback (most recent call last):
File "/home/anasrezklinux/.local/bin/accelerate", line 8, in
sys.exit(main())
File "/home/anasrezklinux/.local/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
args.func(args)
File "/home/anasrezklinux/.local/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1017, in launch_command
simple_launcher(args)
File "/home/anasrezklinux/.local/lib/python3.10/site-packages/accelerate/commands/launch.py", line 637, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/usr/bin/python3', 'train_custom_diffusion.py', '--pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4', '--instance_data_dir=/mnt/c/Users/noobw/PycharmProjects/pythonProject/Anas', '--output_dir=/mnt/c/Users/noobw/PycharmProjects/pythonProject/custom_diffusion_anas', '--class_data_dir=/mnt/c/Users/noobw/PycharmProjects/pythonProject/custom_diffusion_anas/class_prior', '--with_prior_preservation', '--prior_loss_weight=1.0', '--class_prompt=person', '--num_class_images=10', '--instance_prompt=photo of a person', '--resolution=512', '--train_batch_size=1', '--learning_rate=5e-6', '--lr_warmup_steps=0', '--max_train_steps=10', '--scale_lr', '--hflip', '--noaug', '--no_safe_serialization', '--num_validation_images=5', '--validation_prompt= cat sitting in a bucket', '--modifier_token', '', '--enable_xformers_memory_efficient_attention']' died with <Signals.SIGKILL: 9>.

@sayakpaul
Copy link
Member

The stack trace isn't informative enough. There is nothing there suggesting it's coming from the diffusers codebase.

@daeunni
Copy link

daeunni commented Feb 14, 2024

@rezkanas @sayakpaul Hi guys, did u address this error? I have the same issue

@rezkanas
Copy link
Author

rezkanas commented Feb 19, 2024

I resolve it as per #6879 (comment)
but I encounter another problem, #6879 (comment)
can you help @yiyixuxu @sayakpaul ??

@sayakpaul
Copy link
Member

sayakpaul commented Feb 19, 2024

Will try to reproduce.

@rezkanas
Copy link
Author

I would still appreciate if I can train the model with the attribute --freeze_model crossattn activate since I am training faces :)

@sayakpaul
Copy link
Member

Hi.

I just tried with the commands provided in the README of the example and I didn't run into any problems, whatsoever. I am on PyTorch 2.2, diffusers installed from main.

@rezkanas
Copy link
Author

I have also Pytorch 2.2 and installed diffusers from main. do you run it with --freeze_model crossattn activated?

@shinnosukeono
Copy link

Hi, I encountered a similar problem. I tried train_custom_diffusion.py with the cat dataset, and I got an error RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half.

Reproduction:

accelerate launch train_custom_diffusion.py \
    --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
    --instance_data_dir="data/cat" \
    --output_dir="tutorial_output" \
    --instance_prompt="photo of a <new1> cat" \
    --resolution=512 \
    --train_batch_size=1 \
    --learning_rate=5e-6 \
    --lr_warmup_steps=0 \
    --max_train_steps=250 \
    --scale_lr \
    --hflip \
    --modifier_token="<new1>" \
    --no_safe_serialization \
    --validation_prompt="<new1> cat sitting in a bucket"

Logs:

03/06/2024 15:18:04 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'thresholding', 'rescale_betas_zero_snr', 'timestep_spacing', 'dynamic_thresholding_ratio', 'variance_type', 'prediction_type', 'sample_max_value', 'clip_sample_range'} was not found in config. Values will be initialized to default values.
{'force_upcast', 'norm_num_groups', 'latents_std', 'latents_mean'} was not found in config. Values will be initialized to default values.
/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/diffusers/models/lora.py:300: FutureWarning: `LoRACompatibleConv` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.
  deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/diffusers/models/lora.py:387: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.
  deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
{'transformer_layers_per_block', 'time_embedding_dim', 'resnet_skip_time_act', 'time_embedding_act_fn', 'resnet_out_scale_factor', 'timestep_post_act', 'projection_class_embeddings_input_dim', 'dual_cross_attention', 'attention_type', 'time_embedding_type', 'conv_out_kernel', 'encoder_hid_dim', 'resnet_time_scale_shift', 'class_embeddings_concat', 'addition_time_embed_dim', 'dropout', 'addition_embed_type', 'mid_block_type', 'use_linear_projection', 'class_embed_type', 'cross_attention_norm', 'num_class_embeds', 'reverse_transformer_layers_per_block', 'addition_embed_type_num_heads', 'num_attention_heads', 'upcast_attention', 'time_cond_proj_dim', 'mid_block_only_cross_attention', 'only_cross_attention', 'conv_in_kernel', 'encoder_hid_dim_type'} was not found in config. Values will be initialized to default values.
[42170]
03/06/2024 15:18:12 - INFO - __main__ - ***** Running training *****
03/06/2024 15:18:12 - INFO - __main__ -   Num examples = 5
03/06/2024 15:18:12 - INFO - __main__ -   Num batches each epoch = 5
03/06/2024 15:18:12 - INFO - __main__ -   Num Epochs = 50
03/06/2024 15:18:12 - INFO - __main__ -   Instantaneous batch size per device = 1
03/06/2024 15:18:12 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 1
03/06/2024 15:18:12 - INFO - __main__ -   Gradient Accumulation steps = 1
03/06/2024 15:18:12 - INFO - __main__ -   Total optimization steps = 250
Steps:  20%|██████████████████████████▍                                                                                                         | 50/250 [00:18<00:58,  3.44it/s, loss=0.264, lr=5e-6]03/06/2024 15:18:30 - INFO - __main__ - Running validation... 
 Generating 2 images with prompt: <new1> cat sitting in a bucket.
{'image_encoder', 'requires_safety_checker'} was not found in config. Values will be initialized to default values.
                                                                                                                                                                                                     {'force_upcast', 'norm_num_groups', 'latents_std', 'latents_mean'} was not found in config. Values will be initialized to default values.                                        | 0/7 [00:00<?, ?it/s]
Loaded vae as AutoencoderKL from `vae` subfolder of CompVis/stable-diffusion-v1-4.
                                                                                                                                                                                                     {'timestep_spacing', 'prediction_type'} was not found in config. Values will be initialized to default values.                                                           | 1/7 [00:00<00:00,  6.58it/s]
Loaded scheduler as PNDMScheduler from `scheduler` subfolder of CompVis/stable-diffusion-v1-4.
Loaded feature_extractor as CLIPImageProcessor from `feature_extractor` subfolder of CompVis/stable-diffusion-v1-4.
Loaded safety_checker as StableDiffusionSafetyChecker from `safety_checker` subfolder of CompVis/stable-diffusion-v1-4.
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.08it/s]
{'thresholding', 'rescale_betas_zero_snr', 'solver_order', 'timestep_spacing', 'final_sigmas_type', 'use_lu_lambdas', 'dynamic_thresholding_ratio', 'algorithm_type', 'variance_type', 'lambda_min_clipped', 'prediction_type', 'sample_max_value', 'solver_type', 'use_karras_sigmas', 'euler_at_final', 'lower_order_final'} was not found in config. Values will be initialized to default values.
Traceback (most recent call last):
  File "/home/shinnosuke.ono/scripts/train_custom_diffusion.py", line 1355, in <module>
    main(args)
  File "/home/shinnosuke.ono/scripts/train_custom_diffusion.py", line 1255, in main
    images = [
  File "/home/shinnosuke.ono/scripts/train_custom_diffusion.py", line 1256, in <listcomp>
    pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 971, in __call__
    noise_pred = self.unet(
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1152, in forward
    emb = self.time_embedding(t_emb, timestep_cond)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 228, in forward
    sample = self.linear_1(sample)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/diffusers/models/lora.py", line 447, in forward
    out = super().forward(hidden_states)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half
Steps:  20%|██████████████████████████▍                                                                                                         | 50/250 [00:20<01:22,  2.44it/s, loss=0.264, lr=5e-6]
Traceback (most recent call last):
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1023, in launch_command
    simple_launcher(args)
  File "/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/lib/python3.10/site-packages/accelerate/commands/launch.py", line 643, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/shinnosuke.ono/miniforge3/envs/custom-diffusion/bin/python3.10', 'train_custom_diffusion.py', '--pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4', '--instance_data_dir=../data/cat', '--output_dir=tutorial_output', '--instance_prompt=photo of a <new1> cat', '--resolution=512', '--train_batch_size=1', '--learning_rate=5e-6', '--lr_warmup_steps=0', '--max_train_steps=250', '--scale_lr', '--hflip', '--modifier_token=<new1>', '--no_safe_serialization', '--validation_prompt=<new1> cat sitting in a bucket']' returned non-zero exit status 1.

System Info:

  • platform: Get diffusers ready 🚀🚀🚀 #101~20.04.1-Ubuntu SMP Thu Nov 16 14:22:28 UTC 2023

  • Python: 3.10.13

  • pytorch (GPU?): 2.2.0+cu121(True)

  • accelerate: 0.27.2

  • diffusers: 0.27.0 dev0

  • huggingface_hub: 0.21.3

  • transformers: 4.38.2

  • xformers: not installed

What I tried:

  • When I added --freeze_model="crossattn" to the arguments, I got an error RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Half != float, which is the same one as in this issue.
  • When I removed --validation_prompts=... from the arguments, the script finished successfully, but I am concerned with the deterioration of performance if the model is learned in this way.

@sayakpaul
Copy link
Member

Cc: @nupurkmr9

@nupurkmr9
Copy link
Contributor

@rezkanas @shinnosukeono, sorry for the delayed response. I believe the error might be because of float16 training. Without it I am able to train with --freeze_model="crossattn" as well.
I will see if I can update the code to support float16 with full crossattn fine-tuning as well. In the meantime disabling that should work.
Thanks.

@shinnosukeono
Copy link

@nupurkmr9 Thank you so much for your reply! I changed mixed_precision: 'fp16' to mixed_precision: 'no' in the configuration of the accelerate module, and I got to train successfully with the validation prompt. I hope this also helps @rezkanas

Copy link
Contributor

github-actions bot commented May 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label May 7, 2024
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented May 8, 2024

gentle pin @nupurkmr9
are we able to support float16 with full crossattn fine-tuning as well? if not, let's maybe remove that option from the script

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants