Skip to content

[SPIR-V] Avoid emitting Int64 when loading Float64 #7073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 29 additions & 53 deletions tools/clang/lib/SPIRV/RawBufferMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,48 +117,32 @@ SpirvInstruction *RawBufferHandler::load64Bits(SpirvInstruction *buffer,
SpirvInstruction *ptr = nullptr;
auto *constUint0 =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
auto *constUint32 =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));

// Load the first word and increment index.
auto *index = address.getWordIndex(loc, range);

// Need to perform two 32-bit uint loads and construct a 64-bit value.

// Load the first 32-bit uint (word0).
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
{constUint0, index}, loc, range);
SpirvInstruction *word0 =
spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc, range);
// Increment the base index
address.incrementWordIndex(loc, range);

// Load the second word and increment index.
index = address.getWordIndex(loc, range);
// Load the second 32-bit uint (word1).
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
{constUint0, index}, loc, range);
SpirvInstruction *word1 =
spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc, range);

// Convert both word0 and word1 to 64-bit uints.
word0 = spvBuilder.createUnaryOp(
spv::Op::OpUConvert, astContext.UnsignedLongLongTy, word0, loc, range);
word1 = spvBuilder.createUnaryOp(
spv::Op::OpUConvert, astContext.UnsignedLongLongTy, word1, loc, range);

// Shift word1 to the left by 32 bits.
word1 = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
astContext.UnsignedLongLongTy, word1,
constUint32, loc, range);

// BitwiseOr word0 and word1.
result = spvBuilder.createBinaryOp(spv::Op::OpBitwiseOr,
astContext.UnsignedLongLongTy, word0,
word1, loc, range);
result = bitCastToNumericalOrBool(result, astContext.UnsignedLongLongTy,
target64BitType, loc, range);
result->setRValue();

address.incrementWordIndex(loc, range);

// Combine the 2 words into a composite, and bitcast into the destination
// type.
const auto uintVec2Type =
astContext.getExtVectorType(astContext.UnsignedIntTy, 2);
auto *operand = spvBuilder.createCompositeConstruct(
uintVec2Type, {word0, word1}, loc, range);
result = spvBuilder.createUnaryOp(spv::Op::OpBitcast, target64BitType,
operand, loc, range);
result->setRValue();
return result;
}

Expand Down Expand Up @@ -441,39 +425,31 @@ void RawBufferHandler::store64Bits(SpirvInstruction *value,
const auto loc = buffer->getSourceLocation();
auto *constUint0 =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
auto *constUint32 =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));

auto *index = address.getWordIndex(loc, range);
// Bitcast the source into a 32-bit words composite.
const auto uintVec2Type =
astContext.getExtVectorType(astContext.UnsignedIntTy, 2);
auto *tmp = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type, value,
loc, range);

// The underlying element type of the ByteAddressBuffer is uint. So we
// need to store two 32-bit values.
// Extract the low and high word (careful! word order).
auto *A = spvBuilder.createCompositeExtract(astContext.UnsignedIntTy, tmp,
{0}, loc, range);
auto *B = spvBuilder.createCompositeExtract(astContext.UnsignedIntTy, tmp,
{1}, loc, range);

// Store the first word, and increment counter.
auto *index = address.getWordIndex(loc, range);
auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
{constUint0, index}, loc, range);
// First convert the 64-bit value to uint64_t. Then extract two 32-bit words
// from it.
value = bitCastToNumericalOrBool(value, valueType,
astContext.UnsignedLongLongTy, loc, range);

// Use OpUConvert to perform truncation (produces the least significant bits).
SpirvInstruction *lsb = spvBuilder.createUnaryOp(
spv::Op::OpUConvert, astContext.UnsignedIntTy, value, loc, range);

// Shift uint64_t to the right by 32 bits and truncate to get the most
// significant bits.
SpirvInstruction *msb = spvBuilder.createUnaryOp(
spv::Op::OpUConvert, astContext.UnsignedIntTy,
spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical,
astContext.UnsignedLongLongTy, value,
constUint32, loc, range),
loc, range);

spvBuilder.createStore(ptr, lsb, loc, range);
spvBuilder.createStore(ptr, A, loc, range);
address.incrementWordIndex(loc, range);

// Store the second word, and increment counter.
index = address.getWordIndex(loc, range);
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
{constUint0, index}, loc, range);
spvBuilder.createStore(ptr, msb, loc, range);
spvBuilder.createStore(ptr, B, loc, range);
address.incrementWordIndex(loc, range);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %dxc -T cs_6_0 -E main -O0 %s -spirv | FileCheck %s

