@@ -162,10 +162,96 @@ bool group__get_linear_id() {
162
162
return Pass;
163
163
}
164
164
165
+ // Tests group::async_work_group_copy()
166
+ bool group__async_work_group_copy () {
167
+ std::cout << " +++ Running group::async_work_group_copy() test...\n " ;
168
+ constexpr int DIMS = 2 ;
169
+ bool Pass = true ;
170
+
171
+ std::vector<std::pair<range<DIMS>, range<DIMS>>> ranges;
172
+ ranges.push_back ({{3 , 1 }, {2 , 3 }});
173
+ ranges.push_back ({{1 , 3 }, {3 , 2 }});
174
+
175
+ for (const auto &i : ranges) {
176
+ const auto LocalRange = i.first ;
177
+ const auto GroupRange = i.second ;
178
+ const range<DIMS> GlobalRange = LocalRange * GroupRange;
179
+ using DataType = vec<size_t , DIMS>;
180
+ const int DataLen = GlobalRange.size ();
181
+ std::unique_ptr<DataType[]> Data (new DataType[DataLen]);
182
+ std::memset (Data.get (), 0 , DataLen * sizeof (DataType));
183
+
184
+ try {
185
+ buffer<DataType, 1 > Buf (Data.get (), DataLen);
186
+ queue Q (AsyncHandler{});
187
+
188
+ Q.submit ([&](handler &cgh) {
189
+ auto AccGlobal = Buf.get_access <access ::mode::read_write>(cgh);
190
+ accessor<DataType, DIMS, access ::mode::read_write,
191
+ access ::target::local>
192
+ AccLocal (LocalRange, cgh);
193
+
194
+ cgh.parallel_for <class group__async_work_group_copy >(
195
+ nd_range<2 >{GlobalRange, LocalRange},
196
+ [=](nd_item<DIMS> I) {
197
+ const auto Group = I.get_group ();
198
+ const auto NumElem = AccLocal.get_count ();
199
+ const auto Off = Group[0 ] * I.get_group_range (1 ) * NumElem +
200
+ Group[1 ] * I.get_local_range (1 );
201
+ auto PtrGlobal = AccGlobal.get_pointer () + Off;
202
+ auto PtrLocal = AccLocal.get_pointer ();
203
+ if (I.get_local_range (0 ) == 1 ) {
204
+ Group.async_work_group_copy (PtrLocal, PtrGlobal, NumElem);
205
+ } else {
206
+ Group.async_work_group_copy (PtrLocal, PtrGlobal, NumElem,
207
+ I.get_global_range (1 ));
208
+ }
209
+ AccLocal[I.get_local_id ()][0 ] += I.get_global_id (0 );
210
+ AccLocal[I.get_local_id ()][1 ] += I.get_global_id (1 );
211
+ if (I.get_local_range (0 ) == 1 ) {
212
+ Group.async_work_group_copy (PtrGlobal, PtrLocal, NumElem);
213
+ } else {
214
+ Group.async_work_group_copy (PtrGlobal, PtrLocal, NumElem,
215
+ I.get_global_range (1 ));
216
+ }
217
+ });
218
+ });
219
+ } catch (cl::sycl::exception const &E) {
220
+ std::cout << " SYCL exception caught: " << E.what () << ' \n ' ;
221
+ return 2 ;
222
+ }
223
+ const size_t SIZE_Y = GlobalRange.get (0 );
224
+ const size_t SIZE_X = GlobalRange.get (1 );
225
+ int ErrCnt = 0 ;
226
+
227
+ for (size_t Y = 0 ; Y < SIZE_Y; Y++) {
228
+ for (size_t X = 0 ; X < SIZE_X; X++) {
229
+ const size_t Ind = Y * SIZE_X + X;
230
+ const auto Test0 = Data[Ind][0 ];
231
+ const auto Test1 = Data[Ind][1 ];
232
+ const auto Gold0 = Y;
233
+ const auto Gold1 = X;
234
+ const bool Ok = (Test0 == Gold0 && Test1 == Gold1);
235
+ Pass &= Ok;
236
+
237
+ if (!Ok && ErrCnt++ < 10 ) {
238
+ std::cout << " *** ERROR at [" << Y << " ][" << X << " ]: " ;
239
+ std::cout << Test0 << " " << Test1 << " != " ;
240
+ std::cout << Gold0 << " " << Gold1 << " \n " ;
241
+ }
242
+ }
243
+ }
244
+ }
245
+ if (Pass)
246
+ std::cout << " pass\n " ;
247
+ return Pass;
248
+ }
249
+
165
250
int main () {
166
251
bool Pass = 1 ;
167
252
Pass &= group__get_group_range ();
168
253
Pass &= group__get_linear_id ();
254
+ Pass &= group__async_work_group_copy ();
169
255
170
256
if (!Pass) {
171
257
std::cout << " FAILED\n " ;
0 commit comments