@@ -126,6 +126,104 @@ def test_data_collator_with_padding(self):
126
126
batch = data_collator (features )
127
127
self .assertEqual (batch ["input_ids" ].shape , torch .Size ([2 , 8 ]))
128
128
129
+ def test_data_collator_with_flattening (self ):
130
+ features = [
131
+ {"input_ids" : [10 , 11 , 12 ]},
132
+ {"input_ids" : [20 , 21 , 22 , 23 , 24 , 25 ]},
133
+ {"input_ids" : [30 , 31 , 32 , 33 , 34 , 35 , 36 ]},
134
+ ]
135
+
136
+ data_collator = DataCollatorWithFlattening (return_tensors = "pt" )
137
+ batch = data_collator (features )
138
+
139
+ for unexpected_key in [
140
+ "attention_mask" ,
141
+ "cu_seq_lens_k" ,
142
+ "cu_seq_lens_q" ,
143
+ "max_length_k" ,
144
+ "max_length_q" ,
145
+ "seq_idx" ,
146
+ ]:
147
+ self .assertNotIn (unexpected_key , batch )
148
+ self .assertIn ("position_ids" , batch )
149
+
150
+ self .assertEqual (batch ["input_ids" ].shape , torch .Size ([1 , 16 ]))
151
+ self .assertEqual (
152
+ batch ["input_ids" ][0 ].tolist (), [10 , 11 , 12 , 20 , 21 , 22 , 23 , 24 , 25 , 30 , 31 , 32 , 33 , 34 , 35 , 36 ]
153
+ )
154
+ self .assertEqual (batch ["position_ids" ].shape , torch .Size ([1 , 16 ]))
155
+ self .assertEqual (batch ["position_ids" ][0 ].tolist (), [0 , 1 , 2 , 0 , 1 , 2 , 3 , 4 , 5 , 0 , 1 , 2 , 3 , 4 , 5 , 6 ])
156
+
157
+ def test_data_collator_with_flattening_flash_attn_kwargs (self ):
158
+ features = [
159
+ {"input_ids" : [10 , 11 , 12 ]},
160
+ {"input_ids" : [20 , 21 , 22 , 23 , 24 , 25 ]},
161
+ {"input_ids" : [30 , 31 , 32 , 33 , 34 , 35 , 36 ]},
162
+ ]
163
+ data_collator = DataCollatorWithFlattening (return_tensors = "pt" , return_flash_attn_kwargs = True )
164
+ batch = data_collator (features )
165
+
166
+ for unexpected_key in [
167
+ "attention_mask" ,
168
+ "seq_idx" ,
169
+ ]:
170
+ self .assertNotIn (unexpected_key , batch )
171
+ for expected_key in [
172
+ "position_ids" ,
173
+ "cu_seq_lens_k" ,
174
+ "cu_seq_lens_q" ,
175
+ "max_length_k" ,
176
+ "max_length_q" ,
177
+ ]:
178
+ self .assertIn (expected_key , batch )
179
+
180
+ self .assertEqual (batch ["input_ids" ].shape , torch .Size ([1 , 16 ]))
181
+ self .assertEqual (
182
+ batch ["input_ids" ][0 ].tolist (), [10 , 11 , 12 , 20 , 21 , 22 , 23 , 24 , 25 , 30 , 31 , 32 , 33 , 34 , 35 , 36 ]
183
+ )
184
+ self .assertEqual (batch ["position_ids" ].shape , torch .Size ([1 , 16 ]))
185
+ self .assertEqual (batch ["position_ids" ][0 ].tolist (), [0 , 1 , 2 , 0 , 1 , 2 , 3 , 4 , 5 , 0 , 1 , 2 , 3 , 4 , 5 , 6 ])
186
+
187
+ self .assertEqual (batch ["cu_seq_lens_k" ].shape , torch .Size ([4 ]))
188
+ self .assertEqual (batch ["cu_seq_lens_k" ].tolist (), [0 , 3 , 9 , 16 ])
189
+ self .assertEqual (batch ["cu_seq_lens_q" ].shape , torch .Size ([4 ]))
190
+ self .assertEqual (batch ["cu_seq_lens_q" ].tolist (), [0 , 3 , 9 , 16 ])
191
+ # The flash attn max_length_{k,q} are simple python ints
192
+ self .assertEqual (batch ["max_length_k" ], 7 )
193
+ self .assertEqual (batch ["max_length_q" ], 7 )
194
+
195
+ def test_data_collator_with_flattening_seq_idx (self ):
196
+ features = [
197
+ {"input_ids" : [10 , 11 , 12 ]},
198
+ {"input_ids" : [20 , 21 , 22 , 23 , 24 , 25 ]},
199
+ {"input_ids" : [30 , 31 , 32 , 33 , 34 , 35 , 36 ]},
200
+ ]
201
+ data_collator = DataCollatorWithFlattening (return_tensors = "pt" , return_seq_idx = True )
202
+ batch = data_collator (features )
203
+
204
+ for unexpected_key in [
205
+ "attention_mask" ,
206
+ "cu_seq_lens_k" ,
207
+ "cu_seq_lens_q" ,
208
+ "max_length_k" ,
209
+ "max_length_q" ,
210
+ ]:
211
+ self .assertNotIn (unexpected_key , batch )
212
+ for expected_key in [
213
+ "position_ids" ,
214
+ "seq_idx" ,
215
+ ]:
216
+ self .assertIn (expected_key , batch )
217
+
218
+ self .assertEqual (batch ["input_ids" ].shape , torch .Size ([1 , 16 ]))
219
+ self .assertEqual (
220
+ batch ["input_ids" ][0 ].tolist (), [10 , 11 , 12 , 20 , 21 , 22 , 23 , 24 , 25 , 30 , 31 , 32 , 33 , 34 , 35 , 36 ]
221
+ )
222
+ self .assertEqual (batch ["position_ids" ].shape , torch .Size ([1 , 16 ]))
223
+ self .assertEqual (batch ["position_ids" ][0 ].tolist (), [0 , 1 , 2 , 0 , 1 , 2 , 3 , 4 , 5 , 0 , 1 , 2 , 3 , 4 , 5 , 6 ])
224
+ self .assertEqual (batch ["seq_idx" ].shape , batch ["input_ids" ].shape )
225
+ self .assertEqual (batch ["seq_idx" ][0 ].tolist (), [0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ])
226
+
129
227
def test_data_collator_for_token_classification (self ):
130
228
tokenizer = BertTokenizer (self .vocab_file )
131
229
features = [
@@ -1803,14 +1901,96 @@ def test_data_collator_with_flattening(self):
1803
1901
1804
1902
data_collator = DataCollatorWithFlattening (return_tensors = "np" )
1805
1903
batch = data_collator (features )
1904
+
1905
+ for unexpected_key in [
1906
+ "attention_mask" ,
1907
+ "cu_seq_lens_k" ,
1908
+ "cu_seq_lens_q" ,
1909
+ "max_length_k" ,
1910
+ "max_length_q" ,
1911
+ "seq_idx" ,
1912
+ ]:
1913
+ self .assertNotIn (unexpected_key , batch )
1914
+ self .assertIn ("position_ids" , batch )
1915
+
1916
+ self .assertEqual (batch ["input_ids" ].shape , (1 , 16 ))
1917
+ self .assertEqual (
1918
+ batch ["input_ids" ][0 ].tolist (), [10 , 11 , 12 , 20 , 21 , 22 , 23 , 24 , 25 , 30 , 31 , 32 , 33 , 34 , 35 , 36 ]
1919
+ )
1920
+ self .assertEqual (batch ["position_ids" ].shape , (1 , 16 ))
1921
+ self .assertEqual (batch ["position_ids" ][0 ].tolist (), [0 , 1 , 2 , 0 , 1 , 2 , 3 , 4 , 5 , 0 , 1 , 2 , 3 , 4 , 5 , 6 ])
1922
+
1923
+ def test_data_collator_with_flattening_flash_attn_kwargs (self ):
1924
+ features = [
1925
+ {"input_ids" : [10 , 11 , 12 ]},
1926
+ {"input_ids" : [20 , 21 , 22 , 23 , 24 , 25 ]},
1927
+ {"input_ids" : [30 , 31 , 32 , 33 , 34 , 35 , 36 ]},
1928
+ ]
1929
+
1930
+ data_collator = DataCollatorWithFlattening (return_tensors = "np" , return_flash_attn_kwargs = True )
1931
+ batch = data_collator (features )
1932
+
1933
+ for unexpected_key in [
1934
+ "attention_mask" ,
1935
+ "seq_idx" ,
1936
+ ]:
1937
+ self .assertNotIn (unexpected_key , batch )
1938
+ for expected_key in [
1939
+ "position_ids" ,
1940
+ "cu_seq_lens_k" ,
1941
+ "cu_seq_lens_q" ,
1942
+ "max_length_k" ,
1943
+ "max_length_q" ,
1944
+ ]:
1945
+ self .assertIn (expected_key , batch )
1946
+
1947
+ self .assertEqual (batch ["input_ids" ].shape , (1 , 16 ))
1948
+ self .assertEqual (
1949
+ batch ["input_ids" ][0 ].tolist (), [10 , 11 , 12 , 20 , 21 , 22 , 23 , 24 , 25 , 30 , 31 , 32 , 33 , 34 , 35 , 36 ]
1950
+ )
1951
+ self .assertEqual (batch ["position_ids" ].shape , (1 , 16 ))
1952
+ self .assertEqual (batch ["position_ids" ][0 ].tolist (), [0 , 1 , 2 , 0 , 1 , 2 , 3 , 4 , 5 , 0 , 1 , 2 , 3 , 4 , 5 , 6 ])
1953
+
1954
+ self .assertEqual (batch ["cu_seq_lens_k" ].shape , (4 ,))
1955
+ self .assertEqual (batch ["cu_seq_lens_k" ].tolist (), [0 , 3 , 9 , 16 ])
1956
+ self .assertEqual (batch ["cu_seq_lens_q" ].shape , (4 ,))
1957
+ self .assertEqual (batch ["cu_seq_lens_q" ].tolist (), [0 , 3 , 9 , 16 ])
1958
+ # The flash attn max_length_{k,q} are simple python ints
1959
+ self .assertEqual (batch ["max_length_k" ], 7 )
1960
+ self .assertEqual (batch ["max_length_q" ], 7 )
1961
+
1962
+ def test_data_collator_with_flattening_seq_idx (self ):
1963
+ features = [
1964
+ {"input_ids" : [10 , 11 , 12 ]},
1965
+ {"input_ids" : [20 , 21 , 22 , 23 , 24 , 25 ]},
1966
+ {"input_ids" : [30 , 31 , 32 , 33 , 34 , 35 , 36 ]},
1967
+ ]
1968
+
1969
+ data_collator = DataCollatorWithFlattening (return_tensors = "np" , return_seq_idx = True )
1970
+ batch = data_collator (features )
1971
+
1972
+ for unexpected_key in [
1973
+ "attention_mask" ,
1974
+ "cu_seq_lens_k" ,
1975
+ "cu_seq_lens_q" ,
1976
+ "max_length_k" ,
1977
+ "max_length_q" ,
1978
+ ]:
1979
+ self .assertNotIn (unexpected_key , batch )
1980
+ for expected_key in [
1981
+ "position_ids" ,
1982
+ "seq_idx" ,
1983
+ ]:
1984
+ self .assertIn (expected_key , batch )
1985
+
1806
1986
self .assertEqual (batch ["input_ids" ].shape , (1 , 16 ))
1807
1987
self .assertEqual (
1808
1988
batch ["input_ids" ][0 ].tolist (), [10 , 11 , 12 , 20 , 21 , 22 , 23 , 24 , 25 , 30 , 31 , 32 , 33 , 34 , 35 , 36 ]
1809
1989
)
1810
- self .assertNotIn ("attention_mask" , batch )
1811
- self .assertIn ("position_ids" , batch )
1812
1990
self .assertEqual (batch ["position_ids" ].shape , (1 , 16 ))
1813
1991
self .assertEqual (batch ["position_ids" ][0 ].tolist (), [0 , 1 , 2 , 0 , 1 , 2 , 3 , 4 , 5 , 0 , 1 , 2 , 3 , 4 , 5 , 6 ])
1992
+ self .assertEqual (batch ["seq_idx" ].shape , batch ["input_ids" ].shape )
1993
+ self .assertEqual (batch ["seq_idx" ][0 ].tolist (), [0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ])
1814
1994
1815
1995
def test_data_collator_for_token_classification (self ):
1816
1996
tokenizer = BertTokenizer (self .vocab_file )
0 commit comments