Skip to content

NF4 quantization slower on 0.3 vs 0.1 #642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ebsmothers opened this issue Aug 8, 2024 · 19 comments
Closed

NF4 quantization slower on 0.3 vs 0.1 #642

ebsmothers opened this issue Aug 8, 2024 · 19 comments
Labels
bug Something isn't working

Comments

@ebsmothers
Copy link
Contributor

Hi, we're observing a slowdown in our torchtune QLoRA recipe initialization after changing from version 0.1 to 0.3 (I haven't checked 0.4 yet but will do so shortly). This was first pointed out in pytorch/torchtune#1246 and I believe the cause is coming from some changes in torchao.

Repro: from a torchtune git install

# Just some commit hash from right before we upgraded to 0.3
git checkout 52e328337579e9b84ba7f2448b29a6de7c5d8db3
pip install torchao==0.1

# Save time.perf_counter() on init and then log the delta with perf_counter()
# here: https://github.com/pytorch/torchtune/blob/0a407712eda252573326074d33af0a66c2d2990e/recipes/lora_finetune_single_device.py#L539
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
>>> 15.1960636760341

# Do the same on 0.3
pip install torchao==0.3
# also need to comment some quant APIs out to fix import errors
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
>>> 95.78260190901347
@msaroufim
Copy link
Member

msaroufim commented Aug 8, 2024

@drisspg @weifengpy do you have any guesses as to what might be going on?

@ebsmothers in the meantime did you confirm that the quantization is not happening on CPU with a trace? I recall issues mentioning high RAM usage

@ebsmothers
Copy link
Contributor Author

@msaroufim I haven't done that yet, but since it's on a fixed commit of torchtune my assumption is that it shouldn't matter (lmk if this is mistaken though)

@weifengpy
Copy link
Contributor

weifengpy commented Aug 8, 2024

@drisspg @weifengpy do you have any guesses as to what might be going on?

to avoid peaking GPU memory at NF4 quantization, Driss landed a fix that quantize in chunks #196

if we checkout the PR and micro benchmark before/after the change, we should be able to confirm.

If that's confirmed, it's feature 🍡 (not a bug) . NF4 quantization is one-time efforts. it only affects TTFB (checkpoint loading), not the training QPS

@msaroufim
Copy link
Member

msaroufim commented Aug 8, 2024

interesting, the issue though is the quantization is taking on the order of 5 min relative to 10s that seems like a lot

but your hint is a good one I can microbenchmark the initial nf4 quantization

@weifengpy
Copy link
Contributor

interesting, the issue though is the quantization is taking on the order of 5 min relative to 10s that seems like a lot

good point. quantizing in chunks might not be the only reason

@msaroufim
Copy link
Member

msaroufim commented Aug 8, 2024

EDIT: Running a bisection now

My hypothesis is something is wrong with the chunking PR 6081796

Repro the results,

Cloned https://github.com/ebsmothers/ebs-torchtune/tree/debug with the same tune version

Timing 264s
torch                    2.4.0
torchao                  0.3.1

Timing 222s
torch                    2.4.0
torchao                  0.2.0

Timing 15s
torch                    2.4.0
torchao                  0.1

Running nf4 test timings

Still investigating but I tried timing all the tests we have using pytest --durations=0 test/dtypes/test_nf4.py

And there's a few interesting hints test_register_nf4_as_param test_to_copy became 5x slower

So this seems to hint towards init indeed being an new issue, will confirm with a trace next

## Using nightlies

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================= slowest durations =======================================================================================
10.57s call     test/dtypes/test_nf4.py::TestNF4Linear::test_smoketest_linear_compile_bfloat16
6.13s call     test/dtypes/test_nf4.py::TestNF4Linear::test_smoketest_linear_compile_float32
6.13s call     test/dtypes/test_nf4.py::TestQLoRA::test_qlora_fsdp2
5.07s call     test/dtypes/test_nf4.py::TestNF4Linear::test_smoketest_linear_compile_float16
2.26s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_bfloat16_shape0_chunk_size_16
2.06s call     test/dtypes/test_nf4.py::TestNF4Linear::test_nf4_bnb_linear_bfloat16
1.51s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_invalid_3d_input_size0
0.33s call     test/dtypes/test_nf4.py::TestFSDPOps::test_to_cuda
0.19s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_valid_input_size1
0.19s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_valid_input_size2
0.17s call     test/dtypes/test_nf4.py::TestFSDPOps::test_pin_memory
0.17s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_valid_input_size2
0.16s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_1d_invalid
0.16s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_invalid_input_size0
0.16s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_valid_input_size_262144
0.16s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_valid_input_size1
0.16s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_invalid_input_size2
0.16s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_invalid_input_size1
0.16s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_valid_input_size_262144
0.15s call     test/dtypes/test_nf4.py::TestNF4Linear::test_backward_dtype_match_float16
0.15s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_invalid_input_size_262144
0.14s call     test/dtypes/test_nf4.py::TestNF4Linear::test_backward_dtype_match_bfloat16
0.12s call     test/dtypes/test_nf4.py::TestNF4Linear::test_register_nf4_as_param_float16
0.10s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_invalid_input_size1
0.10s call     test/dtypes/test_nf4.py::TestNF4Linear::test_register_nf4_as_param_bfloat16
0.10s call     test/dtypes/test_nf4.py::TestNF4Linear::test_register_nf4_as_param_float32
0.09s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_valid_input_size2
0.08s call     test/dtypes/test_nf4.py::TestNF4Linear::test_backward_dtype_match_float32
0.07s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_2d_invalid
0.07s call     test/dtypes/test_nf4.py::TestNF4Linear::test_output_dtype_match_bfloat16
0.07s call     test/dtypes/test_nf4.py::TestNF4Linear::test_output_dtype_match_float16
0.06s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_valid_input_size_262144
0.06s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_valid_input_size_262144
0.05s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_valid_input_size1
0.05s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_dtype_float32
0.05s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_float32
0.04s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_bfloat16
0.04s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_device
0.04s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_dtype_float16
0.04s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_float16
0.04s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_dtype_bfloat16
0.03s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_valid_input_size2
0.03s call     test/dtypes/test_nf4.py::TestNF4Linear::test_output_dtype_match_float32
0.03s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_invalid_divide_input_size2
0.03s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_view_valid_input_size1
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_invalid_divide_input_size1
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_invalid_divide_input_size_261632
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_view_invalid_input_size0
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_view_invalid_input_size1
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_valid_input_size1
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_view_valid_input_size0
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_nf4_bnb_linear_float16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float16_shape1_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear:: _from_nf4_diff_meta_bfloat16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_reconstruction_qlora_vs_bnb_bfloat16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float32_shape1_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_nf4_bnb_linear_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_bfloat16_shape1_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_reconstruction_qlora_vs_bnb_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_reconstruction_qlora_vs_bnb_float16
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_deepcopy_input_size1
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_diff_meta_float16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_diff_meta_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float32_shape1_chunk_size_16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float16_shape1_chunk_size_16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float16_shape0_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float32_shape0_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_bfloat16_shape0_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_bfloat16_shape1_chunk_size_16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_same_meta_bfloat16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_same_meta_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_same_meta_float16

## Using ao 0.2

======================================================================================= slowest durations =======================================================================================
5.94s call     test/dtypes/test_nf4.py::TestNF4Linear::test_smoketest_linear_compile_bfloat16
2.46s call     test/dtypes/test_nf4.py::TestNF4Linear::test_nf4_bnb_linear_bfloat16
2.45s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_bfloat16_shape0_chunk_size_16
1.61s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_invalid_3d_input_size0
1.35s call     test/dtypes/test_nf4.py::TestNF4Linear::test_smoketest_linear_compile_float16
1.31s call     test/dtypes/test_nf4.py::TestNF4Linear::test_smoketest_linear_compile_float32
0.11s call     test/dtypes/test_nf4.py::TestNF4Linear::test_output_dtype_match_float32
0.06s call     test/dtypes/test_nf4.py::TestNF4Linear::test_backward_dtype_match_bfloat16
0.06s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_valid_input_size1
0.05s call     test/dtypes/test_nf4.py::TestFSDPOps::test_pin_memory
0.04s call     test/dtypes/test_nf4.py::TestFSDPOps::test_to_cuda
0.03s call     test/dtypes/test_nf4.py::TestNF4Linear::test_backward_dtype_match_float16
0.03s call     test/dtypes/test_nf4.py::TestNF4Linear::test_output_dtype_match_float16
0.03s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_view_invalid_input_size1
0.03s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_dtype_bfloat16
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_invalid_divide_input_size_261632
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_invalid_divide_input_size2
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_invalid_divide_input_size1
0.02s call     test/dtypes/test_nf4.py::TestNF4Linear::test_backward_dtype_match_float32
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_valid_input_size2
0.02s call     test/dtypes/test_nf4.py::TestNF4Linear::test_register_nf4_as_param_float16
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_valid_input_size2
0.02s call     test/dtypes/test_nf4.py::TestNF4Linear::test_output_dtype_match_bfloat16
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_view_valid_input_size1
0.02s call     test/dtypes/test_nf4.py::TestNF4Linear::test_register_nf4_as_param_bfloat16
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_valid_input_size1
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_view_valid_input_size0
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_valid_input_size_262144
0.02s call     test/dtypes/test_nf4.py::TestNF4Linear::test_register_nf4_as_param_float32
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_valid_input_size2
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_valid_input_size1
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_invalid_input_size0
0.02s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_1d_invalid
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_invalid_input_size1
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_2d_invalid
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_torch_chunk_valid_input_size2
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_as_strided_valid_input_size_262144
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_valid_input_size_262144
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_invalid_input_size1
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_invalid_input_size2
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_valid_input_size_262144
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_view_invalid_input_size0
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_slice_valid_input_size1
0.01s call     test/dtypes/test_nf4.py::TestFSDPOps::test_tensor_new_zeros_invalid_input_size_262144
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_nf4_bnb_linear_float16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_float16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_diff_meta_bfloat16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_nf4_bnb_linear_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float16_shape1_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_reconstruction_qlora_vs_bnb_bfloat16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float32_shape1_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_bfloat16_shape1_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_reconstruction_qlora_vs_bnb_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_reconstruction_qlora_vs_bnb_float16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_diff_meta_float16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_diff_meta_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_bfloat16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_dtype_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_bfloat16_shape0_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float16_shape0_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float32_shape1_chunk_size_16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float16_shape1_chunk_size_16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_bfloat16_shape1_chunk_size_16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_chunk_size_equivalence_float32_shape0_chunk_size_8
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_same_meta_float32
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_same_meta_bfloat16
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_same_meta_float16


