Skip to content

Full finetune < 16GB #527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ea16903
Upd
rohan-varma Mar 15, 2024
2144c4f
Merge branch 'main' of github.com:pytorch-labs/torchtune
rohan-varma Mar 16, 2024
0104e36
Upd
rohan-varma Mar 16, 2024
f7d1439
Merge branch 'main' of github.com:pytorch-labs/torchtune
rohan-varma Mar 18, 2024
6479227
Merge branch 'main' of github.com:pytorch-labs/torchtune
rohan-varma Mar 19, 2024
ff69f09
Merge branch 'main' of github.com:pytorch-labs/torchtune
rohan-varma Mar 19, 2024
622d6fa
Full FT mem efficiency
rohan-varma Mar 19, 2024
899fa97
Upd
rohan-varma Mar 19, 2024
7770168
upd
rohan-varma Mar 19, 2024
ae4b943
Upd
rohan-varma Mar 22, 2024
2b5060f
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 22, 2024
eeb8edd
Upd
rohan-varma Mar 22, 2024
0fb1eba
Upd
rohan-varma Mar 22, 2024
0d85856
Upd
rohan-varma Mar 22, 2024
bb71c2e
Upd
rohan-varma Mar 23, 2024
57030f2
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 25, 2024
c4a58e2
upd
rohan-varma Mar 25, 2024
e2e0725
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 25, 2024
61371b4
upd
rohan-varma Mar 25, 2024
8b267dc
upd
rohan-varma Mar 25, 2024
9f5ac0b
upd
rohan-varma Mar 25, 2024
b3dea03
Upd
rohan-varma Mar 25, 2024
7fe9485
upd
rohan-varma Mar 25, 2024
23b10f7
upd
rohan-varma Mar 25, 2024
c911eed
CI
rohan-varma Mar 25, 2024
7325133
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 26, 2024
0e2d4fc
UPd
rohan-varma Mar 26, 2024
98f6fcc
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 27, 2024
a4292f1
upd
rohan-varma Mar 27, 2024
86839ad
upd
rohan-varma Mar 27, 2024
d9f8975
WIP - testing
rohan-varma Mar 27, 2024
9efe051
upd
rohan-varma Mar 28, 2024
f26d34b
upd
rohan-varma Mar 28, 2024
a3d850b
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 28, 2024
6207797
Upd
rohan-varma Mar 28, 2024
7597c68
Merge branch 'main' of github.com:pytorch/torchtune into full_ft_mem
rohan-varma Mar 29, 2024
6c9731d
Upd
rohan-varma Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions recipes/configs/llama2/7B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 3
epochs: 1
optimizer:
_component_: torch.optim.SGD
_component_: bitsandbytes.optim.PagedAdamW
lr: 2e-5
optimizer_in_bwd: True
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
Expand All @@ -69,7 +70,7 @@ dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
_component_: torchtune.utils.metric_logging.WandBLogger
project: foo
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is user facing, let's provide a real project name?

output_dir: /tmp/alpaca-llama2-finetune
log_every_n_steps: null
84 changes: 68 additions & 16 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys

from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
from warnings import warn

import torch
Expand Down Expand Up @@ -80,11 +80,18 @@ def __init__(self, cfg: DictConfig) -> None:
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1
self._log_peak_memory_every_n_steps = 100
self._log_peak_memory_every_n_steps = 1
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

# TODO: find a better place / way to perform validation of args that don't yet
# compose with each other.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after the dataclasses were removed we lost a convenient place to perform validation - my thinking is we provide a convenient utility function(s) that does this for users that they just import in their recipe, but curious about your thoughts

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be one approach, a possible downside is that the helper just becomes this monolithic dumping ground where we're checking various configs and it just becomes a large swath of if statements. On the other hand, if we don't do something like this then we'll just have it spelled out in each recipe which will increase code bloat and maintainence overhead (would have to copy things into each recipe)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some piece of this is just not trying to do too much in a single recipe, right? This is part of the reason we split single-device and distributed recipes to begin with.

