diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index de4ce14f26..b5ca7d8413 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -1,8 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import multiprocessing as mp import time -from typing import Any, Callable, Dict, Optional, Union - +from typing import Any, Callable, Dict, Optional, Union, List, Tuple import numpy as np from datasets import Dataset as HfDataset from torch.utils.data import Dataset, IterableDataset @@ -103,7 +102,8 @@ def __len__(self) -> int: class BasePackingDataset: - def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: int = 128, strict: bool = False): + def __init__(self, template, dataset, num_workers: int = 1, *, + packing_interval: Optional[int] = None, strict: bool = False): template._packing = True self.template = template self.dataset = dataset @@ -114,12 +114,29 @@ def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: self.workers = [] @staticmethod - def calculate_matched_group(template, sequences, is_finished: bool = True): + def _pack(sequences, max_length: int): + # Greedy algorithm: + # Traverse all the buckets and place them into a bucket that can accommodate them. + # If none of the existing buckets can accommodate, create a new bucket. + capacity = [] + buckets = [] + for item in sequences: + seq, seq_len = item + for i, capa in enumerate(capacity): + if capa >= seq_len: + capacity[i] = capa - seq_len + buckets[i].append(item) + break + else: + capacity.append(max_length - seq_len) + buckets.append([item]) + return buckets + + @staticmethod + def pack_sequences(template, sequences: List[Tuple[Dict[str, Any], int]], is_finished: bool = True): if len(sequences) == 0: return [], [] - # https://arxiv.org/pdf/2404.10830 - import binpacking - sequences = binpacking.to_constant_volume(sequences, template.max_length, weight_pos=1) + sequences = BasePackingDataset._pack(sequences, template.max_length) res = [] if sequences and not is_finished: sequences, ret_sequences = sequences[:-1], sequences[-1] @@ -142,7 +159,7 @@ def _encode_data(self, data): class PackingDataset(BasePackingDataset, Dataset): - def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: int = 128, strict: bool = False): + def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: Optional[int] = 1024, strict: bool = False): super().__init__(template, dataset, num_workers, packing_interval=packing_interval, strict=strict) self.prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc='Packing') self._queue = mp.Queue() @@ -160,7 +177,9 @@ def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: def fetch_packing_data(self, res: Optional[list] = None): res = res or [] - for _ in range(self.packing_interval): + i = 0 + while self.packing_interval is None or i < self.packing_interval: + i += 1 data = self._queue.get() if data is None: self._terminated_workers += 1 @@ -178,7 +197,7 @@ def get_packed_dataset(self): while True: data = self.fetch_packing_data(data) is_finished = self._terminated_workers == self.num_workers - res, data = self.calculate_matched_group(self.template, data, is_finished=is_finished) + res, data = self.pack_sequences(self.template, data, is_finished=is_finished) result += res if is_finished: break @@ -256,7 +275,7 @@ def __iter__(self): while True: self._put_data_in_queue(iterator) data = self._fetch_data_out_queue(data) - res, data = self.calculate_matched_group(self.template, data, is_finished=False) + res, data = self.pack_sequences(self.template, data, is_finished=False) yield from res