Skip to content

Commit 6c98e0e

Browse files
Merge pull request #1952 from kbenzie/benie/bounds-checking-off-by-default
Make USM parameter bounds checking configurable
2 parents 26f1dfc + 216d30e commit 6c98e0e

File tree

4 files changed

+133
-68
lines changed

4 files changed

+133
-68
lines changed

scripts/core/INTRO.rst

+2
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ Layers currently included with the runtime are as follows:
295295
- Description
296296
* - UR_LAYER_PARAMETER_VALIDATION
297297
- Enables non-adapter-specific parameter validation (e.g. checking for null values).
298+
* - UR_LAYER_BOUNDS_CHECKING
299+
- Enables non-adapter-specific bounds checking of USM allocations for enqueued commands. Automatically enables UR_LAYER_PARAMETER_VALIDATION.
298300
* - UR_LAYER_LEAK_CHECKING
299301
- Performs some leak checking for API calls involving object creation/destruction.
300302
* - UR_LAYER_LIFETIME_VALIDATION

scripts/templates/valddi.cpp.mako

+13-1
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,16 @@ namespace ur_validation_layer
5757
{
5858
%for key, values in sorted_param_checks:
5959
%for val in values:
60-
if( ${val} )
60+
%if 'boundsError' in val:
61+
if ( getContext()->enableBoundsChecking ) {
62+
if ( ${val} ) {
63+
return ${key};
64+
}
65+
}
66+
%else:
67+
if ( ${val} )
6168
return ${key};
69+
%endif
6270

6371
%endfor
6472
%endfor
@@ -178,9 +186,13 @@ namespace ur_validation_layer
178186

179187
if (enabledLayerNames.count(nameFullValidation)) {
180188
enableParameterValidation = true;
189+
enableBoundsChecking = true;
181190
enableLeakChecking = true;
182191
enableLifetimeValidation = true;
183192
} else {
193+
if (enabledLayerNames.count(nameBoundsChecking)) {
194+
enableBoundsChecking = true;
195+
}
184196
if (enabledLayerNames.count(nameParameterValidation)) {
185197
enableParameterValidation = true;
186198
}

source/loader/layers/validation/ur_valddi.cpp

+114-66
Original file line numberDiff line numberDiff line change
@@ -4822,9 +4822,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead(
48224822
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
48234823
}
48244824

4825-
if (auto boundsError = bounds(hBuffer, offset, size);
4826-
boundsError != UR_RESULT_SUCCESS) {
4827-
return boundsError;
4825+
if (getContext()->enableBoundsChecking) {
4826+
if (auto boundsError = bounds(hBuffer, offset, size);
4827+
boundsError != UR_RESULT_SUCCESS) {
4828+
return boundsError;
4829+
}
48284830
}
48294831

48304832
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -4902,9 +4904,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite(
49024904
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
49034905
}
49044906

4905-
if (auto boundsError = bounds(hBuffer, offset, size);
4906-
boundsError != UR_RESULT_SUCCESS) {
4907-
return boundsError;
4907+
if (getContext()->enableBoundsChecking) {
4908+
if (auto boundsError = bounds(hBuffer, offset, size);
4909+
boundsError != UR_RESULT_SUCCESS) {
4910+
return boundsError;
4911+
}
49084912
}
49094913

49104914
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5033,9 +5037,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
50335037
return UR_RESULT_ERROR_INVALID_SIZE;
50345038
}
50355039

5036-
if (auto boundsError = bounds(hBuffer, bufferOrigin, region);
5037-
boundsError != UR_RESULT_SUCCESS) {
5038-
return boundsError;
5040+
if (getContext()->enableBoundsChecking) {
5041+
if (auto boundsError = bounds(hBuffer, bufferOrigin, region);
5042+
boundsError != UR_RESULT_SUCCESS) {
5043+
return boundsError;
5044+
}
50395045
}
50405046

50415047
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5168,9 +5174,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
51685174
return UR_RESULT_ERROR_INVALID_SIZE;
51695175
}
51705176

5171-
if (auto boundsError = bounds(hBuffer, bufferOrigin, region);
5172-
boundsError != UR_RESULT_SUCCESS) {
5173-
return boundsError;
5177+
if (getContext()->enableBoundsChecking) {
5178+
if (auto boundsError = bounds(hBuffer, bufferOrigin, region);
5179+
boundsError != UR_RESULT_SUCCESS) {
5180+
return boundsError;
5181+
}
51745182
}
51755183

51765184
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5248,14 +5256,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy(
52485256
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
52495257
}
52505258

5251-
if (auto boundsError = bounds(hBufferSrc, srcOffset, size);
5252-
boundsError != UR_RESULT_SUCCESS) {
5253-
return boundsError;
5259+
if (getContext()->enableBoundsChecking) {
5260+
if (auto boundsError = bounds(hBufferSrc, srcOffset, size);
5261+
boundsError != UR_RESULT_SUCCESS) {
5262+
return boundsError;
5263+
}
52545264
}
52555265

5256-
if (auto boundsError = bounds(hBufferDst, dstOffset, size);
5257-
boundsError != UR_RESULT_SUCCESS) {
5258-
return boundsError;
5266+
if (getContext()->enableBoundsChecking) {
5267+
if (auto boundsError = bounds(hBufferDst, dstOffset, size);
5268+
boundsError != UR_RESULT_SUCCESS) {
5269+
return boundsError;
5270+
}
52595271
}
52605272

52615273
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5383,14 +5395,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
53835395
return UR_RESULT_ERROR_INVALID_SIZE;
53845396
}
53855397

5386-
if (auto boundsError = bounds(hBufferSrc, srcOrigin, region);
5387-
boundsError != UR_RESULT_SUCCESS) {
5388-
return boundsError;
5398+
if (getContext()->enableBoundsChecking) {
5399+
if (auto boundsError = bounds(hBufferSrc, srcOrigin, region);
5400+
boundsError != UR_RESULT_SUCCESS) {
5401+
return boundsError;
5402+
}
53895403
}
53905404

5391-
if (auto boundsError = bounds(hBufferDst, dstOrigin, region);
5392-
boundsError != UR_RESULT_SUCCESS) {
5393-
return boundsError;
5405+
if (getContext()->enableBoundsChecking) {
5406+
if (auto boundsError = bounds(hBufferDst, dstOrigin, region);
5407+
boundsError != UR_RESULT_SUCCESS) {
5408+
return boundsError;
5409+
}
53945410
}
53955411

53965412
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5492,9 +5508,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill(
54925508
return UR_RESULT_ERROR_INVALID_SIZE;
54935509
}
54945510

5495-
if (auto boundsError = bounds(hBuffer, offset, size);
5496-
boundsError != UR_RESULT_SUCCESS) {
5497-
return boundsError;
5511+
if (getContext()->enableBoundsChecking) {
5512+
if (auto boundsError = bounds(hBuffer, offset, size);
5513+
boundsError != UR_RESULT_SUCCESS) {
5514+
return boundsError;
5515+
}
54985516
}
54995517

55005518
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5579,9 +5597,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageRead(
55795597
return UR_RESULT_ERROR_INVALID_SIZE;
55805598
}
55815599

5582-
if (auto boundsError = boundsImage(hImage, origin, region);
5583-
boundsError != UR_RESULT_SUCCESS) {
5584-
return boundsError;
5600+
if (getContext()->enableBoundsChecking) {
5601+
if (auto boundsError = boundsImage(hImage, origin, region);
5602+
boundsError != UR_RESULT_SUCCESS) {
5603+
return boundsError;
5604+
}
55855605
}
55865606

55875607
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5667,9 +5687,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageWrite(
56675687
return UR_RESULT_ERROR_INVALID_SIZE;
56685688
}
56695689

5670-
if (auto boundsError = boundsImage(hImage, origin, region);
5671-
boundsError != UR_RESULT_SUCCESS) {
5672-
return boundsError;
5690+
if (getContext()->enableBoundsChecking) {
5691+
if (auto boundsError = boundsImage(hImage, origin, region);
5692+
boundsError != UR_RESULT_SUCCESS) {
5693+
return boundsError;
5694+
}
56735695
}
56745696

56755697
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5756,14 +5778,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageCopy(
57565778
return UR_RESULT_ERROR_INVALID_SIZE;
57575779
}
57585780

5759-
if (auto boundsError = boundsImage(hImageSrc, srcOrigin, region);
5760-
boundsError != UR_RESULT_SUCCESS) {
5761-
return boundsError;
5781+
if (getContext()->enableBoundsChecking) {
5782+
if (auto boundsError = boundsImage(hImageSrc, srcOrigin, region);
5783+
boundsError != UR_RESULT_SUCCESS) {
5784+
return boundsError;
5785+
}
57625786
}
57635787

5764-
if (auto boundsError = boundsImage(hImageDst, dstOrigin, region);
5765-
boundsError != UR_RESULT_SUCCESS) {
5766-
return boundsError;
5788+
if (getContext()->enableBoundsChecking) {
5789+
if (auto boundsError = boundsImage(hImageDst, dstOrigin, region);
5790+
boundsError != UR_RESULT_SUCCESS) {
5791+
return boundsError;
5792+
}
57675793
}
57685794

57695795
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -5850,9 +5876,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap(
58505876
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
58515877
}
58525878

5853-
if (auto boundsError = bounds(hBuffer, offset, size);
5854-
boundsError != UR_RESULT_SUCCESS) {
5855-
return boundsError;
5879+
if (getContext()->enableBoundsChecking) {
5880+
if (auto boundsError = bounds(hBuffer, offset, size);
5881+
boundsError != UR_RESULT_SUCCESS) {
5882+
return boundsError;
5883+
}
58565884
}
58575885

58585886
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -6012,9 +6040,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill(
60126040
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
60136041
}
60146042

6015-
if (auto boundsError = bounds(hQueue, pMem, 0, size);
6016-
boundsError != UR_RESULT_SUCCESS) {
6017-
return boundsError;
6043+
if (getContext()->enableBoundsChecking) {
6044+
if (auto boundsError = bounds(hQueue, pMem, 0, size);
6045+
boundsError != UR_RESULT_SUCCESS) {
6046+
return boundsError;
6047+
}
60186048
}
60196049

60206050
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -6089,14 +6119,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy(
60896119
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
60906120
}
60916121

6092-
if (auto boundsError = bounds(hQueue, pDst, 0, size);
6093-
boundsError != UR_RESULT_SUCCESS) {
6094-
return boundsError;
6122+
if (getContext()->enableBoundsChecking) {
6123+
if (auto boundsError = bounds(hQueue, pDst, 0, size);
6124+
boundsError != UR_RESULT_SUCCESS) {
6125+
return boundsError;
6126+
}
60956127
}
60966128

6097-
if (auto boundsError = bounds(hQueue, pSrc, 0, size);
6098-
boundsError != UR_RESULT_SUCCESS) {
6099-
return boundsError;
6129+
if (getContext()->enableBoundsChecking) {
6130+
if (auto boundsError = bounds(hQueue, pSrc, 0, size);
6131+
boundsError != UR_RESULT_SUCCESS) {
6132+
return boundsError;
6133+
}
61006134
}
61016135

61026136
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -6169,9 +6203,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch(
61696203
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
61706204
}
61716205

6172-
if (auto boundsError = bounds(hQueue, pMem, 0, size);
6173-
boundsError != UR_RESULT_SUCCESS) {
6174-
return boundsError;
6206+
if (getContext()->enableBoundsChecking) {
6207+
if (auto boundsError = bounds(hQueue, pMem, 0, size);
6208+
boundsError != UR_RESULT_SUCCESS) {
6209+
return boundsError;
6210+
}
61756211
}
61766212

61776213
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -6230,9 +6266,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMAdvise(
62306266
return UR_RESULT_ERROR_INVALID_SIZE;
62316267
}
62326268

6233-
if (auto boundsError = bounds(hQueue, pMem, 0, size);
6234-
boundsError != UR_RESULT_SUCCESS) {
6235-
return boundsError;
6269+
if (getContext()->enableBoundsChecking) {
6270+
if (auto boundsError = bounds(hQueue, pMem, 0, size);
6271+
boundsError != UR_RESULT_SUCCESS) {
6272+
return boundsError;
6273+
}
62366274
}
62376275
}
62386276

@@ -6332,9 +6370,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill2D(
63326370
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
63336371
}
63346372

6335-
if (auto boundsError = bounds(hQueue, pMem, 0, pitch * height);
6336-
boundsError != UR_RESULT_SUCCESS) {
6337-
return boundsError;
6373+
if (getContext()->enableBoundsChecking) {
6374+
if (auto boundsError = bounds(hQueue, pMem, 0, pitch * height);
6375+
boundsError != UR_RESULT_SUCCESS) {
6376+
return boundsError;
6377+
}
63386378
}
63396379

63406380
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -6431,14 +6471,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
64316471
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
64326472
}
64336473

6434-
if (auto boundsError = bounds(hQueue, pDst, 0, dstPitch * height);
6435-
boundsError != UR_RESULT_SUCCESS) {
6436-
return boundsError;
6474+
if (getContext()->enableBoundsChecking) {
6475+
if (auto boundsError = bounds(hQueue, pDst, 0, dstPitch * height);
6476+
boundsError != UR_RESULT_SUCCESS) {
6477+
return boundsError;
6478+
}
64376479
}
64386480

6439-
if (auto boundsError = bounds(hQueue, pSrc, 0, srcPitch * height);
6440-
boundsError != UR_RESULT_SUCCESS) {
6441-
return boundsError;
6481+
if (getContext()->enableBoundsChecking) {
6482+
if (auto boundsError = bounds(hQueue, pSrc, 0, srcPitch * height);
6483+
boundsError != UR_RESULT_SUCCESS) {
6484+
return boundsError;
6485+
}
64426486
}
64436487

64446488
if (phEventWaitList != NULL && numEventsInWaitList > 0) {
@@ -10997,9 +11041,13 @@ ur_result_t context_t::init(ur_dditable_t *dditable,
1099711041

1099811042
if (enabledLayerNames.count(nameFullValidation)) {
1099911043
enableParameterValidation = true;
11044+
enableBoundsChecking = true;
1100011045
enableLeakChecking = true;
1100111046
enableLifetimeValidation = true;
1100211047
} else {
11048+
if (enabledLayerNames.count(nameBoundsChecking)) {
11049+
enableBoundsChecking = true;
11050+
}
1100311051
if (enabledLayerNames.count(nameParameterValidation)) {
1100411052
enableParameterValidation = true;
1100511053
}

source/loader/layers/validation/ur_validation_layer.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class __urdlllocal context_t : public proxy_layer_context_t,
2424
public AtomicSingleton<context_t> {
2525
public:
2626
bool enableParameterValidation = false;
27+
bool enableBoundsChecking = false;
2728
bool enableLeakChecking = false;
2829
bool enableLifetimeValidation = false;
2930
logger::Logger logger;
@@ -35,7 +36,7 @@ class __urdlllocal context_t : public proxy_layer_context_t,
3536

3637
static std::vector<std::string> getNames() {
3738
return {nameFullValidation, nameParameterValidation, nameLeakChecking,
38-
nameLifetimeValidation};
39+
nameBoundsChecking, nameLifetimeValidation};
3940
}
4041
ur_result_t init(ur_dditable_t *dditable,
4142
const std::set<std::string> &enabledLayerNames,
@@ -49,6 +50,8 @@ class __urdlllocal context_t : public proxy_layer_context_t,
4950
"UR_LAYER_FULL_VALIDATION";
5051
inline static const std::string nameParameterValidation =
5152
"UR_LAYER_PARAMETER_VALIDATION";
53+
inline static const std::string nameBoundsChecking =
54+
"UR_LAYER_BOUNDS_CHECKING";
5255
inline static const std::string nameLeakChecking = "UR_LAYER_LEAK_CHECKING";
5356
inline static const std::string nameLifetimeValidation =
5457
"UR_LAYER_LIFETIME_VALIDATION";

0 commit comments

Comments
 (0)