Skip to content

Custom Diffusion : run train_custom_diffusion.py fails #4704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
canberk17 opened this issue Aug 21, 2023 · 6 comments
Closed

Custom Diffusion : run train_custom_diffusion.py fails #4704

canberk17 opened this issue Aug 21, 2023 · 6 comments
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@canberk17
Copy link
Contributor

canberk17 commented Aug 21, 2023

Describe the bug

Hey guys

I've been doing the tutorials on Hugging Face, love the content! I encountered an error during training while following this documentation on CustomDiffusion. I am using the provided cat dataset, and everything works fine until the training but once the training steps are done I getting the following error message:

ValueError: Key `down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor` is invalid, expected torch.Tensor but received <class 'dict'>

I'm doing everything exactly as it is in the tutorial, but cant figure out why this is happening. I checked the following issues: (#3284,#4235,#3221),

In this case how can I avoid using safetensors, or how can I handle empty weight layers?

Reproduction

import os
os.environ['MODEL_NAME']="CompVis/stable-diffusion-v1-4"
os.environ['OUTPUT_DIR']="/content/stable_diffusion_weights/cat"
os.environ['INSTANCE_DIR']="/content/data/cat"

!accelerate launch train_custom_diffusion.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --class_data_dir=./real_reg/samples_cat/ \
  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
  --class_prompt="cat" --num_class_images=200 \
  --instance_prompt="photo of a <new1> cat"  \
  --resolution=512  \
  --train_batch_size=1  \
  --learning_rate=1e-5  \
  --lr_warmup_steps=80 \
  --max_train_steps=250 \
  --enable_xformers_memory_efficient_attention\
  --scale_lr --hflip  \
  --modifier_token "<new1>" \
  --push_to_hub

Logs

2023-08-21 21:24:18.493777: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
08/21/2023 21:24:20 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

