@@ -117,48 +117,32 @@ SpirvInstruction *RawBufferHandler::load64Bits(SpirvInstruction *buffer,
117
117
SpirvInstruction *ptr = nullptr ;
118
118
auto *constUint0 =
119
119
spvBuilder.getConstantInt (astContext.UnsignedIntTy , llvm::APInt (32 , 0 ));
120
- auto *constUint32 =
121
- spvBuilder.getConstantInt (astContext.UnsignedIntTy , llvm::APInt (32 , 32 ));
122
120
121
+ // Load the first word and increment index.
123
122
auto *index = address.getWordIndex (loc, range);
124
-
125
- // Need to perform two 32-bit uint loads and construct a 64-bit value.
126
-
127
- // Load the first 32-bit uint (word0).
128
123
ptr = spvBuilder.createAccessChain (astContext.UnsignedIntTy , buffer,
129
124
{constUint0, index }, loc, range);
130
125
SpirvInstruction *word0 =
131
126
spvBuilder.createLoad (astContext.UnsignedIntTy , ptr, loc, range);
132
- // Increment the base index
133
127
address.incrementWordIndex (loc, range);
128
+
129
+ // Load the second word and increment index.
134
130
index = address.getWordIndex (loc, range);
135
- // Load the second 32-bit uint (word1).
136
131
ptr = spvBuilder.createAccessChain (astContext.UnsignedIntTy , buffer,
137
132
{constUint0, index }, loc, range);
138
133
SpirvInstruction *word1 =
139
134
spvBuilder.createLoad (astContext.UnsignedIntTy , ptr, loc, range);
140
-
141
- // Convert both word0 and word1 to 64-bit uints.
142
- word0 = spvBuilder.createUnaryOp (
143
- spv::Op::OpUConvert, astContext.UnsignedLongLongTy , word0, loc, range);
144
- word1 = spvBuilder.createUnaryOp (
145
- spv::Op::OpUConvert, astContext.UnsignedLongLongTy , word1, loc, range);
146
-
147
- // Shift word1 to the left by 32 bits.
148
- word1 = spvBuilder.createBinaryOp (spv::Op::OpShiftLeftLogical,
149
- astContext.UnsignedLongLongTy , word1,
150
- constUint32, loc, range);
151
-
152
- // BitwiseOr word0 and word1.
153
- result = spvBuilder.createBinaryOp (spv::Op::OpBitwiseOr,
154
- astContext.UnsignedLongLongTy , word0,
155
- word1, loc, range);
156
- result = bitCastToNumericalOrBool (result, astContext.UnsignedLongLongTy ,
157
- target64BitType, loc, range);
158
- result->setRValue ();
159
-
160
135
address.incrementWordIndex (loc, range);
161
136
137
+ // Combine the 2 words into a composite, and bitcast into the destination
138
+ // type.
139
+ const auto uintVec2Type =
140
+ astContext.getExtVectorType (astContext.UnsignedIntTy , 2 );
141
+ auto *operand = spvBuilder.createCompositeConstruct (
142
+ uintVec2Type, {word0, word1}, loc, range);
143
+ result = spvBuilder.createUnaryOp (spv::Op::OpBitcast, target64BitType,
144
+ operand, loc, range);
145
+ result->setRValue ();
162
146
return result;
163
147
}
164
148
@@ -441,39 +425,31 @@ void RawBufferHandler::store64Bits(SpirvInstruction *value,
441
425
const auto loc = buffer->getSourceLocation ();
442
426
auto *constUint0 =
443
427
spvBuilder.getConstantInt (astContext.UnsignedIntTy , llvm::APInt (32 , 0 ));
444
- auto *constUint32 =
445
- spvBuilder.getConstantInt (astContext.UnsignedIntTy , llvm::APInt (32 , 32 ));
446
428
447
- auto *index = address.getWordIndex (loc, range);
429
+ // Bitcast the source into a 32-bit words composite.
430
+ const auto uintVec2Type =
431
+ astContext.getExtVectorType (astContext.UnsignedIntTy , 2 );
432
+ auto *tmp = spvBuilder.createUnaryOp (spv::Op::OpBitcast, uintVec2Type, value,
433
+ loc, range);
448
434
449
- // The underlying element type of the ByteAddressBuffer is uint. So we
450
- // need to store two 32-bit values.
435
+ // Extract the low and high word (careful! word order).
436
+ auto *A = spvBuilder.createCompositeExtract (astContext.UnsignedIntTy , tmp,
437
+ {0 }, loc, range);
438
+ auto *B = spvBuilder.createCompositeExtract (astContext.UnsignedIntTy , tmp,
439
+ {1 }, loc, range);
440
+
441
+ // Store the first word, and increment counter.
442
+ auto *index = address.getWordIndex (loc, range);
451
443
auto *ptr = spvBuilder.createAccessChain (astContext.UnsignedIntTy , buffer,
452
444
{constUint0, index }, loc, range);
453
- // First convert the 64-bit value to uint64_t. Then extract two 32-bit words
454
- // from it.
455
- value = bitCastToNumericalOrBool (value, valueType,
456
- astContext.UnsignedLongLongTy , loc, range);
457
-
458
- // Use OpUConvert to perform truncation (produces the least significant bits).
459
- SpirvInstruction *lsb = spvBuilder.createUnaryOp (
460
- spv::Op::OpUConvert, astContext.UnsignedIntTy , value, loc, range);
461
-
462
- // Shift uint64_t to the right by 32 bits and truncate to get the most
463
- // significant bits.
464
- SpirvInstruction *msb = spvBuilder.createUnaryOp (
465
- spv::Op::OpUConvert, astContext.UnsignedIntTy ,
466
- spvBuilder.createBinaryOp (spv::Op::OpShiftRightLogical,
467
- astContext.UnsignedLongLongTy , value,
468
- constUint32, loc, range),
469
- loc, range);
470
-
471
- spvBuilder.createStore (ptr, lsb, loc, range);
445
+ spvBuilder.createStore (ptr, A, loc, range);
472
446
address.incrementWordIndex (loc, range);
447
+
448
+ // Store the second word, and increment counter.
473
449
index = address.getWordIndex (loc, range);
474
450
ptr = spvBuilder.createAccessChain (astContext.UnsignedIntTy , buffer,
475
451
{constUint0, index }, loc, range);
476
- spvBuilder.createStore (ptr, msb , loc, range);
452
+ spvBuilder.createStore (ptr, B , loc, range);
477
453
address.incrementWordIndex (loc, range);
478
454
}
479
455
0 commit comments