Skip to content

Commit 0ff9259

Browse files
authored
[flang][cuda][NFC] Extract is cuda device attribute logic (llvm#100809)
1 parent dbb8b7a commit 0ff9259

File tree

1 file changed

+23
-26
lines changed
  • flang/include/flang/Evaluate

1 file changed

+23
-26
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
12431243
const std::optional<ActualArgument> &, const std::string &procName,
12441244
const std::string &argName);
12451245

1246-
// Get the number of distinct symbols with CUDA attribute in the expression.
1246+
inline bool IsCUDADeviceSymbol(const Symbol &sym) {
1247+
if (const auto *details =
1248+
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
1249+
if (details->cudaDataAttr() &&
1250+
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
1251+
return true;
1252+
}
1253+
}
1254+
return false;
1255+
}
1256+
1257+
// Get the number of distinct symbols with CUDA device
1258+
// attribute in the expression.
12471259
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
12481260
semantics::UnorderedSymbolSet symbols;
12491261
for (const Symbol &sym : CollectCudaSymbols(expr)) {
1250-
if (const auto *details =
1251-
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
1252-
if (details->cudaDataAttr() &&
1253-
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
1254-
symbols.insert(sym);
1255-
}
1262+
if (IsCUDADeviceSymbol(sym)) {
1263+
symbols.insert(sym);
12561264
}
12571265
}
12581266
return symbols.size();
12591267
}
12601268

1261-
// Check if any of the symbols part of the expression has a CUDA data
1269+
// Check if any of the symbols part of the expression has a CUDA device
12621270
// attribute.
12631271
template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
12641272
return GetNbOfCUDADeviceSymbols(expr) > 0;
@@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
12701278
unsigned hostSymbols{0};
12711279
unsigned deviceSymbols{0};
12721280
for (const Symbol &sym : CollectCudaSymbols(expr)) {
1273-
if (const auto *details =
1274-
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
1275-
if (details->cudaDataAttr() &&
1276-
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
1277-
++deviceSymbols;
1278-
} else {
1279-
if (sym.owner().IsDerivedType()) {
1280-
if (const auto *details =
1281-
sym.owner()
1282-
.GetSymbol()
1283-
->GetUltimate()
1284-
.detailsIf<semantics::ObjectEntityDetails>()) {
1285-
if (details->cudaDataAttr() &&
1286-
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
1287-
++deviceSymbols;
1288-
}
1289-
}
1281+
if (IsCUDADeviceSymbol(sym)) {
1282+
++deviceSymbols;
1283+
} else {
1284+
if (sym.owner().IsDerivedType()) {
1285+
if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
1286+
++deviceSymbols;
12901287
}
1291-
++hostSymbols;
12921288
}
1289+
++hostSymbols;
12931290
}
12941291
}
12951292
return hostSymbols > 0 && deviceSymbols > 0;

0 commit comments

Comments
 (0)