We can still do config validation, just have to check that fields are defined instead of naively just checking values of fields. Personally I would be in favor of some config validation utilities defined on a per-recipe basis under configs/ somewhere, but coupled with clear documentation in the recipe class of which intersections of features are not supported.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. I've added some documentation in the recipe class for now. @RdoubleA , let's chat about Evan's suggestion here?

if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
raise ValueError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = utils.set_seed(seed=cfg.seed)
Expand Down Expand Up @@ -158,6 +165,7 @@ def setup(self, cfg: DictConfig) -> None:
# checkpoint. Transforming the opt state dict is handled by this method
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=cfg.optimizer_in_bwd,
opt_state_dict=(
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
Expand Down Expand Up @@ -221,18 +229,55 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
) -> Optimizer:
self,
cfg_optimizer: DictConfig,
optimizer_in_bwd: bool = False,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Union[Optimizer, None]:
"""
Set up the optimizer. This method also handles loading the optimizer state_dict, if specified.
"""
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())

if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)

log.info("Optimizer is initialized.")
return optimizer
if optimizer_in_bwd:
self._optimizer_in_bwd = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set this in init?

# Maintain a dict of optims for every parameter.
# TODO (rohan-varma): check foreach arg
optim_dict = {
p: config.instantiate(cfg_optimizer, [p])
for p in self._model.parameters()
}
self._optim_ckpt_wrapper = utils.OptimizerInBackwardWrapper({
n: optim_dict[p] for n, p in self._model.named_parameters()
})
def optim_step(param) -> None:
optim_dict[param].step()
optim_dict[param].zero_grad()

for p in self._model.parameters():
p.register_post_accumulate_grad_hook(optim_step)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say put the optimizer in backward logic in a separate utility, it uses some non-intuitive logic that may confuse users

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and see the reasoning in general, but IMO a downside of something like a utils.setup_optim_in_backward(optim_config, model) is that there is a side-effect of hooks getting registered on the model. I want things that modify state to be as explicit as possible, so I could do something like register_optimizer_in_backward_hooks and make_optim_checkpoint_wrapper - more utilities / components than a monolithic thing that configures the entire optimizer in backward. @kartikayk , @ebsmothers what do you think?


# Load optimizer states. If optimizer states are being restored in an optimizer in backward
# run, these need to have been saved with the same setting. Cannot restore from runs that did not
# use optimizer in backward.
if opt_state_dict is not None:
try:
self._optim_ckpt_wrapper.load_state_dict(opt_state_dict)
print(f"RV: successfully loaded in backward state", flush=True)
except BaseException as e:
raise RuntimeError(
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
f"Original error {str(e)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the from e should take care of surfacing the original error

Copy link
Member Author

@rohan-varma rohan-varma Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried this, and it didn't - i.e. I didn't see the exception from e in this error. Not sure if I just messed something up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me continue checking this though

) from e
log.info("In-backward optimizers are set up.")
return None
else:
self._optimizer_in_bwd = False
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())

if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
log.info("Optimizer is initialized.")
return optimizer

def _setup_data(
self,
Expand Down Expand Up @@ -281,13 +326,16 @@ def save_checkpoint(self, epoch: int) -> None:
if epoch + 1 < self.total_epochs:
ckpt_dict.update(
{
utils.OPT_KEY: self._optimizer.state_dict(),
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
if not self._optimizer_in_bwd:
ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the APIs will be the same for the optim ckpt wrapper, you could just call it self._optimizer and remove the if else

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personally not an advocate for this due to reason explained in the other comment. If folks feel like this is better UX though, I'm happy to just add it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to agree with @RdoubleA. Definitely don't want to hide the actual behavior too much, but in this case we already have state_dict and other APIs defined, we might as well just cut down on branching (aside from where it's really needed, like in train). But honestly not taking a super strong stance here so fine either way

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - let's discuss in follow up PRs.

self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
Expand All @@ -311,8 +359,8 @@ def train(self) -> None:
``max_steps_per_epoch``.
"""
# zero out the gradients before starting training
self._optimizer.zero_grad()

if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
Expand Down Expand Up @@ -344,20 +392,24 @@ def train(self) -> None:
self._metric_logger.log_dict(
{
"loss": loss.item(),
"lr": self._optimizer.param_groups[0]["lr"],
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": list(self._optim_ckpt_wrapper.optim_map.values())[0].param_groups[0]["lr"] if self._optimizer_in_bwd else self._optimizer.param_groups[0]["lr"],
"gpu_resources": torch.cuda.memory_allocated(),
},
step=self.total_training_steps,
)

loss = loss / self._gradient_accumulation_steps
loss.backward()
if self._should_update_weights(idx):
if not self._optimizer_in_bwd and self._should_update_weights(idx):
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.total_training_steps += 1
elif self._optimizer_in_bwd:
self.total_training_steps += 1

# Log peak memory for iteration
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
Expand Down
2 changes: 2 additions & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
FullModelMetaCheckpointer,
FullModelTorchTuneCheckpointer,
ModelType,
OptimizerInBackwardWrapper,
)
from ._device import get_device
from ._distributed import ( # noqa
Expand Down Expand Up @@ -77,4 +78,5 @@
"validate_expected_param_dtype",
"TuneArgumentParser",
"CheckpointableDataLoader",
"OptimizerInBackwardWrapper",
]
2 changes: 1 addition & 1 deletion torchtune/utils/_checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
FullModelMetaCheckpointer,
FullModelTorchTuneCheckpointer,
)
from ._checkpointer_utils import ModelType # noqa
from ._checkpointer_utils import ModelType, OptimizerInBackwardWrapper # noqa
53 changes: 53 additions & 0 deletions torchtune/utils/_checkpointing/_checkpointer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,59 @@
class ModelType(Enum):
LLAMA2 = "llama2"

class OptimizerInBackwardWrapper:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
A bare-bones class meant for checkpoint save and load for optimizers running
in backward. Usage is limited to the following:

optim_dict = {
p: config.instantiate(cfg_optimizer, [p])
for p in self._model.parameters()
}
# Save checkpoint
ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict()
torch.save("/tmp/optim_ckpt", ckpt)
# Load checkpoint
placeholder_optim_dict = {
p: config.instantiate(cfg_optimizer, [p])
for p in self._model.parameters()
}
wrapper = OptimInBackwardWrapper(placeholder_optim_dict)
# load_state_dict expects a dict produced by this class's
# state_dict method.
wrapper.load_state_dict(torch.load("/tmp/optim_ckpt"))
# placeholder_optim_dict now has updated optimizer states.

NOTE: This wrapper is only meant to be used for single-device use cases.
Distributed use cases such as FSDP, which require specialized
optimizer state checkpointing, are not supported.

"""
def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]):
self.optim_map = optim_map

def state_dict(self):
return {
p: opt.state_dict() for p, opt in self.optim_map.items()
}


def load_state_dict(self, optim_ckpt_map: Dict[str, Any]):
params_covered = set()
for param_name in optim_ckpt_map.keys():
if param_name not in self.optim_map:
raise RuntimeError(
f"Trying to load optimizer state for unexpected param {param_name}"
)
self.optim_map[param_name].load_state_dict(optim_ckpt_map[param_name])
params_covered.add(param_name)
# Ensure all params have been loaded into, report missing params
missing_params = set(self.optim_map.keys()) - params_covered
if missing_params:
raise RuntimeError(
f"Expected to load optimizer state for params {missing_params}!"
)


def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path:
"""
Expand Down
1 change: 0 additions & 1 deletion torchtune/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TOTAL_EPOCHS_KEY,
)


def _contains_fsdp(model: nn.Module) -> bool:
"""
Checks if the model contains FSDP.
Expand Down