29
29
import sys
30
30
import time
31
31
import warnings
32
+ import torch_xla .debug .profiler as xp
32
33
from collections .abc import Mapping
33
34
from pathlib import Path
35
+ from threading import Thread
34
36
from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple , Union
35
37
36
38
162
164
import datasets
163
165
164
166
if is_torch_tpu_available (check_device = False ):
167
+ import torch_xla
165
168
import torch_xla .core .xla_model as xm
166
169
import torch_xla .debug .metrics as met
167
170
@@ -838,7 +841,8 @@ def get_train_dataloader(self) -> DataLoader:
838
841
dataloader_params ["drop_last" ] = self .args .dataloader_drop_last
839
842
dataloader_params ["worker_init_fn" ] = seed_worker
840
843
841
- return self .accelerator .prepare (DataLoader (train_dataset , ** dataloader_params ))
844
+ # TODO(jonbolin): Disabling Accelerate on the dataloader (`Unknown device SPMD:0`)
845
+ return DataLoader (train_dataset , ** dataloader_params )
842
846
843
847
def _get_eval_sampler (self , eval_dataset : Dataset ) -> Optional [torch .utils .data .Sampler ]:
844
848
# Deprecated code
@@ -1444,6 +1448,21 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
1444
1448
1445
1449
return model
1446
1450
1451
+ def _xla_sharded_dataloader (self , dataloader ):
1452
+ if is_torch_tpu_available ():
1453
+ sharding_spec = None
1454
+ if self .args .spmd_batch_sharding :
1455
+ import torch_xla .experimental .xla_sharding as xs
1456
+ import torch_xla .runtime as xr
1457
+ import torch_xla .distributed .parallel_loader as pl
1458
+ num_devices = xr .global_device_count ()
1459
+ device_ids = np .arange (num_devices )
1460
+ mesh = xs .Mesh (device_ids , (num_devices , 1 ))
1461
+ sharding_spec = xs .ShardingSpec (mesh , (0 , 1 ))
1462
+ return pl .MpDeviceLoader (dataloader , self .args .device , input_sharding = sharding_spec , loader_prefetch_size = self .args .train_batch_size , device_prefetch_size = 4 )
1463
+ else :
1464
+ return dataloader
1465
+
1447
1466
def train (
1448
1467
self ,
1449
1468
resume_from_checkpoint : Optional [Union [str , bool ]] = None ,
@@ -1537,7 +1556,7 @@ def _inner_training_loop(
1537
1556
self ._train_batch_size = batch_size
1538
1557
logger .debug (f"Currently training with a batch size of: { self ._train_batch_size } " )
1539
1558
# Data loader and number of training steps
1540
- train_dataloader = self .get_train_dataloader ()
1559
+ train_dataloader = self ._xla_sharded_dataloader ( self . get_train_dataloader () )
1541
1560
1542
1561
# Setting up training control variables:
1543
1562
# number of training epochs: num_train_epochs
@@ -1771,7 +1790,13 @@ def _inner_training_loop(
1771
1790
rng_to_sync = True
1772
1791
1773
1792
step = - 1
1793
+ profile_step = int (os .environ .get ('PROFILE_STEP' , - 1 ))
1794
+ profile_epoch = int (os .environ .get ('PROFILE_EPOCH' , - 1 ))
1795
+ profile_duration = int (os .environ .get ('PROFILE_DURATION_MS' , 20000 ))
1796
+ profile_logdir = os .environ .get ('PROFILE_LOGDIR' , None )
1774
1797
for step , inputs in enumerate (epoch_iterator ):
1798
+ if step == 0 and epoch == 0 :
1799
+ print ('input sharding' , {k : (v .shape , torch_xla ._XLAC ._get_xla_sharding_spec (v )) for k , v in inputs .items ()})
1775
1800
total_batched_samples += 1
1776
1801
if rng_to_sync :
1777
1802
self ._load_rng_state (resume_from_checkpoint )
@@ -1792,6 +1817,10 @@ def _inner_training_loop(
1792
1817
if step % args .gradient_accumulation_steps == 0 :
1793
1818
self .control = self .callback_handler .on_step_begin (args , self .state , self .control )
1794
1819
1820
+ if step == profile_step and epoch == profile_epoch :
1821
+ trace = lambda : xp .trace ('127.0.0.1:9012' , profile_logdir or tempfile .mkdtemp (), profile_duration or 20000 )
1822
+ Thread (target = trace ).start ()
1823
+
1795
1824
with self .accelerator .accumulate (model ):
1796
1825
tr_loss_step = self .training_step (model , inputs )
1797
1826
@@ -2199,7 +2228,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
2199
2228
self .log (logs )
2200
2229
2201
2230
metrics = None
2202
- if self .control .should_evaluate :
2231
+ # TODO(jonbolin): Disabling eval loop
2232
+ if False : # self.control.should_evaluate:
2203
2233
if isinstance (self .eval_dataset , dict ):
2204
2234
metrics = {}
2205
2235
for eval_dataset_name , eval_dataset in self .eval_dataset .items ():
@@ -2914,7 +2944,7 @@ def evaluate(
2914
2944
# memory metrics - must set up as early as possible
2915
2945
self ._memory_tracker .start ()
2916
2946
2917
- eval_dataloader = self .get_eval_dataloader (eval_dataset )
2947
+ eval_dataloader = self ._xla_sharded_dataloader ( self . get_eval_dataloader (eval_dataset ) )
2918
2948
start_time = time .time ()
2919
2949
2920
2950
eval_loop = self .prediction_loop if self .args .use_legacy_prediction_loop else self .evaluation_loop
0 commit comments