// CHECK-NOT: OpCapability Int64
// CHECK-DAG: OpCapability Float64
// CHECK-NOT: OpCapability Int64

RWByteAddressBuffer buffer;

[numthreads(1, 1, 1)]
void main() {
double tmp;

// CHECK: [[addr1:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr1]]
// CHECK: [[word0:%[0-9]+]] = OpLoad %uint [[ptr]]
// CHECK: [[addr2:%[0-9]+]] = OpIAdd %uint [[addr1]] %uint_1
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr2]]
// CHECK: [[word1:%[0-9]+]] = OpLoad %uint [[ptr]]
// CHECK: [[addr3:%[0-9]+]] = OpIAdd %uint [[addr2]] %uint_1
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word0]] [[word1]]
// CHECK: [[value:%[0-9]+]] = OpBitcast %double [[merge]]
// CHECK: OpStore %tmp [[value]]
tmp = buffer.Load<double>(0);

// CHECK: [[value:%[0-9]+]] = OpLoad %double %tmp
// CHECK: [[merge:%[0-9]+]] = OpBitcast %v2uint [[value]]
// CHECK: [[word0:%[0-9]+]] = OpCompositeExtract %uint [[merge]] 0
// CHECK: [[word1:%[0-9]+]] = OpCompositeExtract %uint [[merge]] 1

// CHECK: [[addr1:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr1]]
// CHECK: OpStore [[ptr]] [[word0]]
// CHECK: [[addr2:%[0-9]+]] = OpIAdd %uint [[addr1]] %uint_1
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr2]]
// CHECK: OpStore [[ptr]] [[word1]]
// CHECK: [[addr3:%[0-9]+]] = OpIAdd %uint [[addr2]] %uint_1
buffer.Store<double>(0, tmp);
}

Original file line number Diff line number Diff line change
Expand Up @@ -98,53 +98,46 @@ void main(uint3 tid : SV_DispatchThreadId)
// ********* 64-bit matrix ********************

// CHECK: [[index_1:%[0-9]+]] = OpShiftRightLogical %uint [[addr0_1:%[0-9]+]] %uint_2
// CHECK: [[ptr_11:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1]]
// CHECK: [[word0_2:%[0-9]+]] = OpLoad %uint [[ptr_11]]
// CHECK: [[index_1_2:%[0-9]+]] = OpIAdd %uint [[index_1]] %uint_1
// CHECK: [[ptr_12:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1_2]]
// CHECK: [[word1_3:%[0-9]+]] = OpLoad %uint [[ptr_12]]
// CHECK: [[word0_ulong:%[0-9]+]] = OpUConvert %ulong [[word0_2]]
// CHECK: [[word1_ulong:%[0-9]+]] = OpUConvert %ulong [[word1_3]]
// CHECK: [[word1_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word1_ulong]] %uint_32
// CHECK: [[val0_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word0_ulong]] [[word1_ulong_shifted]]
// CHECK: [[val0_1:%[0-9]+]] = OpBitcast %double [[val0_ulong]]
// CHECK: [[index_2_2:%[0-9]+]] = OpIAdd %uint [[index_1_2]] %uint_1
// CHECK: [[ptr_13:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_2_2]]
// CHECK: [[word2_2:%[0-9]+]] = OpLoad %uint [[ptr_13]]
// CHECK: [[index_3_0:%[0-9]+]] = OpIAdd %uint [[index_2_2]] %uint_1
// CHECK: [[ptr_14:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_3_0]]
// CHECK: [[word3_0:%[0-9]+]] = OpLoad %uint [[ptr_14]]
// CHECK: [[word2_ulong:%[0-9]+]] = OpUConvert %ulong [[word2_2]]
// CHECK: [[word3_ulong:%[0-9]+]] = OpUConvert %ulong [[word3_0]]
// CHECK: [[word3_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word3_ulong]] %uint_32
// CHECK: [[val1_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word2_ulong]] [[word3_ulong_shifted]]
// CHECK: [[val1_1:%[0-9]+]] = OpBitcast %double [[val1_ulong]]
// CHECK: [[index_4_0:%[0-9]+]] = OpIAdd %uint [[index_3_0]] %uint_1
// CHECK: [[ptr_15:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_4_0]]
// CHECK: [[word4_0:%[0-9]+]] = OpLoad %uint [[ptr_15]]
// CHECK: [[index_5_0:%[0-9]+]] = OpIAdd %uint [[index_4_0]] %uint_1
// CHECK: [[ptr_16:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_5_0]]
// CHECK: [[word5_0:%[0-9]+]] = OpLoad %uint [[ptr_16]]
// CHECK: [[word4_ulong:%[0-9]+]] = OpUConvert %ulong [[word4_0]]
// CHECK: [[word5_ulong:%[0-9]+]] = OpUConvert %ulong [[word5_0]]
// CHECK: [[word5_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word5_ulong]] %uint_32
// CHECK: [[val2_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word4_ulong]] [[word5_ulong_shifted]]
// CHECK: [[val2_1:%[0-9]+]] = OpBitcast %double [[val2_ulong]]
// CHECK: [[ptr_11:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1]]
// CHECK: [[word0_2:%[0-9]+]] = OpLoad %uint [[ptr_11]]
// CHECK: [[index_1_2:%[0-9]+]] = OpIAdd %uint [[index_1]] %uint_1
// CHECK: [[ptr_12:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1_2]]
// CHECK: [[word1_3:%[0-9]+]] = OpLoad %uint [[ptr_12]]
// CHECK: [[index_2_2:%[0-9]+]] = OpIAdd %uint [[index_1_2]] %uint_1
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word0_2]] [[word1_3]]
// CHECK: [[val0_1:%[0-9]+]] = OpBitcast %double [[merge]]

