Skip to content

Commit 503541d

Browse files
authored
add FlashAttentionKwargs and seq_idx to flat collator (#36456)
* add flash attn kwargs to flattening collator * add return_seq_idx option * doc string edits * cleaner max len updates * various fixes * temp testing code * return int32 seq_idx and FlashAttnKwargs * DataCollatorIntegrationTest impl * fix batch dims and dtypes * fill out remaining collator tests * test name change and fmt * rm unused var * fmt * minor change * fmt * add missing pos_ids check * consistent {np,pt,tf} tests * split pt tests into 3, like np/tf tests * mv comment, rename fa test * remove batch dim comment * simply wrapping * compute cu_seq_len/max_length once * fmt * remove tf code * rm warning * move separator_id back to 2nd pos * use cleaner lists in tests * ret -> batch * fmt * attr ordering * use py ints for max_length_{k,q}
1 parent 9ddcf5f commit 503541d

File tree

3 files changed

+318
-13
lines changed

3 files changed

+318
-13
lines changed

src/transformers/data/data_collator.py

+63-11
Original file line numberDiff line numberDiff line change
@@ -1974,9 +1974,11 @@ class DataCollatorWithFlattening(DefaultDataCollator):
19741974
"""
19751975
Data collator used for padding free approach. Does the following:
19761976
1977-
- concatate the entire mini batch into single long sequence [1, total_tokens]
1977+
- concatenates the entire mini batch into single long sequence of shape [1, total_tokens]
19781978
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
1979-
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
1979+
- no padding will be added, returns `input_ids`, `labels` and `position_ids` by default
1980+
- optionally returns the kwargs contained in FlashAttentionKwargs
1981+
- optionally returns seq_idx indicating which sequence each token belongs to
19801982
19811983
<Tip warning={true}>
19821984
@@ -1986,26 +1988,76 @@ class DataCollatorWithFlattening(DefaultDataCollator):
19861988
</Tip>
19871989
"""
19881990

1989-
def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
1991+
def __init__(
1992+
self,
1993+
*args,
1994+
return_position_ids=True,
1995+
separator_id=-100,
1996+
return_flash_attn_kwargs=False,
1997+
return_seq_idx=False,
1998+
**kwargs,
1999+
):
19902000
super().__init__(*args, **kwargs)
19912001
self.return_position_ids = return_position_ids
19922002
self.separator_id = separator_id
2003+
self.return_flash_attn_kwargs = return_flash_attn_kwargs
2004+
self.return_seq_idx = return_seq_idx
2005+
self._int_64_keys = {"labels", "position_ids", "input_ids"}
2006+
self._batch_dim_keys = {"labels", "position_ids", "input_ids", "seq_idx"}
2007+
self._py_int_keys = {"max_length_q", "max_length_k"}
19932008

19942009
def __call__(self, features, return_tensors=None, separator_id=None):
19952010
if return_tensors is None:
19962011
return_tensors = self.return_tensors
19972012
if separator_id is None:
19982013
separator_id = self.separator_id
19992014
is_labels_provided = "labels" in features[0]
2000-
ret = {"input_ids": [], "labels": []}
2015+
batch = {"input_ids": [], "labels": []}
20012016
if self.return_position_ids:
2002-
ret.update({"position_ids": []})
2003-
for idx in range(0, len(features)):
2004-
ret["input_ids"] += features[idx]["input_ids"]
2017+
batch.update({"position_ids": []})
2018+
if self.return_seq_idx:
2019+
batch.update({"seq_idx": []})
2020+
if self.return_flash_attn_kwargs:
2021+
cu_seq_lens = [0]
2022+
max_length = 0
2023+
for seq_idx, sample in enumerate(features):
2024+
input_ids = sample["input_ids"]
2025+
batch["input_ids"] += input_ids
20052026
if is_labels_provided:
2006-
ret["labels"] += [separator_id] + features[idx]["labels"][1:]
2027+
batch["labels"] += [separator_id] + sample["labels"][1:]
20072028
else:
2008-
ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
2029+
batch["labels"] += [separator_id] + input_ids[1:]
20092030
if self.return_position_ids:
2010-
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
2011-
return default_data_collator([ret], return_tensors)
2031+
batch["position_ids"] += list(range(len(input_ids)))
2032+
if self.return_seq_idx:
2033+
batch["seq_idx"] += [seq_idx for _ in range(len(input_ids))]
2034+
if self.return_flash_attn_kwargs:
2035+
cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids))
2036+
max_length = max(max_length, len(input_ids))
2037+
2038+
if self.return_flash_attn_kwargs:
2039+
batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens
2040+
batch["max_length_q"] = batch["max_length_k"] = max_length
2041+
2042+
# FlashAttentionKwargs and seq_idx are expected to be int32s.
2043+
if return_tensors == "pt":
2044+
import torch
2045+
2046+
data_cls = torch.tensor
2047+
dtype_64 = torch.int64
2048+
dtype_32 = torch.int32
2049+
elif return_tensors == "np":
2050+
data_cls = np.array
2051+
dtype_64 = np.int64
2052+
dtype_32 = np.int32
2053+
else:
2054+
raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not suported')
2055+
2056+
for k, v in batch.items():
2057+
if k in self._batch_dim_keys:
2058+
v = [v]
2059+
# Flash attention max_len_{q,k} are python ints
2060+
if k not in self._py_int_keys:
2061+
batch[k] = data_cls(v, dtype=dtype_64 if k in self._int_64_keys else dtype_32)
2062+
2063+
return batch

