Skip to content

Commit

Permalink
[flang][cuda][NFC] Extract is cuda device attribute logic (#100809)
Browse files Browse the repository at this point in the history
  • Loading branch information
clementval authored Jul 26, 2024
1 parent dbb8b7a commit 0ff9259
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
const std::optional<ActualArgument> &, const std::string &procName,
const std::string &argName);

// Get the number of distinct symbols with CUDA attribute in the expression.
inline bool IsCUDADeviceSymbol(const Symbol &sym) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
return true;
}
}
return false;
}

// Get the number of distinct symbols with CUDA device
// attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
semantics::UnorderedSymbolSet symbols;
for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
symbols.insert(sym);
}
if (IsCUDADeviceSymbol(sym)) {
symbols.insert(sym);
}
}
return symbols.size();
}

// Check if any of the symbols part of the expression has a CUDA data
// Check if any of the symbols part of the expression has a CUDA device
// attribute.
template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
return GetNbOfCUDADeviceSymbols(expr) > 0;
Expand All @@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
unsigned hostSymbols{0};
unsigned deviceSymbols{0};
for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
++deviceSymbols;
} else {
if (sym.owner().IsDerivedType()) {
if (const auto *details =
sym.owner()
.GetSymbol()
->GetUltimate()
.detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
++deviceSymbols;
}
}
if (IsCUDADeviceSymbol(sym)) {
++deviceSymbols;
} else {
if (sym.owner().IsDerivedType()) {
if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
++deviceSymbols;
}
++hostSymbols;
}
++hostSymbols;
}
}
return hostSymbols > 0 && deviceSymbols > 0;
Expand Down

0 comments on commit 0ff9259

Please sign in to comment.