@@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
1243
1243
const std::optional<ActualArgument> &, const std::string &procName,
1244
1244
const std::string &argName);
1245
1245
1246
- // Get the number of distinct symbols with CUDA attribute in the expression.
1246
+ inline bool IsCUDADeviceSymbol (const Symbol &sym) {
1247
+ if (const auto *details =
1248
+ sym.GetUltimate ().detailsIf <semantics::ObjectEntityDetails>()) {
1249
+ if (details->cudaDataAttr () &&
1250
+ *details->cudaDataAttr () != common::CUDADataAttr::Pinned) {
1251
+ return true ;
1252
+ }
1253
+ }
1254
+ return false ;
1255
+ }
1256
+
1257
+ // Get the number of distinct symbols with CUDA device
1258
+ // attribute in the expression.
1247
1259
template <typename A> inline int GetNbOfCUDADeviceSymbols (const A &expr) {
1248
1260
semantics::UnorderedSymbolSet symbols;
1249
1261
for (const Symbol &sym : CollectCudaSymbols (expr)) {
1250
- if (const auto *details =
1251
- sym.GetUltimate ().detailsIf <semantics::ObjectEntityDetails>()) {
1252
- if (details->cudaDataAttr () &&
1253
- *details->cudaDataAttr () != common::CUDADataAttr::Pinned) {
1254
- symbols.insert (sym);
1255
- }
1262
+ if (IsCUDADeviceSymbol (sym)) {
1263
+ symbols.insert (sym);
1256
1264
}
1257
1265
}
1258
1266
return symbols.size ();
1259
1267
}
1260
1268
1261
- // Check if any of the symbols part of the expression has a CUDA data
1269
+ // Check if any of the symbols part of the expression has a CUDA device
1262
1270
// attribute.
1263
1271
template <typename A> inline bool HasCUDADeviceAttrs (const A &expr) {
1264
1272
return GetNbOfCUDADeviceSymbols (expr) > 0 ;
@@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
1270
1278
unsigned hostSymbols{0 };
1271
1279
unsigned deviceSymbols{0 };
1272
1280
for (const Symbol &sym : CollectCudaSymbols (expr)) {
1273
- if (const auto *details =
1274
- sym.GetUltimate ().detailsIf <semantics::ObjectEntityDetails>()) {
1275
- if (details->cudaDataAttr () &&
1276
- *details->cudaDataAttr () != common::CUDADataAttr::Pinned) {
1277
- ++deviceSymbols;
1278
- } else {
1279
- if (sym.owner ().IsDerivedType ()) {
1280
- if (const auto *details =
1281
- sym.owner ()
1282
- .GetSymbol ()
1283
- ->GetUltimate ()
1284
- .detailsIf <semantics::ObjectEntityDetails>()) {
1285
- if (details->cudaDataAttr () &&
1286
- *details->cudaDataAttr () != common::CUDADataAttr::Pinned) {
1287
- ++deviceSymbols;
1288
- }
1289
- }
1281
+ if (IsCUDADeviceSymbol (sym)) {
1282
+ ++deviceSymbols;
1283
+ } else {
1284
+ if (sym.owner ().IsDerivedType ()) {
1285
+ if (IsCUDADeviceSymbol (sym.owner ().GetSymbol ()->GetUltimate ())) {
1286
+ ++deviceSymbols;
1290
1287
}
1291
- ++hostSymbols;
1292
1288
}
1289
+ ++hostSymbols;
1293
1290
}
1294
1291
}
1295
1292
return hostSymbols > 0 && deviceSymbols > 0 ;
0 commit comments