Skip to content

Full finetune < 16GB #527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ea16903
Upd
rohan-varma Mar 15, 2024
2144c4f
Merge branch 'main' of github.com:pytorch-labs/torchtune
rohan-varma Mar 16, 2024
0104e36
Upd
rohan-varma Mar 16, 2024
f7d1439
Merge branch 'main' of github.com:pytorch-labs/torchtune
rohan-varma Mar 18, 2024
6479227
Merge branch 'main' of github.com:pytorch-labs/torchtune
rohan-varma Mar 19, 2024
ff69f09
Merge branch 'main' of github.com:pytorch-labs/torchtune
rohan-varma Mar 19, 2024
622d6fa
Full FT mem efficiency
rohan-varma Mar 19, 2024
899fa97
Upd
rohan-varma Mar 19, 2024
7770168
upd
rohan-varma Mar 19, 2024
ae4b943
Upd
rohan-varma Mar 22, 2024
2b5060f
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 22, 2024
eeb8edd
Upd
rohan-varma Mar 22, 2024
0fb1eba
Upd
rohan-varma Mar 22, 2024
0d85856
Upd
rohan-varma Mar 22, 2024
bb71c2e
Upd
rohan-varma Mar 23, 2024
57030f2
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 25, 2024
c4a58e2
upd
rohan-varma Mar 25, 2024
e2e0725
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 25, 2024
61371b4
upd
rohan-varma Mar 25, 2024
8b267dc
upd
rohan-varma Mar 25, 2024
9f5ac0b
upd
rohan-varma Mar 25, 2024
b3dea03
Upd
rohan-varma Mar 25, 2024
7fe9485
upd
rohan-varma Mar 25, 2024
23b10f7
upd
rohan-varma Mar 25, 2024
c911eed
CI
rohan-varma Mar 25, 2024
7325133
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 26, 2024
0e2d4fc
UPd
rohan-varma Mar 26, 2024
98f6fcc
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 27, 2024
a4292f1
upd
rohan-varma Mar 27, 2024
86839ad
upd
rohan-varma Mar 27, 2024
d9f8975
WIP - testing
rohan-varma Mar 27, 2024
9efe051
upd
rohan-varma Mar 28, 2024
f26d34b
upd
rohan-varma Mar 28, 2024
a3d850b
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 28, 2024
6207797
Upd
rohan-varma Mar 28, 2024
7597c68
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 29, 2024
6c9731d
Upd
rohan-varma Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ experience different peak memory utilization based on changes made in configurat
| 1 x RTX 4090 | QLoRA | [qlora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml) | Llama-7B | 9.29 GB * |
| 2 x RTX 4090 | LoRA | [lora_finetune_distributed](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_lora.yaml) | Llama-7B | 14.17 GB * |
| 1 x RTX 4090 | LoRA | [lora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_lora_single_device.yaml) | Llama-7B | 17.18 GB * |
| 1 x A6000 | Full finetune | [full_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full_single_device.yaml) | Llama-7B | 27.15 GB * |
| 1 x A6000 | Full finetune | [full_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full_single_device_low_memory.yaml) | Llama-7B | 15.97 GB * ^ |
| 4 x RTX 4090 | Full finetune | [full_finetune_distributed](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full.yaml) | Llama-7B | 12.01 GB * |


NOTE: * indicates an estimated metric based on experiments conducted on A100 GPUs with GPU memory artificially limited using [torch.cuda.set_per_process_memory_fraction API](https://pytorch.org/docs/stable/generated/torch.cuda.set_per_process_memory_fraction.html). Peak memory per GPU is as reported by `torch.cuda.max_memory_reserved()`. Please file an issue if you are not able to reproduce these results when running TorchTune on certain hardware.

NOTE: ^ indicates the required use of third-party dependencies that are not installed with torchtune by default. In particular, for the most memory efficient full finetuning [configuration](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full_single_device_low_memory.yaml), [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is required and can be installed via `pip install bitsandbytes`, after which the configuration
can be run successfully.

&nbsp;

---
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
optimizer_in_bwd: False


# Training environment
Expand Down
76 changes: 76 additions & 0 deletions recipes/configs/llama2/7B_full_single_device_low_memory.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Llama2 7B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download --repo-id meta-llama/Llama-2-7b \
# --hf-token <HF_TOKEN> \
# --output-dir /tmp/llama2
#
# To launch on a single device, run the following command from root:
# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \
# --config llama2/7B_full_single_device_low_memory \
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \
# --config llama2/7B_full_single_device_low_memory \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.


# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /tmp/llama2
checkpoint_files: [consolidated.00.pth]
recipe_checkpoint: null
output_dir: /tmp/llama2
model_type: LLAMA2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 1
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bitsandbytes is not officially in our core requirements, right? What's our plan for handling things gracefully here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call out, I think we'll simply add bitsandbytes to our core deps.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added bitsandbytes as a core dep

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed bnb as a core dep after discussion

lr: 2e-5
optimizer_in_bwd: True
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1


# Training environment
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama2-finetune
log_every_n_steps: null
84 changes: 67 additions & 17 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface):
hood. Setting up the env variables is handled by TorchRun.
- Training happens on CUDA (CPU training is not supported)
- Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported.
- User can only use ONE of gradient accumulation or optimizer in backward. These features
currently do not work together.
- Datasets are Map-style and data fits in memory (not streamed).

The following configs can be used to run this recipe:
Expand All @@ -55,8 +57,9 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface):
cfg (DictConfig): OmegaConf object parsed from yaml file

Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``gradient_accumulation_steps > 1`` and ``optimizer_in_bwd`` is `True`.
"""

def __init__(self, cfg: DictConfig) -> None:
Expand All @@ -65,7 +68,7 @@ def __init__(self, cfg: DictConfig) -> None:
# Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor
# enabled necessary features such as gradient scaling.
if self._dtype == torch.float16:
raise ValueError(
raise RuntimeError(
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

Expand All @@ -84,7 +87,14 @@ def __init__(self, cfg: DictConfig) -> None:
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

self._optimizer_in_bwd = cfg.optimizer_in_bwd
# TODO: find a better place / way to perform validation of args that don't yet
# compose with each other.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after the dataclasses were removed we lost a convenient place to perform validation - my thinking is we provide a convenient utility function(s) that does this for users that they just import in their recipe, but curious about your thoughts

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be one approach, a possible downside is that the helper just becomes this monolithic dumping ground where we're checking various configs and it just becomes a large swath of if statements. On the other hand, if we don't do something like this then we'll just have it spelled out in each recipe which will increase code bloat and maintainence overhead (would have to copy things into each recipe)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some piece of this is just not trying to do too much in a single recipe, right? This is part of the reason we split single-device and distributed recipes to begin with.

We can still do config validation, just have to check that fields are defined instead of naively just checking values of fields. Personally I would be in favor of some config validation utilities defined on a per-recipe basis under configs/ somewhere, but coupled with clear documentation in the recipe class of which intersections of features are not supported.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. I've added some documentation in the recipe class for now. @RdoubleA , let's chat about Evan's suggestion here?

if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = utils.set_seed(seed=cfg.seed)
Expand Down Expand Up @@ -158,6 +168,7 @@ def setup(self, cfg: DictConfig) -> None:
# checkpoint. Transforming the opt state dict is handled by this method
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=cfg.optimizer_in_bwd,
opt_state_dict=(
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
Expand Down Expand Up @@ -221,18 +232,46 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
) -> Optimizer:
self,
cfg_optimizer: DictConfig,
optimizer_in_bwd: bool = False,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optional[Optimizer]:
"""
Set up the optimizer. This method also handles loading the optimizer state_dict, if specified.
"""
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())

if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)

log.info("Optimizer is initialized.")
return optimizer
if optimizer_in_bwd:
# Maintain a dict of optims for every parameter.
optim_dict = {
p: config.instantiate(cfg_optimizer, [p])
for p in self._model.parameters()
}
# Register optimizer step hooks on the model to run optimizer in backward.
utils.register_optim_in_bwd_hooks(model=self._model, optim_dict=optim_dict)
# Create a wrapper for checkpoint save/load of optimizer states when running in backward.
self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper(
model=self._model, optim_dict=optim_dict
)
# Load optimizer states. If optimizer states are being restored in an optimizer in backward
# run, these need to have been saved with the same setting. Cannot restore from runs that did not
# use optimizer in backward.
if opt_state_dict is not None:
try:
self._optim_ckpt_wrapper.load_state_dict(opt_state_dict)
except BaseException as e:
raise RuntimeError(
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
) from e
log.info("In-backward optimizers are set up.")
return None
else:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())

if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
log.info("Optimizer is initialized.")
return optimizer

def _setup_data(
self,
Expand Down Expand Up @@ -281,13 +320,16 @@ def save_checkpoint(self, epoch: int) -> None:
if epoch + 1 < self.total_epochs:
ckpt_dict.update(
{
utils.OPT_KEY: self._optimizer.state_dict(),
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
if not self._optimizer_in_bwd:
ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the APIs will be the same for the optim ckpt wrapper, you could just call it self._optimizer and remove the if else

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personally not an advocate for this due to reason explained in the other comment. If folks feel like this is better UX though, I'm happy to just add it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to agree with @RdoubleA. Definitely don't want to hide the actual behavior too much, but in this case we already have state_dict and other APIs defined, we might as well just cut down on branching (aside from where it's really needed, like in train). But honestly not taking a super strong stance here so fine either way

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - let's discuss in follow up PRs.

self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
Expand All @@ -311,8 +353,8 @@ def train(self) -> None:
``max_steps_per_epoch``.
"""
# zero out the gradients before starting training
self._optimizer.zero_grad()

