Skip to content

Commit e7ea6ea

Browse files
authored
Merge pull request huggingface#13 from pytorch-tpu/jonbolin-llama-spmd
Support LLaMA training through SPMD
2 parents 6112b1c + 7b90ea5 commit e7ea6ea

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

examples/pytorch/language-modeling/run_clm.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import torch
3535
from datasets import load_dataset
3636

37+
import torch_xla.debug.profiler as xp
3738
import transformers
3839
from transformers import (
3940
CONFIG_MAPPING,
@@ -139,6 +140,30 @@ class ModelArguments:
139140
)
140141
},
141142
)
143+
spmd_grad_chkpt: bool = field(
144+
default=False,
145+
metadata={
146+
"help": (
147+
"Apply gradient checkpointing to the model"
148+
)
149+
},
150+
)
151+
spmd_fsdp_sharding: bool = field(
152+
default=False,
153+
metadata={
154+
"help": (
155+
"Will apply XLA SPMD to run FSDP"
156+
)
157+
},
158+
)
159+
spmd_batch_sharding: bool = field(
160+
default=False,
161+
metadata={
162+
"help": (
163+
"Will apply XLA SPMD to shard the input along the batch dimension"
164+
)
165+
},
166+
)
142167

143168
def __post_init__(self):
144169
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
@@ -238,6 +263,9 @@ def main():
238263
else:
239264
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
240265

266+
training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding
267+
training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding
268+
241269
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
242270
# information sent is the one passed as arguments along with your Python/PyTorch versions.
243271
send_example_telemetry("run_clm", model_args, data_args)
@@ -285,6 +313,10 @@ def main():
285313
# Set seed before initializing model.
286314
set_seed(training_args.seed)
287315

316+
server = xp.start_server(9012)
317+
logger.info('Profiling server started: {str(server)}')
318+
319+
288320
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
289321
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
290322
# (the dataset will be downloaded automatically from the datasets Hub).
@@ -430,6 +462,36 @@ def main():
430462
if len(tokenizer) > embedding_size:
431463
model.resize_token_embeddings(len(tokenizer))
432464

465+
import torch_xla.core.xla_model as xm
466+
import torch_xla.experimental.xla_sharding as xs
467+
import torch_xla.runtime as xr
468+
num_devices = xr.global_device_count()
469+
device_ids = torch.arange(num_devices)
470+
print('Using dtype', model_args.torch_dtype)
471+
model = model.to(xm.xla_device(), dtype=getattr(torch, model_args.torch_dtype))
472+
473+
if model_args.spmd_grad_chkpt:
474+
print("Applying gradient checkpointing")
475+
from torch_xla.distributed.fsdp import checkpoint_module
476+
for i, block in enumerate(model.model.layers):
477+
# LLaMA-specific
478+
model.model.layers[i] = checkpoint_module(block)
479+
480+
if model_args.spmd_fsdp_sharding:
481+
print('Applying FSDP sharding to all parameters')
482+
for name, param in model.named_parameters():
483+
# Shard all parameters along a single axis
484+
print('> Sharding tensor', name)
485+
486+
# Shard along the largest dimension
487+
import numpy as np
488+
max_dim = np.argmax(param.shape)
489+
shape = [1] * len(param.shape)
490+
shape[max_dim] = num_devices
491+
mesh = xs.HybridMesh(ici_mesh_shape=tuple(shape))
492+
xs.mark_sharding(param, mesh, range(len(param.shape)))
493+
494+
433495
# Preprocessing the datasets.
434496
# First we tokenize all the texts.
435497
if training_args.do_train:

src/transformers/trainer.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
import sys
3030
import time
3131
import warnings
32+
import torch_xla.debug.profiler as xp
3233
from collections.abc import Mapping
3334
from pathlib import Path
35+
from threading import Thread
3436
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
3537

3638

@@ -162,6 +164,7 @@
162164
import datasets
163165

164166
if is_torch_tpu_available(check_device=False):
167+
import torch_xla
165168
import torch_xla.core.xla_model as xm
166169
import torch_xla.debug.metrics as met
167170

@@ -838,7 +841,8 @@ def get_train_dataloader(self) -> DataLoader:
838841
dataloader_params["drop_last"] = self.args.dataloader_drop_last
839842
dataloader_params["worker_init_fn"] = seed_worker
840843

841-
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
844+
# TODO(jonbolin): Disabling Accelerate on the dataloader (`Unknown device SPMD:0`)
845+
return DataLoader(train_dataset, **dataloader_params)
842846

843847
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
844848
# Deprecated code
@@ -1444,6 +1448,21 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
14441448

14451449
return model
14461450

1451+
def _xla_sharded_dataloader(self, dataloader):
1452+
if is_torch_tpu_available():
1453+
sharding_spec = None
1454+
if self.args.spmd_batch_sharding:
1455+
import torch_xla.experimental.xla_sharding as xs
1456+
import torch_xla.runtime as xr
1457+
import torch_xla.distributed.parallel_loader as pl
1458+
num_devices = xr.global_device_count()
1459+
device_ids = np.arange(num_devices)
1460+
mesh = xs.Mesh(device_ids, (num_devices, 1))
1461+
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
1462+
return pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, loader_prefetch_size=self.args.train_batch_size, device_prefetch_size=4)
1463+
else:
1464+
return dataloader
1465+
14471466
def train(
14481467
self,
14491468
resume_from_checkpoint: Optional[Union[str, bool]] = None,
@@ -1537,7 +1556,7 @@ def _inner_training_loop(
15371556
self._train_batch_size = batch_size
15381557
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
15391558
# Data loader and number of training steps
1540-
train_dataloader = self.get_train_dataloader()
1559+
train_dataloader = self._xla_sharded_dataloader(self.get_train_dataloader())
15411560

15421561
# Setting up training control variables:
15431562
# number of training epochs: num_train_epochs
@@ -1771,7 +1790,13 @@ def _inner_training_loop(
17711790
rng_to_sync = True
17721791

17731792
step = -1
1793+
profile_step = int(os.environ.get('PROFILE_STEP', -1))
1794+
profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1))
1795+
profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
1796+
profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
17741797
for step, inputs in enumerate(epoch_iterator):
1798+
if step == 0 and epoch == 0:
1799+
print('input sharding', {k: (v.shape, torch_xla._XLAC._get_xla_sharding_spec(v)) for k, v in inputs.items()})
17751800
total_batched_samples += 1
17761801
if rng_to_sync:
17771802
self._load_rng_state(resume_from_checkpoint)
@@ -1792,6 +1817,10 @@ def _inner_training_loop(
17921817
if step % args.gradient_accumulation_steps == 0:
17931818
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
17941819

1820+
if step == profile_step and epoch == profile_epoch:
1821+
trace = lambda: xp.trace('127.0.0.1:9012', profile_logdir or tempfile.mkdtemp(), profile_duration or 20000)
1822+
Thread(target=trace).start()
1823+
17951824
with self.accelerator.accumulate(model):
17961825
tr_loss_step = self.training_step(model, inputs)
17971826

@@ -2199,7 +2228,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
21992228
self.log(logs)
22002229

22012230
metrics = None
2202-
if self.control.should_evaluate:
2231+
# TODO(jonbolin): Disabling eval loop
2232+
if False: # self.control.should_evaluate:
22032233
if isinstance(self.eval_dataset, dict):
22042234
metrics = {}
22052235
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
@@ -2914,7 +2944,7 @@ def evaluate(
29142944
# memory metrics - must set up as early as possible
29152945
self._memory_tracker.start()
29162946

2917-
eval_dataloader = self.get_eval_dataloader(eval_dataset)
2947+
eval_dataloader = self._xla_sharded_dataloader(self.get_eval_dataloader(eval_dataset))
29182948
start_time = time.time()
29192949

29202950
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop

0 commit comments

Comments
 (0)