|
17 | 17 | #include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"
|
18 | 18 | #include "llvm/SYCLLowerIR/SYCLUtils.h"
|
19 | 19 |
|
| 20 | +#include "../../IR/ConstantsContext.h" |
20 | 21 | #include "llvm/ADT/DenseMap.h"
|
21 | 22 | #include "llvm/ADT/DenseSet.h"
|
22 | 23 | #include "llvm/ADT/SmallVector.h"
|
@@ -1142,9 +1143,8 @@ static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
|
1142 | 1143 | /// right before the given extract element instruction \p EEI using the result
|
1143 | 1144 | /// of vector load. The parameter \p IsVectorCall tells what version of GenX
|
1144 | 1145 | /// intrinsic (scalar or vector) to use to lower the load from SPIRV global.
|
1145 |
| -static Instruction *generateGenXCall(ExtractElementInst *EEI, |
1146 |
| - StringRef IntrinName, bool IsVectorCall) { |
1147 |
| - uint64_t IndexValue = getIndexFromExtract(EEI); |
| 1146 | +static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName, |
| 1147 | + bool IsVectorCall, uint64_t IndexValue) { |
1148 | 1148 | std::string Suffix =
|
1149 | 1149 | IsVectorCall
|
1150 | 1150 | ? ".v3i32"
|
@@ -1244,39 +1244,58 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
|
1244 | 1244 | }
|
1245 | 1245 |
|
1246 | 1246 | // Only loads from _vector_ SPIRV globals reach here now. Their users are
|
1247 |
| - // expected to be ExtractElementInst only, and they are replaced in this loop. |
1248 |
| - // When loads from _scalar_ SPIRV globals are handled here as well, the users |
1249 |
| - // will not be replaced by new instructions, but the GenX call replacing the |
1250 |
| - // original load 'LI' should be inserted before each user. |
| 1247 | + // expected to be ExtractElementInst or TruncInst only, and they are replaced |
| 1248 | + // in this loop. When loads from _scalar_ SPIRV globals are handled here as |
| 1249 | + // well, the users will not be replaced by new instructions, but the GenX call |
| 1250 | + // replacing the original load 'LI' should be inserted before each user. |
1251 | 1251 | for (User *LU : LI->users()) {
|
1252 |
| - ExtractElementInst *EEI = cast<ExtractElementInst>(LU); |
| 1252 | + assert( |
| 1253 | + (isa<ExtractElementInst>(LU) || isa<TruncInst>(LU)) && |
| 1254 | + "SPIRV global users should be either ExtractElementInst or TruncInst"); |
| 1255 | + Instruction *EEI = cast<Instruction>(LU); |
1253 | 1256 | NewInst = nullptr;
|
1254 | 1257 |
|
| 1258 | + uint64_t IndexValue = 0; |
| 1259 | + if (isa<ExtractElementInst>(EEI)) { |
| 1260 | + IndexValue = getIndexFromExtract(cast<ExtractElementInst>(EEI)); |
| 1261 | + } else { |
| 1262 | + auto *GEPCE = cast<GetElementPtrConstantExpr>(LI->getPointerOperand()); |
| 1263 | + |
| 1264 | + IndexValue = cast<Constant>(GEPCE->getOperand(2)) |
| 1265 | + ->getUniqueInteger() |
| 1266 | + .getZExtValue(); |
| 1267 | + } |
| 1268 | + |
1255 | 1269 | if (SpirvGlobalName == "WorkgroupSize") {
|
1256 |
| - NewInst = generateGenXCall(EEI, "local.size", true); |
| 1270 | + NewInst = generateGenXCall(EEI, "local.size", true, IndexValue); |
1257 | 1271 | } else if (SpirvGlobalName == "LocalInvocationId") {
|
1258 |
| - NewInst = generateGenXCall(EEI, "local.id", true); |
| 1272 | + NewInst = generateGenXCall(EEI, "local.id", true, IndexValue); |
1259 | 1273 | } else if (SpirvGlobalName == "WorkgroupId") {
|
1260 |
| - NewInst = generateGenXCall(EEI, "group.id", false); |
| 1274 | + NewInst = generateGenXCall(EEI, "group.id", false, IndexValue); |
1261 | 1275 | } else if (SpirvGlobalName == "GlobalInvocationId") {
|
1262 | 1276 | // GlobalId = LocalId + WorkGroupSize * GroupId
|
1263 |
| - Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true); |
1264 |
| - Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true); |
1265 |
| - Instruction *GroupIdI = generateGenXCall(EEI, "group.id", false); |
| 1277 | + Instruction *LocalIdI = |
| 1278 | + generateGenXCall(EEI, "local.id", true, IndexValue); |
| 1279 | + Instruction *WGSizeI = |
| 1280 | + generateGenXCall(EEI, "local.size", true, IndexValue); |
| 1281 | + Instruction *GroupIdI = |
| 1282 | + generateGenXCall(EEI, "group.id", false, IndexValue); |
1266 | 1283 | Instruction *MulI =
|
1267 | 1284 | BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
|
1268 | 1285 | NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
|
1269 | 1286 | } else if (SpirvGlobalName == "GlobalSize") {
|
1270 | 1287 | // GlobalSize = WorkGroupSize * NumWorkGroups
|
1271 |
| - Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true); |
1272 |
| - Instruction *NumWGI = generateGenXCall(EEI, "group.count", true); |
| 1288 | + Instruction *WGSizeI = |
| 1289 | + generateGenXCall(EEI, "local.size", true, IndexValue); |
| 1290 | + Instruction *NumWGI = |
| 1291 | + generateGenXCall(EEI, "group.count", true, IndexValue); |
1273 | 1292 | NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
|
1274 | 1293 | } else if (SpirvGlobalName == "GlobalOffset") {
|
1275 | 1294 | // TODO: Support GlobalOffset SPIRV intrinsics
|
1276 | 1295 | // Currently all users of load of GlobalOffset are replaced with 0.
|
1277 | 1296 | NewInst = llvm::Constant::getNullValue(EEI->getType());
|
1278 | 1297 | } else if (SpirvGlobalName == "NumWorkgroups") {
|
1279 |
| - NewInst = generateGenXCall(EEI, "group.count", true); |
| 1298 | + NewInst = generateGenXCall(EEI, "group.count", true, IndexValue); |
1280 | 1299 | }
|
1281 | 1300 |
|
1282 | 1301 | llvm::esimd::assert_and_diag(
|
@@ -1786,9 +1805,13 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
|
1786 | 1805 | if (LI) {
|
1787 | 1806 | Value *LoadPtrOp = LI->getPointerOperand();
|
1788 | 1807 | Value *SpirvGlobal = nullptr;
|
1789 |
| - // Look through casts to find SPIRV builtin globals |
| 1808 | + // Look through constant expressions to find SPIRV builtin globals |
| 1809 | + // It may come with or without cast. |
1790 | 1810 | auto *CE = dyn_cast<ConstantExpr>(LoadPtrOp);
|
1791 |
| - if (CE) { |
| 1811 | + auto *GEPCE = dyn_cast<GetElementPtrConstantExpr>(LoadPtrOp); |
| 1812 | + if (GEPCE) { |
| 1813 | + SpirvGlobal = GEPCE->getOperand(0); |
| 1814 | + } else if (CE) { |
1792 | 1815 | assert(CE->isCast() && "ConstExpr should be a cast");
|
1793 | 1816 | SpirvGlobal = CE->getOperand(0);
|
1794 | 1817 | } else {
|
|
0 commit comments