Skip to content

Commit 11a7fd3

Browse files
authored
align llama auto_parallel dataloader with manual_parallel (#8639)
1 parent 25d2140 commit 11a7fd3

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

llm/auto_parallel/llama/run_pretrain_auto.py

+18
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,24 @@ def _wrap_for_dist_loader(self, train_dataloader):
393393
dist_loader._input_keys = ["input_ids", "labels"]
394394
return dist_loader
395395

396+
def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
397+
if self.train_dataset is None:
398+
return None
399+
400+
total_batch_size_per_acc_step = self.args.per_device_train_batch_size * self.args.dataset_world_size
401+
total_batch_size = total_batch_size_per_acc_step
402+
403+
# In llm/llama/run_pretrain.py, it uses paddlenlp.utils.batch_sampler.DistributedBatchSampler,
404+
# which does no shuffle when shuffle is set True.
405+
sampler = paddle.io.BatchSampler(
406+
dataset=self.train_dataset,
407+
shuffle=False,
408+
batch_size=total_batch_size,
409+
drop_last=self.args.dataloader_drop_last,
410+
)
411+
sampler._acc_steps = self.args.gradient_accumulation_steps
412+
return sampler
413+
396414

397415
def print_config(args, key=""):
398416
"""

0 commit comments

Comments
 (0)