tests/test_modeling_common.py

+73
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
AutoModel,
3535
AutoModelForCausalLM,
3636
AutoModelForSequenceClassification,
37+
DataCollatorWithFlattening,
3738
PretrainedConfig,
3839
PreTrainedModel,
3940
is_torch_available,
@@ -4170,6 +4171,78 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
41704171
tol = torch.finfo(torch.float16).eps
41714172
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
41724173

4174+
@require_flash_attn
4175+
@require_torch_gpu
4176+
@mark.flash_attn_test
4177+
@slow
4178+
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
4179+
if not self.has_attentions:
4180+
self.skipTest(reason="Model architecture does not support attentions")
4181+
4182+
max_new_tokens = 30
4183+
4184+
for model_class in self.all_generative_model_classes:
4185+
if not model_class._supports_flash_attn_2:
4186+
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
4187+
4188+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
4189+
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
4190+
self.skipTest("Model dummy inputs should contain padding in their attention mask")
4191+
4192+
dummy_input = inputs_dict[model_class.main_input_name]
4193+
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
4194+
dummy_input = dummy_input.to(torch.float16)
4195+
4196+
# make sure that all models have enough positions for generation
4197+
if hasattr(config, "max_position_embeddings"):
4198+
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
4199+
4200+
model = model_class(config)
4201+
if "position_ids" not in inspect.signature(model.forward).parameters:
4202+
self.skipTest("Model does not support position_ids")
4203+
4204+
with tempfile.TemporaryDirectory() as tmpdirname:
4205+
model.save_pretrained(tmpdirname)
4206+
4207+
# ensure left padding, to adapt for some models
4208+
if 0 in inputs_dict["attention_mask"][:, -1]:
4209+
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
4210+
dummy_attention_mask = inputs_dict["attention_mask"]
4211+
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
4212+
4213+
model = (
4214+
model_class.from_pretrained(
4215+
tmpdirname,
4216+
torch_dtype=torch.float16,
4217+
attn_implementation="flash_attention_2",
4218+
low_cpu_mem_usage=True,
4219+
)
4220+
.to(torch_device)
4221+
.eval()
4222+
)
4223+
4224+
# flatten
4225+
features = [
4226+
{"input_ids": i[a.bool()].tolist()}
4227+
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
4228+
]
4229+
4230+
# add position_ids + fa_kwargs
4231+
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
4232+
batch = data_collator(features)
4233+
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
4234+
4235+
res_padded = model(**inputs_dict)
4236+
res_padfree = model(**batch_cuda)
4237+
4238+
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
4239+
logits_padfree = res_padfree.logits[0]
4240+
4241+
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
4242+
# acceptable numerical instability
4243+
tol = torch.finfo(torch.float16).eps
4244+
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
4245+
41734246
@require_flash_attn
41744247
@require_torch_gpu
41754248
@mark.flash_attn_test

tests/trainer/test_data_collator.py

+182-2
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,104 @@ def test_data_collator_with_padding(self):
126126
batch = data_collator(features)
127127
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
128128

