@@ -61,10 +61,22 @@ struct ur_kernel_handle_t_ {
61
61
using args_t = std::array<char , MaxParamBytes>;
62
62
using args_size_t = std::vector<size_t >;
63
63
using args_index_t = std::vector<void *>;
64
+ // / Storage shared by all args which is mem copied into when adding a new
65
+ // / argument.
64
66
args_t Storage;
67
+ // / Aligned size of each parameter, including padding.
65
68
args_size_t ParamSizes;
69
+ // / Byte offset into /p Storage allocation for each parameter.
66
70
args_index_t Indices;
67
- args_size_t OffsetPerIndex;
71
+ // / Aligned size in bytes for each local memory parameter after padding has
72
+ // / been added. Zero if the argument at the index isn't a local memory
73
+ // / argument.
74
+ args_size_t AlignedLocalMemSize;
75
+ // / Original size in bytes for each local memory parameter, prior to being
76
+ // / padded to appropriate alignment. Zero if the argument at the index
77
+ // / isn't a local memory argument.
78
+ args_size_t OriginalLocalMemSize;
79
+
68
80
// A struct to keep track of memargs so that we can do dependency analysis
69
81
// at urEnqueueKernelLaunch
70
82
struct mem_obj_arg {
@@ -93,7 +105,8 @@ struct ur_kernel_handle_t_ {
93
105
Indices.resize (Index + 2 , Indices.back ());
94
106
// Ensure enough space for the new argument
95
107
ParamSizes.resize (Index + 1 );
96
- OffsetPerIndex.resize (Index + 1 );
108
+ AlignedLocalMemSize.resize (Index + 1 );
109
+ OriginalLocalMemSize.resize (Index + 1 );
97
110
}
98
111
ParamSizes[Index] = Size ;
99
112
// calculate the insertion point on the array
@@ -102,28 +115,81 @@ struct ur_kernel_handle_t_ {
102
115
// Update the stored value for the argument
103
116
std::memcpy (&Storage[InsertPos], Arg, Size );
104
117
Indices[Index] = &Storage[InsertPos];
105
- OffsetPerIndex [Index] = LocalSize;
118
+ AlignedLocalMemSize [Index] = LocalSize;
106
119
}
107
120
108
- void addLocalArg (size_t Index, size_t Size ) {
109
- size_t LocalOffset = this ->getLocalSize ();
121
+ // / Returns the padded size and offset of a local memory argument.
122
+ // / Local memory arguments need to be padded if the alignment for the size
123
+ // / doesn't match the current offset into the kernel local data.
124
+ // / @param Index Kernel arg index.
125
+ // / @param Size User passed size of local parameter.
126
+ // / @return Tuple of (Aligned size, Aligned offset into local data).
127
+ std::pair<size_t , size_t > calcAlignedLocalArgument (size_t Index,
128
+ size_t Size ) {
129
+ // Store the unpadded size of the local argument
130
+ if (Index + 2 > Indices.size ()) {
131
+ AlignedLocalMemSize.resize (Index + 1 );
132
+ OriginalLocalMemSize.resize (Index + 1 );
133
+ }
134
+ OriginalLocalMemSize[Index] = Size ;
135
+
136
+ // Calculate the current starting offset into local data
137
+ const size_t LocalOffset = std::accumulate (
138
+ std::begin (AlignedLocalMemSize),
139
+ std::next (std::begin (AlignedLocalMemSize), Index), size_t {0 });
110
140
111
- // maximum required alignment is the size of the largest vector type
141
+ // Maximum required alignment is the size of the largest vector type
112
142
const size_t MaxAlignment = sizeof (double ) * 16 ;
113
143
114
- // for arguments smaller than the maximum alignment simply align to the
144
+ // For arguments smaller than the maximum alignment simply align to the
115
145
// size of the argument
116
146
const size_t Alignment = std::min (MaxAlignment, Size );
117
147
118
- // align the argument
148
+ // Align the argument
119
149
size_t AlignedLocalOffset = LocalOffset;
120
- size_t Pad = LocalOffset % Alignment;
150
+ const size_t Pad = LocalOffset % Alignment;
121
151
if (Pad != 0 ) {
122
152
AlignedLocalOffset += Alignment - Pad;
123
153
}
124
154
155
+ const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset);
156
+ return std::make_pair (AlignedLocalSize, AlignedLocalOffset);
157
+ }
158
+
159
+ void addLocalArg (size_t Index, size_t Size ) {
160
+ // Get the aligned argument size and offset into local data
161
+ auto [AlignedLocalSize, AlignedLocalOffset] =
162
+ calcAlignedLocalArgument (Index, Size );
163
+
164
+ // Store argument details
125
165
addArg (Index, sizeof (size_t ), (const void *)&(AlignedLocalOffset),
126
- Size + (AlignedLocalOffset - LocalOffset));
166
+ AlignedLocalSize);
167
+
168
+ // For every existing local argument which follows at later argument
169
+ // indices, update the offset and pointer into the kernel local memory.
170
+ // Required as padding will need to be recalculated.
171
+ const size_t NumArgs = Indices.size () - 1 ; // Accounts for implicit arg
172
+ for (auto SuccIndex = Index + 1 ; SuccIndex < NumArgs; SuccIndex++) {
173
+ const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
174
+ if (OriginalLocalSize == 0 ) {
175
+ // Skip if successor argument isn't a local memory arg
176
+ continue ;
177
+ }
178
+
179
+ // Recalculate alignment
180
+ auto [SuccAlignedLocalSize, SuccAlignedLocalOffset] =
181
+ calcAlignedLocalArgument (SuccIndex, OriginalLocalSize);
182
+
183
+ // Store new local memory size
184
+ AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
185
+
186
+ // Store new offset into local data
187
+ const size_t InsertPos =
188
+ std::accumulate (std::begin (ParamSizes),
189
+ std::begin (ParamSizes) + SuccIndex, size_t {0 });
190
+ std::memcpy (&Storage[InsertPos], &SuccAlignedLocalOffset,
191
+ sizeof (size_t ));
192
+ }
127
193
}
128
194
129
195
void addMemObjArg (int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
@@ -145,15 +211,11 @@ struct ur_kernel_handle_t_ {
145
211
std::memcpy (ImplicitOffsetArgs, ImplicitOffset, Size );
146
212
}
147
213
148
- void clearLocalSize () {
149
- std::fill (std::begin (OffsetPerIndex), std::end (OffsetPerIndex), 0 );
150
- }
151
-
152
214
const args_index_t &getIndices () const noexcept { return Indices; }
153
215
154
216
uint32_t getLocalSize () const {
155
- return std::accumulate (std::begin (OffsetPerIndex ),
156
- std::end (OffsetPerIndex ), 0 );
217
+ return std::accumulate (std::begin (AlignedLocalMemSize ),
218
+ std::end (AlignedLocalMemSize ), 0 );
157
219
}
158
220
} Args;
159
221
@@ -240,7 +302,5 @@ struct ur_kernel_handle_t_ {
240
302
241
303
uint32_t getLocalSize () const noexcept { return Args.getLocalSize (); }
242
304
243
- void clearLocalSize () { Args.clearLocalSize (); }
244
-
245
305
size_t getRegsPerThread () const noexcept { return RegsPerThread; };
246
306
};
0 commit comments