Downloading (…)okenizer_config.json: 100% 806/806 [00:00<00:00, 4.56MB/s]
Downloading (…)tokenizer/vocab.json: 100% 1.06M/1.06M [00:00<00:00, 1.12MB/s]
Downloading (…)tokenizer/merges.txt: 100% 525k/525k [00:00<00:00, 41.1MB/s]
Downloading (…)cial_tokens_map.json: 100% 472/472 [00:00<00:00, 2.57MB/s]
Downloading (…)_encoder/config.json: 100% 592/592 [00:00<00:00, 3.36MB/s]
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Downloading (…)cheduler_config.json: 100% 313/313 [00:00<00:00, 1.89MB/s]
{'clip_sample_range', 'timestep_spacing', 'dynamic_thresholding_ratio', 'thresholding', 'variance_type', 'prediction_type', 'sample_max_value'} was not found in config. Values will be initialized to default values.
Downloading model.safetensors: 100% 492M/492M [00:01<00:00, 399MB/s]
Downloading (…)main/vae/config.json: 100% 551/551 [00:00<00:00, 2.41MB/s]
Downloading (…)ch_model.safetensors: 100% 335M/335M [00:00<00:00, 395MB/s]
{'force_upcast', 'norm_num_groups'} was not found in config. Values will be initialized to default values.
Downloading (…)ain/unet/config.json: 100% 743/743 [00:00<00:00, 4.28MB/s]
Downloading (…)ch_model.safetensors: 100% 3.44G/3.44G [00:10<00:00, 325MB/s]
{'class_embed_type', 'resnet_out_scale_factor', 'encoder_hid_dim_type', 'time_cond_proj_dim', 'time_embedding_type', 'time_embedding_act_fn', 'only_cross_attention', 'dual_cross_attention', 'addition_embed_type', 'class_embeddings_concat', 'resnet_skip_time_act', 'time_embedding_dim', 'mid_block_type', 'timestep_post_act', 'conv_out_kernel', 'upcast_attention', 'addition_embed_type_num_heads', 'encoder_hid_dim', 'use_linear_projection', 'addition_time_embed_dim', 'attention_type', 'transformer_layers_per_block', 'conv_in_kernel', 'resnet_time_scale_shift', 'mid_block_only_cross_attention', 'num_class_embeds', 'projection_class_embeddings_input_dim', 'num_attention_heads', 'cross_attention_norm'} was not found in config. Values will be initialized to default values.
[42170]
/content/diffusers/examples/custom_diffusion/train_custom_diffusion.py:880: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
  logger.warn(
08/21/2023 21:24:49 - WARNING - __main__ - xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.
08/21/2023 21:24:51 - INFO - __main__ - ***** Running training *****
08/21/2023 21:24:51 - INFO - __main__ -   Num examples = 200
08/21/2023 21:24:51 - INFO - __main__ -   Num batches each epoch = 200
08/21/2023 21:24:51 - INFO - __main__ -   Num Epochs = 2
08/21/2023 21:24:51 - INFO - __main__ -   Instantaneous batch size per device = 1
08/21/2023 21:24:51 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 1
08/21/2023 21:24:51 - INFO - __main__ -   Gradient Accumulation steps = 1
08/21/2023 21:24:51 - INFO - __main__ -   Total optimization steps = 250
Steps: 100% 250/250 [08:11<00:00,  1.99s/it, loss=3.73, lr=2e-5]08/21/2023 21:33:02 - INFO - accelerate.accelerator - Saving current state to /content/stable_diffusion_weights/cat/checkpoint-250
08/21/2023 21:33:02 - INFO - accelerate.checkpointing - Model weights saved in /content/stable_diffusion_weights/cat/checkpoint-250/pytorch_model.bin
08/21/2023 21:33:03 - INFO - accelerate.checkpointing - Model weights saved in /content/stable_diffusion_weights/cat/checkpoint-250/pytorch_model_1.bin
08/21/2023 21:33:04 - INFO - accelerate.checkpointing - Optimizer state saved in /content/stable_diffusion_weights/cat/checkpoint-250/optimizer.bin
08/21/2023 21:33:04 - INFO - accelerate.checkpointing - Scheduler state saved in /content/stable_diffusion_weights/cat/checkpoint-250/scheduler.bin
08/21/2023 21:33:04 - INFO - accelerate.checkpointing - Random states saved in /content/stable_diffusion_weights/cat/checkpoint-250/random_states_0.pkl
08/21/2023 21:33:04 - INFO - accelerate.checkpointing - Saving the state of AttnProcsLayers to /content/stable_diffusion_weights/cat/checkpoint-250/custom_checkpoint_0.pkl
08/21/2023 21:33:04 - INFO - __main__ - Saved state to /content/stable_diffusion_weights/cat/checkpoint-250
Steps: 100% 250/250 [08:13<00:00,  1.99s/it, loss=0.994, lr=2e-5]Traceback (most recent call last):
  File "/content/diffusers/examples/custom_diffusion/train_custom_diffusion.py", line 1330, in <module>
    main(args)
  File "/content/diffusers/examples/custom_diffusion/train_custom_diffusion.py", line 1258, in main
    unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
  File "/content/diffusers/src/diffusers/loaders.py", line 568, in save_attn_procs
    save_function(state_dict, os.path.join(save_directory, weight_name))
  File "/content/diffusers/src/diffusers/loaders.py", line 534, in save_function
    return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
  File "/usr/local/lib/python3.10/dist-packages/safetensors/torch.py", line 282, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
  File "/usr/local/lib/python3.10/dist-packages/safetensors/torch.py", line 433, in _flatten
    raise ValueError(f"Key `{k}` is invalid, expected torch.Tensor but received {type(v)}")
ValueError: Key `down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor` is invalid, expected torch.Tensor but received <class 'dict'>
Steps: 100% 250/250 [08:13<00:00,  1.98s/it, loss=0.994, lr=2e-5]
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 979, in launch_command
    simple_launcher(args)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 628, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/usr/bin/python3', 'train_custom_diffusion.py', '--pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4', '--instance_data_dir=/content/data/cat', '--output_dir=/content/stable_diffusion_weights/cat', '--class_data_dir=./real_reg/samples_cat/', '--with_prior_preservation', '--real_prior', '--prior_loss_weight=1.0', '--class_prompt=cat', '--num_class_images=200', '--instance_prompt=photo of a <new1> cat', '--resolution=512', '--train_batch_size=1', '--learning_rate=1e-5', '--lr_warmup_steps=80', '--max_train_steps=250', '--enable_xformers_memory_efficient_attention', '--scale_lr', '--hflip', '--modifier_token', '<new1>', '--push_to_hub']' returned non-zero exit status 1.

System Info

  • diffusers version: 0.21.0.dev0
  • Platform: Linux-5.15.109+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version (GPU?): 1.13.1+cu117 (True)
  • Huggingface_hub version: 0.16.4
  • Transformers version: 4.31.0
  • Accelerate version: 0.21.0
  • xFormers version: 0.0.16
  • Using GPU in script: Yes, NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0
  • Using distributed or parallel set-up in script?: <>

Who can help?

@sayakpaul @patrickvonplaten

@canberk17 canberk17 added the bug Something isn't working label Aug 21, 2023
@canberk17
Copy link
Contributor Author

I tried using --no_safe_serialization to switch to PyTorch format at it is stated in the args here at line 615

But then I starterd getting the following error:

Traceback (most recent call last):
  File "/content/diffusers/examples/custom_diffusion/train_custom_diffusion.py", line 1330, in <module>
    main(args)
  File "/content/diffusers/examples/custom_diffusion/train_custom_diffusion.py", line 1312, in main
    images=images,
UnboundLocalError: local variable 'images' referenced before assignment
Steps: 100% 250/250 [08:22<00:00,  2.01s/it, loss=0.994, lr=2e-5]
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 979, in launch_command
    simple_launcher(args)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 628, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)

@canberk17 canberk17 changed the title run train_custom_diffusion.py failed Custom Diffusion Tutorial : run train_custom_diffusion.py fails Aug 22, 2023
@canberk17 canberk17 changed the title Custom Diffusion Tutorial : run train_custom_diffusion.py fails Custom Diffusion : run train_custom_diffusion.py fails Aug 22, 2023
@sayakpaul
Copy link
Member

Do you have a Colab Notebook for us to look into? Ccing @DN6

@canberk17
Copy link
Contributor Author

canberk17 commented Aug 22, 2023

Hey @sayakpaul, I managed to make it work.

After adding the --no_safe_serialization option to the initial training script from the tutorial, I started receiving an UnboundLocalError at the end of the training:

File "/content/diffusers/examples/custom_diffusion/train_custom_diffusion.py", line 1330, in <module>
    main(args)
  File "/content/diffusers/examples/custom_diffusion/train_custom_diffusion.py", line 1312, in main
    images=images,
UnboundLocalError: local variable 'images' referenced before assignment

The line (1312) is referring inside the conditional block if args.push_to_hub:. This error error occurs because the images variable is defined inside two nested conditional blocks (lines 1232 and 1290), but is later referenced without being assigned if certain conditions are not met.

Here's how images is defined:

line 1232

if accelerator.is_main_process:
    if args.validation_prompt is not None and global_step % args.validation_steps == 0:
        # ...

line 1290

if args.validation_prompt and args.num_validation_images > 0:
    images = [
        # ...
    ]

This error leads to a "UnboundLocalError" when the code tries to access images without having defined it first.

Basically the first training script from the documentation on it own leads to a ValueError. To resolve this, I added the --no_safe_serialization argument, but then encountered the UnboundLocalError.

To fix the issue, the following arguments can be added to the training script:

 --num_validation_images=5 \
 --validation_prompt="<new1> cat sitting in a bucket" \
 --no_safe_serialization \

So the final configuration is as follows:

!accelerate launch train_custom_diffusion.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  # ...
  --num_validation_images=5 \
  --validation_prompt="<new1> cat sitting in a bucket" \
  --max_train_steps=250 \
  --no_safe_serialization \
  --enable_xformers_memory_efficient_attention\
  --scale_lr --hflip  \
  --modifier_token "<new1>" 
  --push_to_hub

Hope this helps

@sayakpaul
Copy link
Member

Wonderful! Actually, safe serialization is currently not possible with Custom Diffusion training (Cc: @DN6).

Thanks for mentioning your changes to make it work. Would you like to open a PR? I think that will be very helpful.

@canberk17
Copy link
Contributor Author

sent the PR

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
2 participants