-
Notifications
You must be signed in to change notification settings - Fork 253
NF4 quantization slower on 0.3 vs 0.1 #642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@drisspg @weifengpy do you have any guesses as to what might be going on? @ebsmothers in the meantime did you confirm that the quantization is not happening on CPU with a trace? I recall issues mentioning high RAM usage |
@msaroufim I haven't done that yet, but since it's on a fixed commit of torchtune my assumption is that it shouldn't matter (lmk if this is mistaken though) |
to avoid peaking GPU memory at NF4 quantization, Driss landed a fix that quantize in chunks #196 if we checkout the PR and micro benchmark before/after the change, we should be able to confirm. If that's confirmed, it's feature 🍡 (not a bug) . NF4 quantization is one-time efforts. it only affects TTFB (checkpoint loading), not the training QPS |
interesting, the issue though is the quantization is taking on the order of 5 min relative to 10s that seems like a lot but your hint is a good one I can microbenchmark the initial nf4 quantization |
good point. quantizing in chunks might not be the only reason |
EDIT: Running a bisection now My hypothesis is something is wrong with the chunking PR 6081796 Repro the results,Cloned https://github.com/ebsmothers/ebs-torchtune/tree/debug with the same tune version
Running nf4 test timingsStill investigating but I tried timing all the tests we have using And there's a few interesting hints So this seems to hint towards init indeed being an new issue, will confirm with a trace next
Control-CIf I interrupt at any point during long load times the stack trace is always the same
|
Ok actually this is a pretty funny bug - I ran a comprehensive git bisect and turns out the offending PR was one by @cpuhrsch that just bumps the ao version from 0.1 to 0.2 #234 I thought I made a mistake so I ran the bisect again and same result 😕 So I checked the torchtune code and indeed there is a version guard here on 0.2 https://github.com/pytorch/torchtune/blob/main/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py#L23-L31 So I figured something funny was going on, I'm not sure what the if condition was about so I just switched - should_define_inplace_copy = False
+ should_define_inplace_copy = True And now model loading is fast again on ao nightlies
EDIT: Removing high pri since this now has a fix |
It seems to suggest that Also, is swapping inner data for tensor subclass generally safe, especially for torch.compile? If it is safe, it might be a good optimization for other subclasses, like the ones used in low-bit optimizers. |
Thanks @msaroufim for the comprehensive debugging here! On our end we can just quickly update the version check to always define the inplace_copy op, that'll unblock us. But agree with @gau-nernst -- it seems like the definition in torchtune is faster then, since it's overriding the one in ao. I don't fully understand the version in ao but guess it's not inplace? Actually it is interesting to me that we're able to get away with inplace in torchtune all the time (but need to look more closely into where this op is actually used by us) |
Yeah it makes sense to upstream the fastest thing to ao, we're happy to maintain that code. Whenever you have a slow day feel free to send a PR our way |
Just to follow up here, after a bit more investigation on pytorch/torchtune#1294 we are in an interesting place where ao's copy override is slow (on single-device) and works with distributed, while torchtune's inplace version is faster on single-device but does not work with distributed. The call to |
Defining
at least gets us past the original error, but once we're done sharding the model, any NF4 params that are in DTensors have all nans for their |
@ebsmothers is on PTO so I'm taking a look at this. @weifengpy @msaroufim my insanely naive approach would be to explicitly set the What's the interaction here between FSDP2 and NF4Tensor? |
Bumping this issue since we still don't have a fix here. In addition to defining |
Upon closer inspection, I think torchtune is faster because the original weight is moved to GPU before quantization, while torchao still perform quantization on original device i.e. CPU ao/torchao/dtypes/nf4tensor.py Lines 340 to 344 in 9860194
This is because torchtune init model on CUDA, and load CPU state dict, thus trigger if not isinstance(copy_in, NF4Tensor):
copy_in_nf4 = NF4Tensor.from_tensor(
copy_in.to(original.device), original.block_size, original.scaler_block_size
)
return original.copy_(copy_in_nf4) This will raise peak GPU memory, but should be a good trade-off? |
thanks for pinpointing to the root cause! |
@gau-nernst thanks for figuring this out! This looks like exactly what we need. I confirmed that after making this change in ao that the QLoRA initialization time matches what we see using torchtune's inplace copy. Additionally the FSDP2 recipe works when we use this version instead of the one in torchtune. I actually don't see any negative memory implications of this change either, so no reason not to make this change from my perspective |
@ebsmothers mind then deleting some of the existing nf4 code in tune and making the patch here in ao. As long as ci is green I suspect ci review will be speedy here. |
Opened #737 and updated pytorch/torchtune#1294 to delete torchtune's override |
Hi, we're observing a slowdown in our torchtune QLoRA recipe initialization after changing from version 0.1 to 0.3 (I haven't checked 0.4 yet but will do so shortly). This was first pointed out in pytorch/torchtune#1246 and I believe the cause is coming from some changes in torchao.
Repro: from a torchtune git install
The text was updated successfully, but these errors were encountered: