Skip to content

update packing #3751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions swift/llm/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import multiprocessing as mp
import time
from typing import Any, Callable, Dict, Optional, Union

from typing import Any, Callable, Dict, Optional, Union, List, Tuple
import numpy as np
from datasets import Dataset as HfDataset
from torch.utils.data import Dataset, IterableDataset
Expand Down Expand Up @@ -103,7 +102,8 @@ def __len__(self) -> int:

class BasePackingDataset:

def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: int = 128, strict: bool = False):
def __init__(self, template, dataset, num_workers: int = 1, *,
packing_interval: Optional[int] = None, strict: bool = False):
template._packing = True
self.template = template
self.dataset = dataset
Expand All @@ -114,12 +114,29 @@ def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval:
self.workers = []

@staticmethod
def calculate_matched_group(template, sequences, is_finished: bool = True):
def _pack(sequences, max_length: int):
# Greedy algorithm:
# Traverse all the buckets and place them into a bucket that can accommodate them.
# If none of the existing buckets can accommodate, create a new bucket.
capacity = []
buckets = []
for item in sequences:
seq, seq_len = item
for i, capa in enumerate(capacity):
if capa >= seq_len:
capacity[i] = capa - seq_len
buckets[i].append(item)
break
else:
capacity.append(max_length - seq_len)
buckets.append([item])
return buckets

@staticmethod
def pack_sequences(template, sequences: List[Tuple[Dict[str, Any], int]], is_finished: bool = True):
if len(sequences) == 0:
return [], []
# https://arxiv.org/pdf/2404.10830
import binpacking
sequences = binpacking.to_constant_volume(sequences, template.max_length, weight_pos=1)
sequences = BasePackingDataset._pack(sequences, template.max_length)
res = []
if sequences and not is_finished:
sequences, ret_sequences = sequences[:-1], sequences[-1]
Expand All @@ -142,7 +159,7 @@ def _encode_data(self, data):

class PackingDataset(BasePackingDataset, Dataset):

def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: int = 128, strict: bool = False):
def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: Optional[int] = 1024, strict: bool = False):
super().__init__(template, dataset, num_workers, packing_interval=packing_interval, strict=strict)
self.prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc='Packing')
self._queue = mp.Queue()
Expand All @@ -160,7 +177,9 @@ def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval:

def fetch_packing_data(self, res: Optional[list] = None):
res = res or []
for _ in range(self.packing_interval):
i = 0
while self.packing_interval is None or i < self.packing_interval:
i += 1
data = self._queue.get()
if data is None:
self._terminated_workers += 1
Expand All @@ -178,7 +197,7 @@ def get_packed_dataset(self):
while True:
data = self.fetch_packing_data(data)
is_finished = self._terminated_workers == self.num_workers
res, data = self.calculate_matched_group(self.template, data, is_finished=is_finished)
res, data = self.pack_sequences(self.template, data, is_finished=is_finished)
result += res
if is_finished:
break
Expand Down Expand Up @@ -256,7 +275,7 @@ def __iter__(self):
while True:
self._put_data_in_queue(iterator)
data = self._fetch_data_out_queue(data)
res, data = self.calculate_matched_group(self.template, data, is_finished=False)
res, data = self.pack_sequences(self.template, data, is_finished=False)
yield from res


Expand Down
Loading