## Using ao 0.1

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================= slowest durations =======================================================================================
7.04s call     test/dtypes/test_nf4.py::TestNF4Linear::test_smoketest_linear_compile
2.14s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_bfloat16
2.00s call     test/dtypes/test_nf4.py::TestNF4Linear::test_nf4_bnb_linear
0.07s call     test/dtypes/test_nf4.py::TestNF4Linear::test_backward_bf16
0.02s call     test/dtypes/test_nf4.py::TestNF4Linear::test_output_bf16
0.02s call     test/dtypes/test_nf4.py::TestNF4Linear::test_register_nf4_as_param
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_reconstruction_qlora_vs_bnb
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_diff_meta
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy
0.01s call     test/dtypes/test_nf4.py::TestNF4Linear::test_load_from_nf4_same_meta

(26 durations < 0.005s hidden.  Use -vv to show these durations.)

Control-C

If I interrupt at any point during long load times the stack trace is always the same

  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
    module._load_from_state_dict(
  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2096, in _load_from_state_dict
    param.copy_(input_param)
  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/dtypes/nf4tensor.py", line 795, in __torch_function__
    return func(*args, **kwargs)
  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/dtypes/nf4tensor.py", line 776, in __torch_dispatch__
    return NF4_OPS_TABLE[func](func, args, kwargs)
  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/dtypes/nf4tensor.py", line 335, in copy_
    copy_in_nf4 = NF4Tensor.from_tensor(
  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/dtypes/nf4tensor.py", line 517, in from_tensor
    quantized_data = cls.convert_to_norm_float_weight(
  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/dtypes/nf4tensor.py", line 652, in convert_to_norm_float_weight
    quantized_blocks[start:end] = NF4Tensor.quantize_tensor_nearest(flattened[start:end], nf4).to(torch.uint8)
  File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/dtypes/nf4tensor.py", line 697, in quantize_tensor_nearest
    diff = (value - nf4).abs()

@msaroufim
Copy link
Member

msaroufim commented Aug 9, 2024

Ok actually this is a pretty funny bug - I ran a comprehensive git bisect and turns out the offending PR was one by @cpuhrsch that just bumps the ao version from 0.1 to 0.2 #234

I thought I made a mistake so I ran the bisect again and same result 😕

So I checked the torchtune code and indeed there is a version guard here on 0.2 https://github.com/pytorch/torchtune/blob/main/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py#L23-L31

So I figured something funny was going on, I'm not sure what the if condition was about so I just switched

- should_define_inplace_copy = False
+ should_define_inplace_copy = True

And now model loading is fast again on ao nightlies

(ao) [[email protected] ~/ao (main)]$ tune run lora_finetune_single_device --config llama3/8B_qlora_single_dev
ice
INFO:torchtune.utils.logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
  checkpoint_files:
  - consolidated.00.pth
  model_type: LLAMA3
  output_dir: /tmp/Meta-Llama-3-8B-Instruct/
  recipe_checkpoint: null
compile: false
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
device: cuda
dtype: bf16
enable_activation_checkpointing: true
epochs: 1
gradient_accumulation_steps: 16
log_every_n_steps: 1
log_peak_memory_stats: false
loss:
  _component_: torch.nn.CrossEntropyLoss
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.utils.metric_logging.DiskLogger
  log_dir: /tmp/qlora_finetune_output/
model:
  _component_: torchtune.models.llama3.qlora_llama3_8b
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_alpha: 16
  lora_attn_modules:
  - q_proj
  - v_proj
  - k_proj
  - output_proj
  lora_rank: 8
optimizer:
  _component_: torch.optim.AdamW
  lr: 0.0003
  weight_decay: 0.01
output_dir: /tmp/qlora_finetune_output/
profiler:
  _component_: torchtune.utils.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/qlora_finetune_output//profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 5
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model

DEBUG:torchtune.utils.logging:Setting manual seed to local seed 490383513. Local seed is seed + rank = 490383513 + 0
Writing logs to /tmp/qlora_finetune_output/log_1723171696.txt
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Memory stats after model init:
        GPU peak memory allocation: 6.79 GB
        GPU peak memory reserved: 6.97 GB
        GPU peak memory active: 6.79 GB
INFO:torchtune.utils.logging:Tokenizer is initialized from file.
INFO:torchtune.utils.logging:Optimizer and loss are initialized.
INFO:torchtune.utils.logging:Loss is initialized.
INFO:torchtune.utils.logging:Dataset and Sampler are initialized.
INFO:torchtune.utils.logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils.logging: Profiling disabled.
INFO:torchtune.utils.logging: Profiler config after instantiation: {'enabled': False}
Setup time: 16.385372227989137
(ao) [[email protected] ~/ao (main)]$ pip list
Package                  Version          Editable project location
------------------------ ---------------- -------------------------------------
aiohappyeyeballs         2.3.5
aiohttp                  3.10.2
aiosignal                1.3.1
antlr4-python3-runtime   4.9.3
async-timeout            4.0.3
attrs                    24.2.0
bitsandbytes             0.43.3
blobfile                 2.1.1
certifi                  2024.7.4
charset-normalizer       3.3.2
contourpy                1.2.1
cycler                   0.12.1
datasets                 2.20.0
dill                     0.3.8
exceptiongroup           1.2.2
expecttest               0.2.1
filelock                 3.15.4
fire                     0.6.0
fonttools                4.53.1
frozenlist               1.4.1
fsspec                   2024.5.0
huggingface-hub          0.24.5
hypothesis               6.110.1
idna                     3.7
iniconfig                2.0.0
Jinja2                   3.1.4
kiwisolver               1.4.5
lxml                     4.9.4
MarkupSafe               2.1.5
matplotlib               3.9.1.post1
mpmath                   1.3.0
multidict                6.0.5
multiprocess             0.70.16
networkx                 3.3
ninja                    1.11.1.1
numpy                    1.26.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.6.20
nvidia-nvtx-cu12         12.1.105
omegaconf                2.3.0
packaging                24.1
pandas                   2.2.2
parameterized            0.9.0
pillow                   10.4.0
pip                      24.0
pluggy                   1.5.0
pyarrow                  17.0.0
pyarrow-hotfix           0.6
pycryptodomex            3.20.0
pyparsing                3.1.2
pytest                   7.4.0
python-dateutil          2.9.0.post0
pytz                     2024.1
PyYAML                   6.0.2
regex                    2024.7.24
requests                 2.32.3
safetensors              0.4.4
sentencepiece            0.2.0
setuptools               72.1.0
six                      1.16.0
sortedcontainers         2.4.0
sympy                    1.13.1
tabulate                 0.9.0
termcolor                2.4.0
tiktoken                 0.7.0
tokenizers               0.19.1
tomli                    2.0.1
torch                    2.4.0
torchao                  0.4.0+git433cd14
torchtune                0.0.0            /home/marksaroufim/test/ebs-torchtune
torchvision              0.19.0
tqdm                     4.66.5
transformers             4.44.0
triton                   3.0.0
typing_extensions        4.12.2
tzdata                   2024.1
unittest-xml-reporting   3.2.0
urllib3                  2.2.2
wheel                    0.43.0
xxhash                   3.4.1
yarl                     1.9.4
(ao) [[email protected] ~/ao (main)]$ 

EDIT: Removing high pri since this now has a fix

@gau-nernst
Copy link
Collaborator

It seems to suggest that aten.copy_.default impl in torchtune is more efficient than that in torchao? Comparing the two, the most obvious difference I can see is that torchtune swaps the inner quantized data (something like swap storage), while torchao does an extra step of copying the inner quantized data. It doesn't seem to account for such a big slowdown though. Probably need to look into it more.

Also, is swapping inner data for tensor subclass generally safe, especially for torch.compile? If it is safe, it might be a good optimization for other subclasses, like the ones used in low-bit optimizers.

@ebsmothers
Copy link
Contributor Author

Thanks @msaroufim for the comprehensive debugging here! On our end we can just quickly update the version check to always define the inplace_copy op, that'll unblock us. But agree with @gau-nernst -- it seems like the definition in torchtune is faster then, since it's overriding the one in ao. I don't fully understand the version in ao but guess it's not inplace? Actually it is interesting to me that we're able to get away with inplace in torchtune all the time (but need to look more closely into where this op is actually used by us)

@msaroufim
Copy link
Member

msaroufim commented Aug 9, 2024

Yeah it makes sense to upstream the fastest thing to ao, we're happy to maintain that code. Whenever you have a slow day feel free to send a PR our way

@ebsmothers
Copy link
Contributor Author

Just to follow up here, after a bit more investigation on pytorch/torchtune#1294 we are in an interesting place where ao's copy override is slow (on single-device) and works with distributed, while torchtune's inplace version is faster on single-device but does not work with distributed. The call to copy_ by FSDP here is triggering a call to aten.to.dtype_layout from this line of the torchtune version, which is not implemented for NF4. I'm taking a look at this, but also tagging in @cpuhrsch and @rohan-varma since you guys implemented the original version in torchtune and probably have better ideas than me here.

@ebsmothers
Copy link
Contributor Author

Defining

@implements([torch.ops.aten.to.dtype_layout])
def to_dtype_layout(func, *args, **kwargs):
    nf4tensor = args[0][0]
    updated_attrs = apply_to_inner_tensors(nf4tensor, func, args=[], kwargs=args[1])
    return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))

at least gets us past the original error, but once we're done sharding the model, any NF4 params that are in DTensors have all nans for their _local_tensor field.

@joecummings
Copy link

@ebsmothers is on PTO so I'm taking a look at this. @weifengpy @msaroufim my insanely naive approach would be to explicitly set the _local_tensor field, but this seems like attacking the symptom rather than the root cause.

What's the interaction here between FSDP2 and NF4Tensor?

@ebsmothers
Copy link
Contributor Author

Bumping this issue since we still don't have a fix here. In addition to defining to_dtype_layout as I did here we also tried directly replacing the to_nf4 usage here with ref_tensor = NF4Tensor(*construct_nf4_args(dest_tensor)) to get rid of the .to(device) call. Both result in nans though. Per @msaroufim's suggestion, tagging in @bdhirsh in case you have any ideas on what could be happening here.

@gau-nernst
Copy link
Collaborator

Upon closer inspection, I think torchtune is faster because the original weight is moved to GPU before quantization, while torchao still perform quantization on original device i.e. CPU

https://github.com/pytorch/torchtune/blob/f9f75bb563ecae371492a9d49da4a9f514c081b3/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py#L45

if not isinstance(copy_in, NF4Tensor):
copy_in_nf4 = NF4Tensor.from_tensor(
copy_in, original.block_size, original.scaler_block_size
)
return original.copy_(copy_in_nf4)

This is because torchtune init model on CUDA, and load CPU state dict, thus trigger NF4Tensor_CUDA.copy_(BF16Tensor_CPU) (correct me if i'm wrong). A simple fix will be

    if not isinstance(copy_in, NF4Tensor):
        copy_in_nf4 = NF4Tensor.from_tensor(
            copy_in.to(original.device), original.block_size, original.scaler_block_size
        )
        return original.copy_(copy_in_nf4)

This will raise peak GPU memory, but should be a good trade-off?

@weifengpy
Copy link
Contributor

thanks for pinpointing to the root cause!

@ebsmothers
Copy link
Contributor Author

@gau-nernst thanks for figuring this out! This looks like exactly what we need. I confirmed that after making this change in ao that the QLoRA initialization time matches what we see using torchtune's inplace copy. Additionally the FSDP2 recipe works when we use this version instead of the one in torchtune. I actually don't see any negative memory implications of this change either, so no reason not to make this change from my perspective

@msaroufim
Copy link
Member

@ebsmothers mind then deleting some of the existing nf4 code in tune and making the patch here in ao. As long as ci is green I suspect ci review will be speedy here.

@ebsmothers
Copy link
Contributor Author

Opened #737 and updated pytorch/torchtune#1294 to delete torchtune's override

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants