Skip to content

WIP: Integrate SpQR + FSDP functionalities #840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
bc33a49
Updated 8-bit optimizers to blocksize 256.
TimDettmers Aug 5, 2023
5672e15
Removed non-blockwise optimizers.
TimDettmers Aug 6, 2023
61eaf7b
Added fsdp test.
TimDettmers Aug 6, 2023
bda3722
Some more custom data type and fsdp test.
TimDettmers Aug 31, 2023
c28984c
Added simple swapping for Params4bit.
TimDettmers Aug 31, 2023
98b57ba
Added missing funcions.
TimDettmers Aug 31, 2023
0076227
Small fixes to swapping logic
TimDettmers Aug 31, 2023
fe7e0a6
Switched subclasses.
TimDettmers Aug 31, 2023
6d15795
Working swapping.
TimDettmers Aug 31, 2023
dca5a7e
Added swapping benchmark.
TimDettmers Sep 1, 2023
91aa1f8
spqr: initial implementation outline for collaborative review
Titus-von-Koeller Sep 2, 2023
3d6ce45
Simple packing code with test that fails.
TimDettmers Sep 5, 2023
1a79a80
inline spqr quantized model loading, only formatting changed
Titus-von-Koeller Sep 6, 2023
a49fed2
refactor spqr easier for layer-wise operations
Titus-von-Koeller Sep 6, 2023
1fdd218
chore: improve project setup - dev environment, formatting + linting …
Titus-von-Koeller Sep 6, 2023
f78a6d1
Added 3-bit packing code with tests (all green).
TimDettmers Sep 6, 2023
050b345
Comments for permutation order quantization.
TimDettmers Sep 6, 2023
3bfd4af
Added row-wise correct kernel. Test needs to be adjusted.
TimDettmers Sep 7, 2023
4b89b6b
format + lint nn/modules.py
Titus-von-Koeller Sep 15, 2023
3ce14e9
Better FSDP tests. Failing mixed grads test.
TimDettmers Sep 18, 2023
9e23d14
Added manually wrapped mixed gradient test.
TimDettmers Sep 18, 2023
8278fca
Added fixes to Linear4bit for FSDP.
TimDettmers Sep 19, 2023
de5259d
spqr: packing, custom linear layer, improvements, bug in cuda
Titus-von-Koeller Sep 19, 2023
441960d
updated conda dev env
Titus-von-Koeller Sep 19, 2023
b11a663
improve linting config
Titus-von-Koeller Sep 19, 2023
ba1d319
add pytest config for default args, etc
Titus-von-Koeller Sep 19, 2023
361af91
Fixed indexing error in kPack3Bits.
TimDettmers Sep 20, 2023
b5acc7d
fsdp: outline nested module wrapping integration
Titus-von-Koeller Sep 25, 2023
a2461dd
fsdp: starting point for tests (Titus)
Titus-von-Koeller Oct 9, 2023
9a391f9
Added random port number to fsdp tests.
TimDettmers Oct 9, 2023
68ed65a
fsdp: preliminary implementation + tests
Titus-von-Koeller Oct 10, 2023
cb97b46
updated dev environment
Titus-von-Koeller Oct 10, 2023
bfc113a
added small improvements to nn introspection helpers
Titus-von-Koeller Oct 10, 2023
ca404d0
auto-formatting
Titus-von-Koeller Oct 10, 2023
b2f3082
delete superfluous files
Titus-von-Koeller Oct 10, 2023
b0d4249
fsdp test: add delay to allow GC to collect remains of old process group
Titus-von-Koeller Oct 11, 2023
ea0a096
fsdp: manual wrapping w/o use_orig_params test working
Titus-von-Koeller Oct 13, 2023
baec8cb
enable isort
Titus-von-Koeller Oct 13, 2023
27c7acd
auto-formatting
Titus-von-Koeller Oct 13, 2023
2e6ac0a
fsdp: add LoRA toy model
Titus-von-Koeller Oct 21, 2023
407b8c0
update dependencies
Titus-von-Koeller Oct 24, 2023
9c6212c
fix f-string
Titus-von-Koeller Oct 24, 2023
85246f0
setup: minor fixes + auto formatting
Titus-von-Koeller Oct 24, 2023
3449684
utils: polish replace_linear
Titus-von-Koeller Oct 25, 2023
a580011
polish + test replace_linear
Titus-von-Koeller Oct 25, 2023
aceeb1d
format utils.py
Titus-von-Koeller Oct 25, 2023
e38c474
utils: add further tests for replace_linear
Titus-von-Koeller Oct 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .style.yapf
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[style]
ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = True
ALLOW_MULTILINE_LAMBDAS = True
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = True
COLUMN_LIMIT = 88
COALESCE_BRACKETS = True
SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True
SPACES_BEFORE_COMMENT = 2
SPLIT_BEFORE_BITWISE_OPERATOR = True
SPLIT_BEFORE_FIRST_ARGUMENT = True
SPLIT_BEFORE_LOGICAL_OPERATOR = True
SPLIT_BEFORE_NAMED_ASSIGNS = True
SPLIT_COMPLEX_COMPREHENSION = True
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86
CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80

CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
#CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 := -gencode arch=compute_86,code=sm_86

CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
Expand Down
254 changes: 166 additions & 88 deletions bitsandbytes/cuda_setup/main.py

Large diffs are not rendered by default.

Empty file.
188 changes: 188 additions & 0 deletions bitsandbytes/distributed/fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""
Fully Sharded Data Parallelism (FSDP) is a technique used to scale deep learning models across multiple GPUs or nodes, enabling training of models that are too large to fit into the memory of a single GPU.

FSDP shards (splits) the model's parameters, gradients, and optimizer state across the data parallel workers. This reduces the memory consumption on each worker, enabling training of models that are too large to fit into the memory of a single node by sharding the model across multiple nodes.

## Requirements for Selective Wrapping:
Note: Look at the last commit, the FSDP test (`tests/test_optim.py`), how it wraps the model.

We need a mechanism that takes a PyTorch model (which can be a nested `torch.nn.Module`) and wrap the largest possible module graph sub-trees based on the below criteria:

- wrap as many layers into one block as possible to reduce the overhead of communication
- only layers with same data types + grad types can be wrapped together

- therefore layers with `requires_grad=False` CANNOT be wrapped with layers with `requires_grad=True`
- `Linear4bit` is considered such special type that can only be wrapped with the same type, `Linear4bit`
- mixed subtrees cannot be wrapped and can be ignored. Instead wrap the smallest wrappable sub-tree
- not all layers need to be wrapped. Layernorms are usually not faster when wrapped. A good rule of thumb is a layer needs to have at least 1M parameters to be worth wrapping
- bias or no bias grants no special considerations

## Custom Auto Wrap Policy:
Custom auto wrap policy function for determining whether to wrap a given module
with Fully Sharded Data Parallel (FSDP) based on specific criteria.

This function is designed to be used as the `auto_wrap_policy` argument when
initializing an FSDP wrapper. It follows the API expected by FSDP for auto-wrap
policies and makes wrapping decisions, but it has a second type of boolean logic baked
in, which makes things confusing. The boolean return value has different meanings,
based on whether the `recurse` parameter is `True` or `False`, see below for further
details.

Parameters:
- module (nn.Module): The module being considered for wrapping.
- recurse (bool): A flag indicating whether the function is being called during
the traversal down the module tree. If `True`, the function will always continue
the traversal by returning `True`.
- nonwrapped_numel (int): The number of elements in the module that are not yet
wrapped. This parameter is not used in the current implementation but is included
to fulfill the expected API.

Returns:
- bool: A boolean value indicating either (1) whether recursion should continue, if
called with `recurse=True` or (2) whether the given module should be wrapped, if
called with `recurse=False`.


How the recursion works:
The FSDP wrapper traverses the module tree, starting from the root module, and
calls this function for each module encountered. The `recurse` parameter indicates
whether the current call is part of the traversal down the tree. If `recurse` is
`True`, the function returns `True` to continue the traversal. When reaching a leaf
module, `recurse` is `False`, and the function makes a decision based on the specific
criteria whether to wrap the module. This way, the function is recursively called for
each module in the tree, allowing selective wrapping of the modules. Therefore, the
leaves are wrapped first, and the wrapping propagates up the tree.
"""

from collections import deque
import functools
from typing import Dict, Iterable, Optional, OrderedDict, Set, Tuple, Type, Union

import torch
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
import torch.nn as nn
import torch.nn.functional as F

from bitsandbytes.nn import Linear4bit


def parameters_all_consistent(params: Iterable[torch.nn.Parameter]) -> bool:
"""
Check if all parameters in the iterable have the same dtype and requires_grad attributes.

Parameters:
- params (Iterable[torch.nn.Parameter]): An iterable of PyTorch parameters.

Returns:
- bool: True if all parameters are consistent, False otherwise.
"""
params_iter = iter(params)

try:
first_param = next(params_iter)
except StopIteration:
return True # No parameters to check

return all(
param.dtype == first_param.dtype
and param.requires_grad == first_param.requires_grad for param in params_iter)


from bitsandbytes.nn.helpers import debug


# TODO: still getting "ValueError: Must flatten tensors with uniform `requires_grad`
# when `use_orig_params=False`" when running `pytest tests/test_fsdp.py::test_fsdp_bnb`
# something must still be off with the custom auto wrap policy...
@debug
def bnb_fsdp_auto_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
debug: bool = True,
*args,
**kwargs,
) -> bool:
"""See the module doc string, section "Custom Auto Wrap Policy" for details..."""
def debug(): # TODO: remove this and the extraneous comments once this is working
print([(n, p.dtype, p.requires_grad) for n, p in module.named_parameters()])

if recurse:
# return True to recurse: we recurse until we hit a module w/ consistent params
return not parameters_all_consistent(module.parameters())
# if we're not recursing, we're evaluating if the module should be wrapped and
# therefore return True if the module has consistent params, as we're trying to
# wrap the largest possible module graph sub-trees based on this criterium
return parameters_all_consistent(module.parameters())

"""
Things to still integrate:
- min_num_params
- ignored_modules
"""


# TODO: this example policy will be removed later, we still need to integrate the
# min_num_params mechanism, either directly into the bnb_custom_auto_wrap_policy or
# by using the FSDP `_or_policy`.
def size_based_auto_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
# Additional custom arguments
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
) -> bool:
"""
A size-based auto wrap policy.

Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.

min_num_params (int): Customizable policy input that controls the size
threshold over which a module is ready to be wrapped. This is in
units of numel.
force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
as leaves, i.e. their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
excluded in wrapping.

Returns:
Whether ``module`` should be wrapped.
"""
force_leaf_modules = (
size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
if force_leaf_modules is None else force_leaf_modules)
exclude_wrap_modules = (
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
if exclude_wrap_modules is None else exclude_wrap_modules)

# Keep the argument `min_num_params` for BC for now, but it represents the
# minimum non-wrapped *numel* before triggering a wrapping
min_nonwrapped_numel = min_num_params
is_large = nonwrapped_numel >= min_nonwrapped_numel
if recurse:
# We should recurse if the module is big enough but not in force_leaf_modules list.
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, determine if we should wrap.
return is_large and not isinstance(module, tuple(exclude_wrap_modules))


# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {
nn.ModuleList, nn.ModuleDict
} # type: ignore[attr-defined]
size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {
nn.MultiheadAttention
} # type: ignore[attr-defined]
Loading