if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
Expand Down Expand Up @@ -344,20 +386,28 @@ def train(self) -> None:
self._metric_logger.log_dict(
{
"loss": loss.item(),
"lr": self._optimizer.param_groups[0]["lr"],
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": (
self._optim_ckpt_wrapper.get_optim_key("lr")
if self._optimizer_in_bwd
else self._optimizer.param_groups[0]["lr"]
),
"gpu_resources": torch.cuda.memory_allocated(),
},
step=self.total_training_steps,
)

loss = loss / self._gradient_accumulation_steps
loss.backward()
if self._should_update_weights(idx):
if not self._optimizer_in_bwd and self._should_update_weights(idx):
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.total_training_steps += 1
elif self._optimizer_in_bwd:
self.total_training_steps += 1

# Log peak memory for iteration
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
Expand Down
7 changes: 5 additions & 2 deletions tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,18 @@ def _fetch_expected_loss_values(self):
return [10.5074, 10.5563, 10.5152, 10.4851]

@pytest.mark.integration_test
def test_loss(self, tmpdir, monkeypatch):
@pytest.mark.parametrize(
"config", ["full_single_device_low_memory", "full_single_device"]
)
def test_loss(self, config, tmpdir, monkeypatch):
ckpt = "small_test_ckpt_meta"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

cmd = f"""
tune full_finetune_single_device
--config llama2/7B_full_single_device \
--config llama2/7B_{config} \
output_dir={tmpdir} \
checkpointer._component_=torchtune.utils.FullModelMetaCheckpointer
checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand Down
91 changes: 91 additions & 0 deletions tests/torchtune/utils/test_optim_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from torchtune.utils import create_optim_in_bwd_wrapper, register_optim_in_bwd_hooks


def _run_dummy_step(model, wrapper):
with torch.no_grad():
for p in model.parameters():
p.grad = torch.rand_like(p)
for v in wrapper.optim_map.values():
v.step()
v.zero_grad()


def _validate_dicts(d1, d2):
if len(d1) != len(d2):
return False
for k, v in d1.items():
if k not in d2:
return False
if isinstance(v, dict):
return _validate_dicts(v, d2[k])
else:
if isinstance(v, torch.Tensor):
if not torch.allclose(v, d2[k]):
return False
elif v != d2[k]:
return False
return True


@pytest.fixture
def model():
return torch.nn.Linear(10, 1)


@pytest.fixture
def optim_dict(model):
return {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()}


@pytest.fixture
def wrapper(model, optim_dict):
return create_optim_in_bwd_wrapper(model, optim_dict)


class TestOptimInBackward:
def test_state_dict_save_load(self, model, wrapper):
# Run a dummy step to create optimizer states
_run_dummy_step(model, wrapper)

sd = wrapper.state_dict()
new_optim_dict = create_optim_in_bwd_wrapper(
model, {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()}
)
assert not _validate_dicts(sd, new_optim_dict.state_dict())
new_optim_dict.load_state_dict(sd)
assert _validate_dicts(sd, new_optim_dict.state_dict())

def test_missing_unexpected_param_load_raises(self, model, wrapper):
# Run a dummy step to create optimizer states
_run_dummy_step(model, wrapper)
sd = wrapper.state_dict()
new_optim_dict = create_optim_in_bwd_wrapper(
model, {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()}
)
with pytest.raises(RuntimeError, match="Expected to load optimizer state"):
sd.pop(next(iter(sd.keys())))
new_optim_dict.load_state_dict(sd)

sd = wrapper.state_dict()
sd["new_key"] = 1234
with pytest.raises(RuntimeError, match="unexpected param"):
new_optim_dict.load_state_dict(sd)


class TestRegisterOptimHooks:
def test_register_optim_in_bwd_hooks(self, model, optim_dict):
register_optim_in_bwd_hooks(model, optim_dict)
# Ensure backward() updates the parameters and sets grads to None
orig_params = [p.clone().detach() for p in model.parameters()]
model(torch.rand(2, 10)).sum().backward()
for p, orig_p in zip(model.parameters(), orig_params):
assert not p.grad
assert not torch.allclose(p, orig_p)
Loading