Skip to content

Commit e4636f0

Browse files
authored
[SPIR-V] Avoid emitting Int64 when loading Float64 (microsoft#7073)
When loading a Float64 from a raw buffer, we used an Int64, which required an additional capability, even if the code wasn't using any Int64. In practice, it seems most devices supporting Float64 do also support Int64, but this it doesn't have to. By changing the codegen a bit, we can avoid the Int64 value. Tested the word-order using a vulkan compute shader, and checking the returned value on the API side. ```hlsl double tmp = buffer.Load<double>(0); if (tmp == 12.0) buffer.Store<double>(0, 13.0); ``` Fixes microsoft#7038 --------- Signed-off-by: Nathan Gauër <[email protected]>
1 parent e52b6bc commit e4636f0

6 files changed

+229
-235
lines changed

tools/clang/lib/SPIRV/RawBufferMethods.cpp

Lines changed: 29 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -117,48 +117,32 @@ SpirvInstruction *RawBufferHandler::load64Bits(SpirvInstruction *buffer,
117117
SpirvInstruction *ptr = nullptr;
118118
auto *constUint0 =
119119
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
120-
auto *constUint32 =
121-
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));
122120

121+
// Load the first word and increment index.
123122
auto *index = address.getWordIndex(loc, range);
124-
125-
// Need to perform two 32-bit uint loads and construct a 64-bit value.
126-
127-
// Load the first 32-bit uint (word0).
128123
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
129124
{constUint0, index}, loc, range);
130125
SpirvInstruction *word0 =
131126
spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc, range);
132-
// Increment the base index
133127
address.incrementWordIndex(loc, range);
128+
129+
// Load the second word and increment index.
134130
index = address.getWordIndex(loc, range);
135-
// Load the second 32-bit uint (word1).
136131
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
137132
{constUint0, index}, loc, range);
138133
SpirvInstruction *word1 =
139134
spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc, range);
140-
141-
// Convert both word0 and word1 to 64-bit uints.
142-
word0 = spvBuilder.createUnaryOp(
143-
spv::Op::OpUConvert, astContext.UnsignedLongLongTy, word0, loc, range);
144-
word1 = spvBuilder.createUnaryOp(
145-
spv::Op::OpUConvert, astContext.UnsignedLongLongTy, word1, loc, range);
146-
147-
// Shift word1 to the left by 32 bits.
148-
word1 = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
149-
astContext.UnsignedLongLongTy, word1,
150-
constUint32, loc, range);
151-
152-
// BitwiseOr word0 and word1.
153-
result = spvBuilder.createBinaryOp(spv::Op::OpBitwiseOr,
154-
astContext.UnsignedLongLongTy, word0,
155-
word1, loc, range);
156-
result = bitCastToNumericalOrBool(result, astContext.UnsignedLongLongTy,
157-
target64BitType, loc, range);
158-
result->setRValue();
159-
160135
address.incrementWordIndex(loc, range);
161136

137+
// Combine the 2 words into a composite, and bitcast into the destination
138+
// type.
139+
const auto uintVec2Type =
140+
astContext.getExtVectorType(astContext.UnsignedIntTy, 2);
141+
auto *operand = spvBuilder.createCompositeConstruct(
142+
uintVec2Type, {word0, word1}, loc, range);
143+
result = spvBuilder.createUnaryOp(spv::Op::OpBitcast, target64BitType,
144+
operand, loc, range);
145+
result->setRValue();
162146
return result;
163147
}
164148

@@ -441,39 +425,31 @@ void RawBufferHandler::store64Bits(SpirvInstruction *value,
441425
const auto loc = buffer->getSourceLocation();
442426
auto *constUint0 =
443427
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
444-
auto *constUint32 =
445-
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));
446428

447-
auto *index = address.getWordIndex(loc, range);
429+
// Bitcast the source into a 32-bit words composite.
430+
const auto uintVec2Type =
431+
astContext.getExtVectorType(astContext.UnsignedIntTy, 2);
432+
auto *tmp = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type, value,
433+
loc, range);
448434

449-
// The underlying element type of the ByteAddressBuffer is uint. So we
450-
// need to store two 32-bit values.
435+
// Extract the low and high word (careful! word order).
436+
auto *A = spvBuilder.createCompositeExtract(astContext.UnsignedIntTy, tmp,
437+
{0}, loc, range);
438+
auto *B = spvBuilder.createCompositeExtract(astContext.UnsignedIntTy, tmp,
439+
{1}, loc, range);
440+
441+
// Store the first word, and increment counter.
442+
auto *index = address.getWordIndex(loc, range);
451443
auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
452444
{constUint0, index}, loc, range);
453-
// First convert the 64-bit value to uint64_t. Then extract two 32-bit words
454-
// from it.
455-
value = bitCastToNumericalOrBool(value, valueType,
456-
astContext.UnsignedLongLongTy, loc, range);
457-
458-
// Use OpUConvert to perform truncation (produces the least significant bits).
459-
SpirvInstruction *lsb = spvBuilder.createUnaryOp(
460-
spv::Op::OpUConvert, astContext.UnsignedIntTy, value, loc, range);
461-
462-
// Shift uint64_t to the right by 32 bits and truncate to get the most
463-
// significant bits.
464-
SpirvInstruction *msb = spvBuilder.createUnaryOp(
465-
spv::Op::OpUConvert, astContext.UnsignedIntTy,
466-
spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical,
467-
astContext.UnsignedLongLongTy, value,
468-
constUint32, loc, range),
469-
loc, range);
470-
471-
spvBuilder.createStore(ptr, lsb, loc, range);
445+
spvBuilder.createStore(ptr, A, loc, range);
472446
address.incrementWordIndex(loc, range);
447+
448+
// Store the second word, and increment counter.
473449
index = address.getWordIndex(loc, range);
474450
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
475451
{constUint0, index}, loc, range);
476-
spvBuilder.createStore(ptr, msb, loc, range);
452+
spvBuilder.createStore(ptr, B, loc, range);
477453
address.incrementWordIndex(loc, range);
478454
}
479455

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %dxc -T cs_6_0 -E main -O0 %s -spirv | FileCheck %s
2+
3+
// CHECK-NOT: OpCapability Int64
4+
// CHECK-DAG: OpCapability Float64
5+
// CHECK-NOT: OpCapability Int64
6+
7+
RWByteAddressBuffer buffer;
8+
9+
[numthreads(1, 1, 1)]
10+
void main() {
11+
double tmp;
12+
13+
// CHECK: [[addr1:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
14+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr1]]
15+
// CHECK: [[word0:%[0-9]+]] = OpLoad %uint [[ptr]]
16+
// CHECK: [[addr2:%[0-9]+]] = OpIAdd %uint [[addr1]] %uint_1
17+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr2]]
18+
// CHECK: [[word1:%[0-9]+]] = OpLoad %uint [[ptr]]
19+
// CHECK: [[addr3:%[0-9]+]] = OpIAdd %uint [[addr2]] %uint_1
20+
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word0]] [[word1]]
21+
// CHECK: [[value:%[0-9]+]] = OpBitcast %double [[merge]]
22+
// CHECK: OpStore %tmp [[value]]
23+
tmp = buffer.Load<double>(0);
24+
25+
// CHECK: [[value:%[0-9]+]] = OpLoad %double %tmp
26+
// CHECK: [[merge:%[0-9]+]] = OpBitcast %v2uint [[value]]
27+
// CHECK: [[word0:%[0-9]+]] = OpCompositeExtract %uint [[merge]] 0
28+
// CHECK: [[word1:%[0-9]+]] = OpCompositeExtract %uint [[merge]] 1
29+
30+
// CHECK: [[addr1:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
31+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr1]]
32+
// CHECK: OpStore [[ptr]] [[word0]]
33+
// CHECK: [[addr2:%[0-9]+]] = OpIAdd %uint [[addr1]] %uint_1
34+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr2]]
35+
// CHECK: OpStore [[ptr]] [[word1]]
36+
// CHECK: [[addr3:%[0-9]+]] = OpIAdd %uint [[addr2]] %uint_1
37+
buffer.Store<double>(0, tmp);
38+
}
39+

tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.templated-load.matrix.hlsl

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -98,53 +98,46 @@ void main(uint3 tid : SV_DispatchThreadId)
9898
// ********* 64-bit matrix ********************
9999

100100
// CHECK: [[index_1:%[0-9]+]] = OpShiftRightLogical %uint [[addr0_1:%[0-9]+]] %uint_2
101-
// CHECK: [[ptr_11:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1]]
102-
// CHECK: [[word0_2:%[0-9]+]] = OpLoad %uint [[ptr_11]]
103-
// CHECK: [[index_1_2:%[0-9]+]] = OpIAdd %uint [[index_1]] %uint_1
104-
// CHECK: [[ptr_12:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1_2]]
105-
// CHECK: [[word1_3:%[0-9]+]] = OpLoad %uint [[ptr_12]]
106-
// CHECK: [[word0_ulong:%[0-9]+]] = OpUConvert %ulong [[word0_2]]
107-
// CHECK: [[word1_ulong:%[0-9]+]] = OpUConvert %ulong [[word1_3]]
108-
// CHECK: [[word1_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word1_ulong]] %uint_32
109-
// CHECK: [[val0_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word0_ulong]] [[word1_ulong_shifted]]
110-
// CHECK: [[val0_1:%[0-9]+]] = OpBitcast %double [[val0_ulong]]
111-
// CHECK: [[index_2_2:%[0-9]+]] = OpIAdd %uint [[index_1_2]] %uint_1
112-
// CHECK: [[ptr_13:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_2_2]]
113-
// CHECK: [[word2_2:%[0-9]+]] = OpLoad %uint [[ptr_13]]
114-
// CHECK: [[index_3_0:%[0-9]+]] = OpIAdd %uint [[index_2_2]] %uint_1
115-
// CHECK: [[ptr_14:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_3_0]]
116-
// CHECK: [[word3_0:%[0-9]+]] = OpLoad %uint [[ptr_14]]
117-
// CHECK: [[word2_ulong:%[0-9]+]] = OpUConvert %ulong [[word2_2]]
118-
// CHECK: [[word3_ulong:%[0-9]+]] = OpUConvert %ulong [[word3_0]]
119-
// CHECK: [[word3_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word3_ulong]] %uint_32
120-
// CHECK: [[val1_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word2_ulong]] [[word3_ulong_shifted]]
121-
// CHECK: [[val1_1:%[0-9]+]] = OpBitcast %double [[val1_ulong]]
122-
// CHECK: [[index_4_0:%[0-9]+]] = OpIAdd %uint [[index_3_0]] %uint_1
123-
// CHECK: [[ptr_15:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_4_0]]
124-
// CHECK: [[word4_0:%[0-9]+]] = OpLoad %uint [[ptr_15]]
125-
// CHECK: [[index_5_0:%[0-9]+]] = OpIAdd %uint [[index_4_0]] %uint_1
126-
// CHECK: [[ptr_16:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_5_0]]
127-
// CHECK: [[word5_0:%[0-9]+]] = OpLoad %uint [[ptr_16]]
128-
// CHECK: [[word4_ulong:%[0-9]+]] = OpUConvert %ulong [[word4_0]]
129-
// CHECK: [[word5_ulong:%[0-9]+]] = OpUConvert %ulong [[word5_0]]
130-
// CHECK: [[word5_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word5_ulong]] %uint_32
131-
// CHECK: [[val2_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word4_ulong]] [[word5_ulong_shifted]]
132-
// CHECK: [[val2_1:%[0-9]+]] = OpBitcast %double [[val2_ulong]]
101+
// CHECK: [[ptr_11:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1]]
102+
// CHECK: [[word0_2:%[0-9]+]] = OpLoad %uint [[ptr_11]]
103+
// CHECK: [[index_1_2:%[0-9]+]] = OpIAdd %uint [[index_1]] %uint_1
104+
// CHECK: [[ptr_12:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1_2]]
105+
// CHECK: [[word1_3:%[0-9]+]] = OpLoad %uint [[ptr_12]]
106+
// CHECK: [[index_2_2:%[0-9]+]] = OpIAdd %uint [[index_1_2]] %uint_1
107+
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word0_2]] [[word1_3]]
108+
// CHECK: [[val0_1:%[0-9]+]] = OpBitcast %double [[merge]]
109+
110+
// CHECK: [[ptr_13:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_2_2]]
111+
// CHECK: [[word2_2:%[0-9]+]] = OpLoad %uint [[ptr_13]]
112+
// CHECK: [[index_3_0:%[0-9]+]] = OpIAdd %uint [[index_2_2]] %uint_1
113+
// CHECK: [[ptr_14:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_3_0]]
114+
// CHECK: [[word3_0:%[0-9]+]] = OpLoad %uint [[ptr_14]]
115+
// CHECK: [[index_4_0:%[0-9]+]] = OpIAdd %uint [[index_3_0]] %uint_1
116+
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word2_2]] [[word3_0]]
117+
// CHECK: [[val1_1:%[0-9]+]] = OpBitcast %double [[merge]]
118+
119+
// CHECK: [[ptr_15:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_4_0]]
120+
// CHECK: [[word4_0:%[0-9]+]] = OpLoad %uint [[ptr_15]]
121+
// CHECK: [[index_5_0:%[0-9]+]] = OpIAdd %uint [[index_4_0]] %uint_1
122+
// CHECK: [[ptr_16:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_5_0]]
123+
// CHECK: [[word5_0:%[0-9]+]] = OpLoad %uint [[ptr_16]]
133124
// CHECK: [[index_6:%[0-9]+]] = OpIAdd %uint [[index_5_0]] %uint_1
134-
// CHECK: [[ptr_17:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_6]]
125+
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word4_0]] [[word5_0]]
126+
// CHECK: [[val2_1:%[0-9]+]] = OpBitcast %double [[merge]]
127+
128+
// CHECK: [[ptr_17:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_6]]
135129
// CHECK: [[word6:%[0-9]+]] = OpLoad %uint [[ptr_17]]
136130
// CHECK: [[index_7:%[0-9]+]] = OpIAdd %uint [[index_6]] %uint_1
137-
// CHECK: [[ptr_18:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_7]]
131+
// CHECK: [[ptr_18:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_7]]
138132
// CHECK: [[word7:%[0-9]+]] = OpLoad %uint [[ptr_18]]
139-
// CHECK: [[word6_ulong:%[0-9]+]] = OpUConvert %ulong [[word6]]
140-
// CHECK: [[word7_ulong:%[0-9]+]] = OpUConvert %ulong [[word7]]
141-
// CHECK: [[word7_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word7_ulong]] %uint_32
142-
// CHECK: [[val3_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word6_ulong]] [[word7_ulong_shifted]]
143-
// CHECK: [[val3_1:%[0-9]+]] = OpBitcast %double [[val3_ulong]]
144-
// CHECK: [[row0_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val0_1]] [[val2_1]]
145-
// CHECK: [[row1_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val1_1]] [[val3_1]]
146-
// CHECK: [[matrix_1:%[0-9]+]] = OpCompositeConstruct %mat2v2double [[row0_1]] [[row1_1]]
147-
// CHECK: OpStore %f64 [[matrix_1]]
133+
// CHECK: [[index_8:%[0-9]+]] = OpIAdd %uint [[index_7]] %uint_1
134+
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word6]] [[word7]]
135+
// CHECK: [[val3_1:%[0-9]+]] = OpBitcast %double [[merge]]
136+
137+
// CHECK: [[row0_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val0_1]] [[val2_1]]
138+
// CHECK: [[row1_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val1_1]] [[val3_1]]
139+
// CHECK: [[matrix_1:%[0-9]+]] = OpCompositeConstruct %mat2v2double [[row0_1]] [[row1_1]]
140+
// CHECK: OpStore %f64 [[matrix_1]]
148141
float64_t2x2 f64 = buf.Load<float64_t2x2>(tid.x);
149142

150143
// ********* array of matrices ********************

0 commit comments

Comments
 (0)