-
Notifications
You must be signed in to change notification settings - Fork 619
Full finetune < 16GB #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full finetune < 16GB #527
Changes from 17 commits
ea16903
2144c4f
0104e36
f7d1439
6479227
ff69f09
622d6fa
899fa97
7770168
ae4b943
2b5060f
eeb8edd
0fb1eba
0d85856
bb71c2e
57030f2
c4a58e2
e2e0725
61371b4
8b267dc
9f5ac0b
b3dea03
7fe9485
23b10f7
c911eed
7325133
0e2d4fc
98f6fcc
a4292f1
86839ad
d9f8975
9efe051
f26d34b
a3d850b
6207797
7597c68
6c9731d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import sys | ||
|
||
from functools import partial | ||
from typing import Any, Dict, Optional, Tuple | ||
from typing import Any, Dict, Optional, Tuple, Union | ||
from warnings import warn | ||
|
||
import torch | ||
|
@@ -80,11 +80,18 @@ def __init__(self, cfg: DictConfig) -> None: | |
# logging attributes | ||
self._output_dir = cfg.output_dir | ||
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 | ||
self._log_peak_memory_every_n_steps = 100 | ||
self._log_peak_memory_every_n_steps = 1 | ||
rohan-varma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Training cfg | ||
self._resume_from_checkpoint = cfg.resume_from_checkpoint | ||
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps | ||
|
||
# TODO: find a better place / way to perform validation of args that don't yet | ||
# compose with each other. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. after the dataclasses were removed we lost a convenient place to perform validation - my thinking is we provide a convenient utility function(s) that does this for users that they just import in their recipe, but curious about your thoughts There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That could be one approach, a possible downside is that the helper just becomes this monolithic dumping ground where we're checking various configs and it just becomes a large swath of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some piece of this is just not trying to do too much in a single recipe, right? This is part of the reason we split single-device and distributed recipes to begin with. We can still do config validation, just have to check that fields are defined instead of naively just checking values of fields. Personally I would be in favor of some config validation utilities defined on a per-recipe basis under configs/ somewhere, but coupled with clear documentation in the recipe class of which intersections of features are not supported. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes sense. I've added some documentation in the recipe class for now. @RdoubleA , let's chat about Evan's suggestion here? |
||
if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: | ||
raise ValueError( | ||
"Gradient accumulation is not supported with optimizer in bwd." | ||
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." | ||
) | ||
# These are public properties which are updated by the checkpoint loader | ||
# when ``resume_from_checkpoint`` is `True` or validated in tests | ||
self.seed = utils.set_seed(seed=cfg.seed) | ||
|
@@ -158,6 +165,7 @@ def setup(self, cfg: DictConfig) -> None: | |
# checkpoint. Transforming the opt state dict is handled by this method | ||
self._optimizer = self._setup_optimizer( | ||
cfg_optimizer=cfg.optimizer, | ||
optimizer_in_bwd=cfg.optimizer_in_bwd, | ||
opt_state_dict=( | ||
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None | ||
), | ||
|
@@ -221,18 +229,55 @@ def _setup_model( | |
return model | ||
|
||
def _setup_optimizer( | ||
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None | ||
) -> Optimizer: | ||
self, | ||
cfg_optimizer: DictConfig, | ||
optimizer_in_bwd: bool = False, | ||
opt_state_dict: Optional[Dict[str, Any]] = None, | ||
) -> Union[Optimizer, None]: | ||
rohan-varma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Set up the optimizer. This method also handles loading the optimizer state_dict, if specified. | ||
""" | ||
optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) | ||
|
||
if opt_state_dict: | ||
optimizer.load_state_dict(opt_state_dict) | ||
|
||
log.info("Optimizer is initialized.") | ||
return optimizer | ||
if optimizer_in_bwd: | ||
self._optimizer_in_bwd = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. set this in init? |
||
# Maintain a dict of optims for every parameter. | ||
# TODO (rohan-varma): check foreach arg | ||
rohan-varma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
optim_dict = { | ||
p: config.instantiate(cfg_optimizer, [p]) | ||
for p in self._model.parameters() | ||
} | ||
self._optim_ckpt_wrapper = utils.OptimizerInBackwardWrapper({ | ||
n: optim_dict[p] for n, p in self._model.named_parameters() | ||
}) | ||
def optim_step(param) -> None: | ||
optim_dict[param].step() | ||
optim_dict[param].zero_grad() | ||
|
||
for p in self._model.parameters(): | ||
p.register_post_accumulate_grad_hook(optim_step) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say put the optimizer in backward logic in a separate utility, it uses some non-intuitive logic that may confuse users There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree and see the reasoning in general, but IMO a downside of something like a |
||
|
||
# Load optimizer states. If optimizer states are being restored in an optimizer in backward | ||
# run, these need to have been saved with the same setting. Cannot restore from runs that did not | ||
# use optimizer in backward. | ||
if opt_state_dict is not None: | ||
try: | ||
self._optim_ckpt_wrapper.load_state_dict(opt_state_dict) | ||
print(f"RV: successfully loaded in backward state", flush=True) | ||
except BaseException as e: | ||
raise RuntimeError( | ||
"Failed loading in-backward optimizer checkpoints." | ||
"Please make sure run being restored from was using in-backward optimizer." | ||
f"Original error {str(e)}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually tried this, and it didn't - i.e. I didn't see the exception from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me continue checking this though |
||
) from e | ||
log.info("In-backward optimizers are set up.") | ||
return None | ||
else: | ||
self._optimizer_in_bwd = False | ||
optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) | ||
|
||
if opt_state_dict: | ||
optimizer.load_state_dict(opt_state_dict) | ||
log.info("Optimizer is initialized.") | ||
return optimizer | ||
|
||
def _setup_data( | ||
self, | ||
|
@@ -281,13 +326,16 @@ def save_checkpoint(self, epoch: int) -> None: | |
if epoch + 1 < self.total_epochs: | ||
ckpt_dict.update( | ||
{ | ||
utils.OPT_KEY: self._optimizer.state_dict(), | ||
utils.SEED_KEY: self.seed, | ||
utils.EPOCHS_KEY: self.epochs_run, | ||
utils.TOTAL_EPOCHS_KEY: self.total_epochs, | ||
utils.MAX_STEPS_KEY: self.max_steps_per_epoch, | ||
} | ||
) | ||
if not self._optimizer_in_bwd: | ||
ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() | ||
else: | ||
ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the APIs will be the same for the optim ckpt wrapper, you could just call it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm personally not an advocate for this due to reason explained in the other comment. If folks feel like this is better UX though, I'm happy to just add it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm inclined to agree with @RdoubleA. Definitely don't want to hide the actual behavior too much, but in this case we already have state_dict and other APIs defined, we might as well just cut down on branching (aside from where it's really needed, like in train). But honestly not taking a super strong stance here so fine either way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense - let's discuss in follow up PRs. |
||
self._checkpointer.save_checkpoint( | ||
ckpt_dict, | ||
epoch=epoch, | ||
|
@@ -311,8 +359,8 @@ def train(self) -> None: | |
``max_steps_per_epoch``. | ||
""" | ||
# zero out the gradients before starting training | ||
self._optimizer.zero_grad() | ||
|
||
if not self._optimizer_in_bwd: | ||
self._optimizer.zero_grad() | ||
# self.epochs_run should be non-zero when we're resuming from a checkpoint | ||
for curr_epoch in range(self.epochs_run, self.total_epochs): | ||
# Update the sampler to ensure data is correctly shuffled across epochs | ||
|
@@ -344,20 +392,24 @@ def train(self) -> None: | |
self._metric_logger.log_dict( | ||
{ | ||
"loss": loss.item(), | ||
"lr": self._optimizer.param_groups[0]["lr"], | ||
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently | ||
# true since we don't expose the ability to configure this yet. | ||
"lr": list(self._optim_ckpt_wrapper.optim_map.values())[0].param_groups[0]["lr"] if self._optimizer_in_bwd else self._optimizer.param_groups[0]["lr"], | ||
"gpu_resources": torch.cuda.memory_allocated(), | ||
}, | ||
step=self.total_training_steps, | ||
) | ||
|
||
loss = loss / self._gradient_accumulation_steps | ||
loss.backward() | ||
if self._should_update_weights(idx): | ||
if not self._optimizer_in_bwd and self._should_update_weights(idx): | ||
self._optimizer.step() | ||
self._optimizer.zero_grad(set_to_none=True) | ||
|
||
# Update the number of steps when the weights are updated | ||
self.total_training_steps += 1 | ||
elif self._optimizer_in_bwd: | ||
self.total_training_steps += 1 | ||
|
||
# Log peak memory for iteration | ||
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,59 @@ | |
class ModelType(Enum): | ||
LLAMA2 = "llama2" | ||
|
||
class OptimizerInBackwardWrapper: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @janeyx99 |
||
""" | ||
A bare-bones class meant for checkpoint save and load for optimizers running | ||
in backward. Usage is limited to the following: | ||
|
||
optim_dict = { | ||
p: config.instantiate(cfg_optimizer, [p]) | ||
for p in self._model.parameters() | ||
} | ||
# Save checkpoint | ||
ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict() | ||
torch.save("/tmp/optim_ckpt", ckpt) | ||
# Load checkpoint | ||
placeholder_optim_dict = { | ||
p: config.instantiate(cfg_optimizer, [p]) | ||
for p in self._model.parameters() | ||
} | ||
wrapper = OptimInBackwardWrapper(placeholder_optim_dict) | ||
# load_state_dict expects a dict produced by this class's | ||
# state_dict method. | ||
wrapper.load_state_dict(torch.load("/tmp/optim_ckpt")) | ||
# placeholder_optim_dict now has updated optimizer states. | ||
|
||
NOTE: This wrapper is only meant to be used for single-device use cases. | ||
Distributed use cases such as FSDP, which require specialized | ||
optimizer state checkpointing, are not supported. | ||
rohan-varma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
""" | ||
def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]): | ||
self.optim_map = optim_map | ||
|
||
def state_dict(self): | ||
return { | ||
p: opt.state_dict() for p, opt in self.optim_map.items() | ||
} | ||
|
||
|
||
def load_state_dict(self, optim_ckpt_map: Dict[str, Any]): | ||
params_covered = set() | ||
for param_name in optim_ckpt_map.keys(): | ||
if param_name not in self.optim_map: | ||
raise RuntimeError( | ||
f"Trying to load optimizer state for unexpected param {param_name}" | ||
) | ||
self.optim_map[param_name].load_state_dict(optim_ckpt_map[param_name]) | ||
params_covered.add(param_name) | ||
# Ensure all params have been loaded into, report missing params | ||
missing_params = set(self.optim_map.keys()) - params_covered | ||
if missing_params: | ||
raise RuntimeError( | ||
f"Expected to load optimizer state for params {missing_params}!" | ||
) | ||
|
||
|
||
def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is user facing, let's provide a real project name?