129+
def test_data_collator_with_flattening(self):
130+
features = [
131+
{"input_ids": [10, 11, 12]},
132+
{"input_ids": [20, 21, 22, 23, 24, 25]},
133+
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
134+
]
135+
136+
data_collator = DataCollatorWithFlattening(return_tensors="pt")
137+
batch = data_collator(features)
138+
139+
for unexpected_key in [
140+
"attention_mask",
141+
"cu_seq_lens_k",
142+
"cu_seq_lens_q",
143+
"max_length_k",
144+
"max_length_q",
145+
"seq_idx",
146+
]:
147+
self.assertNotIn(unexpected_key, batch)
148+
self.assertIn("position_ids", batch)
149+
150+
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
151+
self.assertEqual(
152+
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
153+
)
154+
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
155+
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
156+
157+
def test_data_collator_with_flattening_flash_attn_kwargs(self):
158+
features = [
159+
{"input_ids": [10, 11, 12]},
160+
{"input_ids": [20, 21, 22, 23, 24, 25]},
161+
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
162+
]
163+
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
164+
batch = data_collator(features)
165+
166+
for unexpected_key in [
167+
"attention_mask",
168+
"seq_idx",
169+
]:
170+
self.assertNotIn(unexpected_key, batch)
171+
for expected_key in [
172+
"position_ids",
173+
"cu_seq_lens_k",
174+
"cu_seq_lens_q",
175+
"max_length_k",
176+
"max_length_q",
177+
]:
178+
self.assertIn(expected_key, batch)
179+
180+
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
181+
self.assertEqual(
182+
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
183+
)
184+
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
185+
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
186+
187+
self.assertEqual(batch["cu_seq_lens_k"].shape, torch.Size([4]))
188+
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
189+
self.assertEqual(batch["cu_seq_lens_q"].shape, torch.Size([4]))
190+
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
191+
# The flash attn max_length_{k,q} are simple python ints
192+
self.assertEqual(batch["max_length_k"], 7)
193+
self.assertEqual(batch["max_length_q"], 7)
194+
195+
def test_data_collator_with_flattening_seq_idx(self):
196+
features = [
197+
{"input_ids": [10, 11, 12]},
198+
{"input_ids": [20, 21, 22, 23, 24, 25]},
199+
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
200+
]
201+
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_seq_idx=True)
202+
batch = data_collator(features)
203+
204+
for unexpected_key in [
205+
"attention_mask",
206+
"cu_seq_lens_k",
207+
"cu_seq_lens_q",
208+
"max_length_k",
209+
"max_length_q",
210+
]:
211+
self.assertNotIn(unexpected_key, batch)
212+
for expected_key in [
213+
"position_ids",
214+
"seq_idx",
215+
]:
216+
self.assertIn(expected_key, batch)
217+
218+
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
219+
self.assertEqual(
220+
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
221+
)
222+
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
223+
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
224+
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
225+
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
226+
129227
def test_data_collator_for_token_classification(self):
130228
tokenizer = BertTokenizer(self.vocab_file)
131229
features = [
@@ -1803,14 +1901,96 @@ def test_data_collator_with_flattening(self):
18031901

18041902
data_collator = DataCollatorWithFlattening(return_tensors="np")
18051903
batch = data_collator(features)
1904+
1905+
for unexpected_key in [
1906+
"attention_mask",
1907+
"cu_seq_lens_k",
1908+
"cu_seq_lens_q",
1909+
"max_length_k",
1910+
"max_length_q",
1911+
"seq_idx",
1912+
]:
1913+
self.assertNotIn(unexpected_key, batch)
1914+
self.assertIn("position_ids", batch)
1915+
1916+
self.assertEqual(batch["input_ids"].shape, (1, 16))
1917+
self.assertEqual(
1918+
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
1919+
)
1920+
self.assertEqual(batch["position_ids"].shape, (1, 16))
1921+
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
1922+
1923+
def test_data_collator_with_flattening_flash_attn_kwargs(self):
1924+
features = [
1925+
{"input_ids": [10, 11, 12]},
1926+
{"input_ids": [20, 21, 22, 23, 24, 25]},
1927+
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
1928+
]
1929+
1930+
data_collator = DataCollatorWithFlattening(return_tensors="np", return_flash_attn_kwargs=True)
1931+
batch = data_collator(features)
1932+
1933+
for unexpected_key in [
1934+
"attention_mask",
1935+
"seq_idx",
1936+
]:
1937+
self.assertNotIn(unexpected_key, batch)
1938+
for expected_key in [
1939+
"position_ids",
1940+
"cu_seq_lens_k",
1941+
"cu_seq_lens_q",
1942+
"max_length_k",
1943+
"max_length_q",
1944+
]:
1945+
self.assertIn(expected_key, batch)
1946+
1947+
self.assertEqual(batch["input_ids"].shape, (1, 16))
1948+
self.assertEqual(
1949+
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
1950+
)
1951+
self.assertEqual(batch["position_ids"].shape, (1, 16))
1952+
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
1953+
1954+
self.assertEqual(batch["cu_seq_lens_k"].shape, (4,))
1955+
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
1956+
self.assertEqual(batch["cu_seq_lens_q"].shape, (4,))
1957+
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
1958+
# The flash attn max_length_{k,q} are simple python ints
1959+
self.assertEqual(batch["max_length_k"], 7)
1960+
self.assertEqual(batch["max_length_q"], 7)
1961+
1962+
def test_data_collator_with_flattening_seq_idx(self):
1963+
features = [
1964+
{"input_ids": [10, 11, 12]},
1965+
{"input_ids": [20, 21, 22, 23, 24, 25]},
1966+
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
1967+
]
1968+
1969+
data_collator = DataCollatorWithFlattening(return_tensors="np", return_seq_idx=True)
1970+
batch = data_collator(features)
1971+
1972+
for unexpected_key in [
1973+
"attention_mask",
1974+
"cu_seq_lens_k",
1975+
"cu_seq_lens_q",
1976+
"max_length_k",
1977+
"max_length_q",
1978+
]:
1979+
self.assertNotIn(unexpected_key, batch)
1980+
for expected_key in [
1981+
"position_ids",
1982+
"seq_idx",
1983+
]:
1984+
self.assertIn(expected_key, batch)
1985+
18061986
self.assertEqual(batch["input_ids"].shape, (1, 16))
18071987
self.assertEqual(
18081988
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
18091989
)
1810-
self.assertNotIn("attention_mask", batch)
1811-
self.assertIn("position_ids", batch)
18121990
self.assertEqual(batch["position_ids"].shape, (1, 16))
18131991
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
1992+
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
1993+
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
18141994

18151995
def test_data_collator_for_token_classification(self):
18161996
tokenizer = BertTokenizer(self.vocab_file)

0 commit comments

Comments
 (0)