// CHECK: [[ptr_13:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_2_2]]
// CHECK: [[word2_2:%[0-9]+]] = OpLoad %uint [[ptr_13]]
// CHECK: [[index_3_0:%[0-9]+]] = OpIAdd %uint [[index_2_2]] %uint_1
// CHECK: [[ptr_14:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_3_0]]
// CHECK: [[word3_0:%[0-9]+]] = OpLoad %uint [[ptr_14]]
// CHECK: [[index_4_0:%[0-9]+]] = OpIAdd %uint [[index_3_0]] %uint_1
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word2_2]] [[word3_0]]
// CHECK: [[val1_1:%[0-9]+]] = OpBitcast %double [[merge]]

// CHECK: [[ptr_15:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_4_0]]
// CHECK: [[word4_0:%[0-9]+]] = OpLoad %uint [[ptr_15]]
// CHECK: [[index_5_0:%[0-9]+]] = OpIAdd %uint [[index_4_0]] %uint_1
// CHECK: [[ptr_16:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_5_0]]
// CHECK: [[word5_0:%[0-9]+]] = OpLoad %uint [[ptr_16]]
// CHECK: [[index_6:%[0-9]+]] = OpIAdd %uint [[index_5_0]] %uint_1
// CHECK: [[ptr_17:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_6]]
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word4_0]] [[word5_0]]
// CHECK: [[val2_1:%[0-9]+]] = OpBitcast %double [[merge]]

// CHECK: [[ptr_17:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_6]]
// CHECK: [[word6:%[0-9]+]] = OpLoad %uint [[ptr_17]]
// CHECK: [[index_7:%[0-9]+]] = OpIAdd %uint [[index_6]] %uint_1
// CHECK: [[ptr_18:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_7]]
// CHECK: [[ptr_18:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_7]]
// CHECK: [[word7:%[0-9]+]] = OpLoad %uint [[ptr_18]]
// CHECK: [[word6_ulong:%[0-9]+]] = OpUConvert %ulong [[word6]]
// CHECK: [[word7_ulong:%[0-9]+]] = OpUConvert %ulong [[word7]]
// CHECK: [[word7_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word7_ulong]] %uint_32
// CHECK: [[val3_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word6_ulong]] [[word7_ulong_shifted]]
// CHECK: [[val3_1:%[0-9]+]] = OpBitcast %double [[val3_ulong]]
// CHECK: [[row0_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val0_1]] [[val2_1]]
// CHECK: [[row1_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val1_1]] [[val3_1]]
// CHECK: [[matrix_1:%[0-9]+]] = OpCompositeConstruct %mat2v2double [[row0_1]] [[row1_1]]
// CHECK: OpStore %f64 [[matrix_1]]
// CHECK: [[index_8:%[0-9]+]] = OpIAdd %uint [[index_7]] %uint_1
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word6]] [[word7]]
// CHECK: [[val3_1:%[0-9]+]] = OpBitcast %double [[merge]]

// CHECK: [[row0_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val0_1]] [[val2_1]]
// CHECK: [[row1_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val1_1]] [[val3_1]]
// CHECK: [[matrix_1:%[0-9]+]] = OpCompositeConstruct %mat2v2double [[row0_1]] [[row1_1]]
// CHECK: OpStore %f64 [[matrix_1]]
float64_t2x2 f64 = buf.Load<float64_t2x2>(tid.x);

// ********* array of matrices ********************
Expand Down
Loading
Loading