Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenACC] Add 'collapse' clause AST/basic Sema implementation #109461

Merged
merged 3 commits into from
Oct 1, 2024

Conversation

erichkeane
Copy link
Collaborator

The 'collapse' clause on a 'loop' construct is used to specify how many nested loops are associated with the 'loop' construct. It takes an optional 'force' tag, and an integer constant expression as arguments.

There are many other restrictions based on the contents of the loop/etc, but those are implemented in followup patches, for now, this patch just adds the AST node and does basic argument checking on the loop-count.

The 'collapse' clause on a 'loop' construct is used to specify how many
nested loops are associated with the 'loop' construct.  It takes an
optional 'force' tag, and an integer constant expression as arguments.

There are many other restrictions based on the contents of the loop/etc,
but those are implemented in followup patches, for now, this patch just
adds the AST node and does basic argument checking on the loop-count.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:as-a-library libclang and C++ API labels Sep 20, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-clang-modules

Author: Erich Keane (erichkeane)

Changes

The 'collapse' clause on a 'loop' construct is used to specify how many nested loops are associated with the 'loop' construct. It takes an optional 'force' tag, and an integer constant expression as arguments.

There are many other restrictions based on the contents of the loop/etc, but those are implemented in followup patches, for now, this patch just adds the AST node and does basic argument checking on the loop-count.


Patch is 33.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109461.diff

21 Files Affected:

  • (modified) clang/include/clang/AST/OpenACCClause.h (+26)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3)
  • (modified) clang/include/clang/Basic/OpenACCClauses.def (+1)
  • (modified) clang/include/clang/Sema/SemaOpenACC.h (+27-1)
  • (modified) clang/lib/AST/OpenACCClause.cpp (+33-1)
  • (modified) clang/lib/AST/StmtProfile.cpp (+6)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+6)
  • (modified) clang/lib/Parse/ParseOpenACC.cpp (+15-4)
  • (modified) clang/lib/Sema/SemaOpenACC.cpp (+62)
  • (modified) clang/lib/Sema/TreeTransform.h (+25)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+7-1)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+7-1)
  • (modified) clang/test/AST/ast-print-openacc-loop-construct.cpp (+25)
  • (modified) clang/test/ParserOpenACC/parse-clauses.c (+4-7)
  • (modified) clang/test/ParserOpenACC/parse-clauses.cpp (+14-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-device_type-clause.c (+1-2)
  • (modified) clang/test/SemaOpenACC/loop-construct-auto_seq_independent-clauses.c (-6)
  • (added) clang/test/SemaOpenACC/loop-construct-collapse-ast.cpp (+150)
  • (added) clang/test/SemaOpenACC/loop-construct-collapse-clause.cpp (+117)
  • (modified) clang/test/SemaOpenACC/loop-construct-device_type-clause.c (-1)
  • (modified) clang/tools/libclang/CIndex.cpp (+3)
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index ea1ffbc7fd08b4..90f5b7fc9ab6f4 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -547,6 +547,32 @@ class OpenACCAsyncClause : public OpenACCClauseWithSingleIntExpr {
                                     SourceLocation EndLoc);
 };
 
+/// Represents a 'collapse' clause on a 'loop' construct. This clause takes an
+/// integer constant expression 'N' that represents how deep to collapse the
+/// construct. It also takes an optional 'force' tag that permits intervening
+/// code in the loops.
+class OpenACCCollapseClause : public OpenACCClauseWithSingleIntExpr {
+  bool HasForce = false;
+
+  OpenACCCollapseClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                        bool HasForce, Expr *LoopCount, SourceLocation EndLoc);
+
+public:
+  const Expr *getLoopCount() const { return getIntExpr(); }
+  Expr *getLoopCount() { return getIntExpr(); }
+
+  bool hasForce() const { return HasForce; }
+
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Collapse;
+  }
+
+  static OpenACCCollapseClause *Create(const ASTContext &C,
+                                       SourceLocation BeginLoc,
+                                       SourceLocation LParenLoc, bool HasForce,
+                                       Expr *LoopCount, SourceLocation EndLoc);
+};
+
 /// Represents a clause with one or more 'var' objects, represented as an expr,
 /// as its arguments. Var-list is expected to be stored in trailing storage.
 /// For now, we're just storing the original expression in its entirety, unlike
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index ba813af960af6f..fb60477bc0c2e6 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12606,6 +12606,9 @@ def note_acc_construct_here : Note<"'%0' construct is here">;
 def err_acc_loop_spec_conflict
     : Error<"OpenACC clause '%0' on '%1' construct conflicts with previous "
             "data dependence clause">;
+def err_acc_collapse_loop_count
+    : Error<"OpenACC 'collapse' clause loop count must be a %select{constant "
+            "expression|positive integer value, evaluated to %1}0">;
 
 // AMDGCN builtins diagnostics
 def err_amdgcn_global_load_lds_size_invalid_value : Error<"invalid size value">;
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 85f4859925f0bc..19cdfe7672133b 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -24,6 +24,7 @@
 VISIT_CLAUSE(Auto)
 VISIT_CLAUSE(Async)
 VISIT_CLAUSE(Attach)
+VISIT_CLAUSE(Collapse)
 VISIT_CLAUSE(Copy)
 CLAUSE_ALIAS(PCopy, Copy, true)
 CLAUSE_ALIAS(PresentOrCopy, Copy, true)
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index 0ca76842e5f902..b4d54d95953046 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -87,9 +87,14 @@ class SemaOpenACC : public SemaBase {
       SmallVector<Expr *> VarList;
     };
 
+    struct CollapseDetails {
+      bool IsForce;
+      Expr *LoopCount;
+    };
+
     std::variant<std::monostate, DefaultDetails, ConditionDetails,
                  IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
-                 ReductionDetails>
+                 ReductionDetails, CollapseDetails>
         Details = std::monostate{};
 
   public:
@@ -246,6 +251,19 @@ class SemaOpenACC : public SemaBase {
       return std::get<VarListDetails>(Details).IsZero;
     }
 
+    bool isForce() const {
+      assert(ClauseKind == OpenACCClauseKind::Collapse &&
+             "Only 'collapse' has a force tag");
+      return std::get<CollapseDetails>(Details).IsForce;
+    }
+
+    Expr *getLoopCount() const {
+      assert(ClauseKind == OpenACCClauseKind::Collapse &&
+             "Only 'collapse' has a loop count");
+      return std::get<CollapseDetails>(Details).LoopCount;
+    }
+
+
     ArrayRef<DeviceTypeArgument> getDeviceTypeArchitectures() const {
       assert((ClauseKind == OpenACCClauseKind::DeviceType ||
               ClauseKind == OpenACCClauseKind::DType) &&
@@ -384,6 +402,12 @@ class SemaOpenACC : public SemaBase {
              "Only 'device_type'/'dtype' has a device-type-arg list");
       Details = DeviceTypeDetails{std::move(Archs)};
     }
+
+    void setCollapseDetails(bool IsForce, Expr *LoopCount) {
+      assert(ClauseKind == OpenACCClauseKind::Collapse &&
+             "Only 'collapse' has collapse details");
+      Details = CollapseDetails{IsForce, LoopCount};
+    }
   };
 
   SemaOpenACC(Sema &S);
@@ -448,6 +472,8 @@ class SemaOpenACC : public SemaBase {
                                    Expr *LowerBound,
                                    SourceLocation ColonLocFirst, Expr *Length,
                                    SourceLocation RBLoc);
+  /// Checks the loop depth value for a collapse clause.
+  ExprResult CheckCollapseLoopCount(Expr *LoopCount);
 
   /// Helper type for the registration/assignment of constructs that need to
   /// 'know' about their parent constructs and hold a reference to them, such as
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 95089a9b79e267..fc8a4c28dc3fe4 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -43,7 +43,7 @@ bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
 bool OpenACCClauseWithSingleIntExpr::classof(const OpenACCClause *C) {
   return OpenACCNumWorkersClause::classof(C) ||
          OpenACCVectorLengthClause::classof(C) ||
-         OpenACCAsyncClause::classof(C);
+         OpenACCCollapseClause::classof(C) || OpenACCAsyncClause::classof(C);
 }
 OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
                                                    OpenACCDefaultClauseKind K,
@@ -134,6 +134,30 @@ OpenACCNumWorkersClause::Create(const ASTContext &C, SourceLocation BeginLoc,
       OpenACCNumWorkersClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
 }
 
+OpenACCCollapseClause::OpenACCCollapseClause(SourceLocation BeginLoc,
+                                             SourceLocation LParenLoc,
+                                             bool HasForce, Expr *LoopCount,
+                                             SourceLocation EndLoc)
+    : OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::Collapse, BeginLoc,
+                                     LParenLoc, LoopCount, EndLoc),
+      HasForce(HasForce) {
+  assert(LoopCount && "LoopCount required");
+}
+
+OpenACCCollapseClause *
+OpenACCCollapseClause::Create(const ASTContext &C, SourceLocation BeginLoc,
+                              SourceLocation LParenLoc, bool HasForce,
+                              Expr *LoopCount, SourceLocation EndLoc) {
+  assert(LoopCount &&
+         (LoopCount->isInstantiationDependent() ||
+          (LoopCount->isIntegerConstantExpr(C))) &&
+         "Loop count not constant expression");
+  void *Mem =
+      C.Allocate(sizeof(OpenACCCollapseClause), alignof(OpenACCCollapseClause));
+  return new (Mem)
+      OpenACCCollapseClause(BeginLoc, LParenLoc, HasForce, LoopCount, EndLoc);
+}
+
 OpenACCVectorLengthClause::OpenACCVectorLengthClause(SourceLocation BeginLoc,
                                                      SourceLocation LParenLoc,
                                                      Expr *IntExpr,
@@ -550,3 +574,11 @@ void OpenACCClausePrinter::VisitIndependentClause(
 void OpenACCClausePrinter::VisitSeqClause(const OpenACCSeqClause &C) {
   OS << "seq";
 }
+
+void OpenACCClausePrinter::VisitCollapseClause(const OpenACCCollapseClause &C) {
+  OS << "collapse(";
+  if (C.hasForce())
+    OS << "force:";
+  printExpr(C.getLoopCount());
+  OS << ")";
+}
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index ad4281986f668e..c3812844ab8a31 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2558,6 +2558,12 @@ void OpenACCClauseProfiler::VisitNumWorkersClause(
   Profiler.VisitStmt(Clause.getIntExpr());
 }
 
+void OpenACCClauseProfiler::VisitCollapseClause(
+    const OpenACCCollapseClause &Clause) {
+  assert(Clause.getLoopCount() && "collapse clause requires a valid int expr");
+  Profiler.VisitStmt(Clause.getLoopCount());
+}
+
 void OpenACCClauseProfiler::VisitPrivateClause(
     const OpenACCPrivateClause &Clause) {
   for (auto *E : Clause.getVarList())
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 3c51c746471829..8a74159c7c93e5 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -419,6 +419,12 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
       // but print 'clause' here so it is clear what is happening from the dump.
       OS << " clause";
       break;
+    case OpenACCClauseKind::Collapse:
+      OS << " clause";
+      if (cast<OpenACCCollapseClause>(C)->hasForce())
+        OS << ": force";
+      break;
+
     case OpenACCClauseKind::CopyIn:
     case OpenACCClauseKind::PCopyIn:
     case OpenACCClauseKind::PresentOrCopyIn:
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 0261e8ea3c9b76..e66abd6873794e 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -976,14 +976,25 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
                                      /*IsReadOnly=*/false, /*IsZero=*/false);
       break;
     case OpenACCClauseKind::Collapse: {
-      tryParseAndConsumeSpecialTokenKind(*this, OpenACCSpecialTokenKind::Force,
-                                         ClauseKind);
-      ExprResult NumLoops =
+      bool HasForce = tryParseAndConsumeSpecialTokenKind(
+          *this, OpenACCSpecialTokenKind::Force, ClauseKind);
+      ExprResult LoopCount =
           getActions().CorrectDelayedTyposInExpr(ParseConstantExpression());
-      if (NumLoops.isInvalid()) {
+      if (LoopCount.isInvalid()) {
         Parens.skipToEnd();
         return OpenACCCanContinue();
       }
+
+      LoopCount = getActions().OpenACC().ActOnIntExpr(
+          OpenACCDirectiveKind::Invalid, ClauseKind,
+          LoopCount.get()->getBeginLoc(), LoopCount.get());
+
+      if (LoopCount.isInvalid()) {
+        Parens.skipToEnd();
+        return OpenACCCanContinue();
+      }
+
+      ParsedClause.setCollapseDetails(HasForce, LoopCount.get());
       break;
     }
     case OpenACCClauseKind::Bind: {
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index ecbcc19413dc61..6a4054b53fcb43 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -343,6 +343,18 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
       return false;
     }
 
+  case OpenACCClauseKind::Collapse: {
+    switch (DirectiveKind) {
+    case OpenACCDirectiveKind::Loop:
+    case OpenACCDirectiveKind::ParallelLoop:
+    case OpenACCDirectiveKind::SerialLoop:
+    case OpenACCDirectiveKind::KernelsLoop:
+      return true;
+    default:
+      return false;
+    }
+  }
+
   default:
     // Do nothing so we can go to the 'unimplemented' diagnostic instead.
     return true;
@@ -1037,6 +1049,26 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitReductionClause(
       ValidVars, Clause.getEndLoc());
 }
 
+OpenACCClause *SemaOpenACCClauseVisitor::VisitCollapseClause(
+    SemaOpenACC::OpenACCParsedClause &Clause) {
+  // Duplicates here are not really sensible.  We could possible permit
+  // multiples if they all had the same value, but there isn't really a good
+  // reason to do so. Also, this simplifies the suppression of duplicates, in
+  // that we know if we 'find' one after instantiation, that it is the same
+  // clause, which simplifies instantiation/checking/etc.
+  if (checkAlreadyHasClauseOfKind(SemaRef, ExistingClauses, Clause))
+    return nullptr;
+
+  ExprResult LoopCount = SemaRef.CheckCollapseLoopCount(Clause.getLoopCount());
+
+  if (!LoopCount.isUsable())
+    return nullptr;
+
+  return OpenACCCollapseClause::Create(
+      Ctx, Clause.getBeginLoc(), Clause.getLParenLoc(),
+      Clause.isForce(), LoopCount.get(), Clause.getEndLoc());
+}
+
 } // namespace
 
 SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
@@ -1273,6 +1305,9 @@ ExprResult SemaOpenACC::ActOnIntExpr(OpenACCDirectiveKind DK,
     }
   } IntExprDiagnoser(DK, CK, IntExpr);
 
+  if (!IntExpr)
+    return ExprError();
+
   ExprResult IntExprResult = SemaRef.PerformContextualImplicitConversion(
       Loc, IntExpr, IntExprDiagnoser);
   if (IntExprResult.isInvalid())
@@ -1583,6 +1618,33 @@ ExprResult SemaOpenACC::ActOnArraySectionExpr(Expr *Base, SourceLocation LBLoc,
                        OK_Ordinary, ColonLoc, RBLoc);
 }
 
+ExprResult SemaOpenACC::CheckCollapseLoopCount(Expr *LoopCount) {
+  if (!LoopCount)
+    return ExprError();
+
+  assert((LoopCount->isInstantiationDependent() ||
+          LoopCount->getType()->isIntegerType()) &&
+         "Loop argument non integer?");
+
+  // If this is dependent, there really isn't anything we can check.
+  if (LoopCount->isInstantiationDependent())
+    return ExprResult{LoopCount};
+
+  std::optional<llvm::APSInt> ICE =
+      LoopCount->getIntegerConstantExpr(getASTContext());
+
+  // OpenACC 3.3: 2.9.1
+  // The argument to the collapse clause must be a constant positive integer
+  // expression.
+  if (!ICE || *ICE <= 0) {
+    Diag(LoopCount->getBeginLoc(), diag::err_acc_collapse_loop_count)
+        << ICE.has_value() << ICE.value_or(llvm::APSInt{}).getExtValue();
+    return ExprError();
+  }
+
+  return ExprResult{LoopCount};
+}
+
 bool SemaOpenACC::ActOnStartStmtDirective(OpenACCDirectiveKind K,
                                           SourceLocation StartLoc) {
   SemaRef.DiscardCleanupsInEvaluationContext();
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 95ded5e59a9fa7..3dba8e53669e8d 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -11821,6 +11821,31 @@ void OpenACCClauseTransform<Derived>::VisitReductionClause(
       ParsedClause.getLParenLoc(), C.getReductionOp(), ValidVars,
       ParsedClause.getEndLoc());
 }
+
+template <typename Derived>
+void OpenACCClauseTransform<Derived>::VisitCollapseClause(
+    const OpenACCCollapseClause &C) {
+  Expr *LoopCount = const_cast<Expr *>(C.getLoopCount());
+  assert(LoopCount && "collapse clause constructed with invalid loop count");
+
+  ExprResult NewLoopCount = Self.TransformExpr(LoopCount);
+
+  NewLoopCount = Self.getSema().OpenACC().ActOnIntExpr(
+      OpenACCDirectiveKind::Invalid, ParsedClause.getClauseKind(),
+      NewLoopCount.get()->getBeginLoc(), NewLoopCount.get());
+
+  NewLoopCount =
+      Self.getSema().OpenACC().CheckCollapseLoopCount(NewLoopCount.get());
+
+  if (!NewLoopCount.isUsable())
+    return;
+
+  ParsedClause.setCollapseDetails(C.hasForce(), NewLoopCount.get());
+  NewClause = OpenACCCollapseClause::Create(
+      Self.getSema().getASTContext(), ParsedClause.getBeginLoc(),
+      ParsedClause.getLParenLoc(), ParsedClause.isForce(),
+      ParsedClause.getLoopCount(), ParsedClause.getEndLoc());
+}
 } // namespace
 template <typename Derived>
 OpenACCClause *TreeTransform<Derived>::TransformOpenACCClause(
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 7efcc81e194d95..7a0a6d81b21a58 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -12269,6 +12269,13 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
     return OpenACCIndependentClause::Create(getContext(), BeginLoc, EndLoc);
   case OpenACCClauseKind::Auto:
     return OpenACCAutoClause::Create(getContext(), BeginLoc, EndLoc);
+  case OpenACCClauseKind::Collapse: {
+    SourceLocation LParenLoc = readSourceLocation();
+    bool HasForce = readBool();
+    Expr *LoopCount = readSubExpr();
+    return OpenACCCollapseClause::Create(getContext(), BeginLoc, LParenLoc,
+                                         HasForce, LoopCount, EndLoc);
+  }
 
   case OpenACCClauseKind::Finalize:
   case OpenACCClauseKind::IfPresent:
@@ -12282,7 +12289,6 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
   case OpenACCClauseKind::DeviceResident:
   case OpenACCClauseKind::Host:
   case OpenACCClauseKind::Link:
-  case OpenACCClauseKind::Collapse:
   case OpenACCClauseKind::Bind:
   case OpenACCClauseKind::DeviceNum:
   case OpenACCClauseKind::DefaultAsync:
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 008bf571f847dc..2da0c94c9e4cbc 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -8126,6 +8126,13 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
     // Nothing to do here, there is no additional information beyond the
     // begin/end loc and clause kind.
     return;
+  case OpenACCClauseKind::Collapse: {
+    const auto *CC = cast<OpenACCCollapseClause>(C);
+    writeSourceLocation(CC->getLParenLoc());
+    writeBool(CC->hasForce());
+    AddStmt(const_cast<Expr *>(CC->getLoopCount()));
+    return;
+  }
 
   case OpenACCClauseKind::Finalize:
   case OpenACCClauseKind::IfPresent:
@@ -8139,7 +8146,6 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
   case OpenACCClauseKind::DeviceResident:
   case OpenACCClauseKind::Host:
   case OpenACCClauseKind::Link:
-  case OpenACCClauseKind::Collapse:
   case OpenACCClauseKind::Bind:
   case OpenACCClauseKind::DeviceNum:
   case OpenACCClauseKind::DefaultAsync:
diff --git a/clang/test/AST/ast-print-openacc-loop-construct.cpp b/clang/test/AST/ast-print-openacc-loop-construct.cpp
index cde302a66f3af7..ae1f7964f019eb 100644
--- a/clang/test/AST/ast-print-openacc-loop-construct.cpp
+++ b/clang/test/AST/ast-print-openacc-loop-construct.cpp
@@ -57,4 +57,29 @@ void foo() {
 // CHECK-NEXT: ;
 #pragma acc loop private(i, array[1], array, array[1:2])
   for(;;);
+
+// CHECK: #pragma acc loop collapse(1)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: ;
+#pragma acc loop collapse(1)
+  for(;;);
+// CHECK: #pragma acc loop collapse(force:1)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: ;
+#pragma acc loop collapse(force:1)
+  for(;;);
+// CHECK: #pragma acc loop collapse(2)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: ;
+#pragma acc loop collapse(2)
+  for(;;)
+    for(;;);
+// CHECK: #pragma acc loop collapse(force:2)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: ;
+#pragma acc loop collapse(force:2)
+  for(;;)
+    for(;;);
 }
diff --git a/clang/test/ParserOpenACC/parse-clauses.c b/clang/test/ParserOpenACC/parse-clauses.c
index 15c4554a31922a..6c9ce4ad5e1969 100644
--- a/clang/test/ParserOpenACC/parse-clauses.c
+++ b/clang/test/ParserOpenACC/parse-clauses.c
@@ -105,17 +105,14 @@ void func() {
 #pragma acc loop collapse(force:)
   for(;;){}
 
-  // expected-error@+2{{invalid tag 'unknown' on 'collapse' clause}}
-  // expected-warning@+1{{OpenACC clause 'collapse' not yet implemented, clause ignored}}
-#pragma acc loop collapse(unknown:5)
+  // expected-error@+1{{invalid tag 'unknown' on 'collapse' clause}}
+#pragma acc loop collapse(unknown:1)
   for(;;){}
 
-  // expected-warning@+1{{OpenACC clause 'collapse' not yet implemented, clause ignored}}
-#pragma acc loop collapse(force:5)
+#pragma acc loop collapse(force:1)
   for(;;){}
 
-  // expected-warning@+1{{OpenACC clause 'collapse' not yet implemented, clause ignored}}
-#pragma acc loop collapse(5)
+#pragma acc loop collapse(1)
   for(;;){}
 
   // expected-error@+2{{expected ')'}}
diff --git a/clang/test/ParserOpenACC/parse-clauses.cpp b/clang/test/ParserOpenACC/parse-clauses.cpp
index b7e252e892beab..9613530db77ddc 100644
--- a/clang/test/ParserOpenACC/parse-clauses.cpp
+++ b/clang/test/ParserOpenACC/parse-clauses.cpp
@@ -2,13 +2,23 @@
 
 template<unsigned I, typename T>
 void templ() {
-  // expected-warning@+1{{OpenACC clause 'collapse' not yet implemented, clause ignored}}
 #pragma acc loop collapse(I)
-  for(;;){}
+  for(;;)
+    for(;;)
+      for(;;)
+        for(;;)
+          for(;;)
+        ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-clang

Author: Erich Keane (erichkeane)

Changes

The 'collapse' clause on a 'loop' construct is used to specify how many nested loops are associated with the 'loop' construct. It takes an optional 'force' tag, and an integer constant expression as arguments.

There are many other restrictions based on the contents of the loop/etc, but those are implemented in followup patches, for now, this patch just adds the AST node and does basic argument checking on the loop-count.


Patch is 33.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109461.diff

21 Files Affected:

  • (modified) clang/include/clang/AST/OpenACCClause.h (+26)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3)
  • (modified) clang/include/clang/Basic/OpenACCClauses.def (+1)
  • (modified) clang/include/clang/Sema/SemaOpenACC.h (+27-1)
  • (modified) clang/lib/AST/OpenACCClause.cpp (+33-1)
  • (modified) clang/lib/AST/StmtProfile.cpp (+6)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+6)
  • (modified) clang/lib/Parse/ParseOpenACC.cpp (+15-4)
  • (modified) clang/lib/Sema/SemaOpenACC.cpp (+62)
  • (modified) clang/lib/Sema/TreeTransform.h (+25)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+7-1)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+7-1)
  • (modified) clang/test/AST/ast-print-openacc-loop-construct.cpp (+25)
  • (modified) clang/test/ParserOpenACC/parse-clauses.c (+4-7)
  • (modified) clang/test/ParserOpenACC/parse-clauses.cpp (+14-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-device_type-clause.c (+1-2)
  • (modified) clang/test/SemaOpenACC/loop-construct-auto_seq_independent-clauses.c (-6)
  • (added) clang/test/SemaOpenACC/loop-construct-collapse-ast.cpp (+150)
  • (added) clang/test/SemaOpenACC/loop-construct-collapse-clause.cpp (+117)
  • (modified) clang/test/SemaOpenACC/loop-construct-device_type-clause.c (-1)
  • (modified) clang/tools/libclang/CIndex.cpp (+3)
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index ea1ffbc7fd08b4..90f5b7fc9ab6f4 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -547,6 +547,32 @@ class OpenACCAsyncClause : public OpenACCClauseWithSingleIntExpr {
                                     SourceLocation EndLoc);
 };
 
+/// Represents a 'collapse' clause on a 'loop' construct. This clause takes an
+/// integer constant expression 'N' that represents how deep to collapse the
+/// construct. It also takes an optional 'force' tag that permits intervening
+/// code in the loops.
+class OpenACCCollapseClause : public OpenACCClauseWithSingleIntExpr {
+  bool HasForce = false;
+
+  OpenACCCollapseClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                        bool HasForce, Expr *LoopCount, SourceLocation EndLoc);
+
+public:
+  const Expr *getLoopCount() const { return getIntExpr(); }
+  Expr *getLoopCount() { return getIntExpr(); }
+
+  bool hasForce() const { return HasForce; }
+
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Collapse;
+  }
+
+  static OpenACCCollapseClause *Create(const ASTContext &C,
+                                       SourceLocation BeginLoc,
+                                       SourceLocation LParenLoc, bool HasForce,
+                                       Expr *LoopCount, SourceLocation EndLoc);
+};
+
 /// Represents a clause with one or more 'var' objects, represented as an expr,
 /// as its arguments. Var-list is expected to be stored in trailing storage.
 /// For now, we're just storing the original expression in its entirety, unlike
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index ba813af960af6f..fb60477bc0c2e6 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12606,6 +12606,9 @@ def note_acc_construct_here : Note<"'%0' construct is here">;
 def err_acc_loop_spec_conflict
     : Error<"OpenACC clause '%0' on '%1' construct conflicts with previous "
             "data dependence clause">;
+def err_acc_collapse_loop_count
+    : Error<"OpenACC 'collapse' clause loop count must be a %select{constant "
+            "expression|positive integer value, evaluated to %1}0">;
 
 // AMDGCN builtins diagnostics
 def err_amdgcn_global_load_lds_size_invalid_value : Error<"invalid size value">;
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 85f4859925f0bc..19cdfe7672133b 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -24,6 +24,7 @@
 VISIT_CLAUSE(Auto)
 VISIT_CLAUSE(Async)
 VISIT_CLAUSE(Attach)
+VISIT_CLAUSE(Collapse)
 VISIT_CLAUSE(Copy)
 CLAUSE_ALIAS(PCopy, Copy, true)
 CLAUSE_ALIAS(PresentOrCopy, Copy, true)
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index 0ca76842e5f902..b4d54d95953046 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -87,9 +87,14 @@ class SemaOpenACC : public SemaBase {
       SmallVector<Expr *> VarList;
     };
 
+    struct CollapseDetails {
+      bool IsForce;
+      Expr *LoopCount;
+    };
+
     std::variant<std::monostate, DefaultDetails, ConditionDetails,
                  IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
-                 ReductionDetails>
+                 ReductionDetails, CollapseDetails>
         Details = std::monostate{};
 
   public:
@@ -246,6 +251,19 @@ class SemaOpenACC : public SemaBase {
       return std::get<VarListDetails>(Details).IsZero;
     }
 
+    bool isForce() const {
+      assert(ClauseKind == OpenACCClauseKind::Collapse &&
+             "Only 'collapse' has a force tag");
+      return std::get<CollapseDetails>(Details).IsForce;
+    }
+
+    Expr *getLoopCount() const {
+      assert(ClauseKind == OpenACCClauseKind::Collapse &&
+             "Only 'collapse' has a loop count");
+      return std::get<CollapseDetails>(Details).LoopCount;
+    }
+
+
     ArrayRef<DeviceTypeArgument> getDeviceTypeArchitectures() const {
       assert((ClauseKind == OpenACCClauseKind::DeviceType ||
               ClauseKind == OpenACCClauseKind::DType) &&
@@ -384,6 +402,12 @@ class SemaOpenACC : public SemaBase {
              "Only 'device_type'/'dtype' has a device-type-arg list");
       Details = DeviceTypeDetails{std::move(Archs)};
     }
+
+    void setCollapseDetails(bool IsForce, Expr *LoopCount) {
+      assert(ClauseKind == OpenACCClauseKind::Collapse &&
+             "Only 'collapse' has collapse details");
+      Details = CollapseDetails{IsForce, LoopCount};
+    }
   };
 
   SemaOpenACC(Sema &S);
@@ -448,6 +472,8 @@ class SemaOpenACC : public SemaBase {
                                    Expr *LowerBound,
                                    SourceLocation ColonLocFirst, Expr *Length,
                                    SourceLocation RBLoc);
+  /// Checks the loop depth value for a collapse clause.
+  ExprResult CheckCollapseLoopCount(Expr *LoopCount);
 
   /// Helper type for the registration/assignment of constructs that need to
   /// 'know' about their parent constructs and hold a reference to them, such as
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 95089a9b79e267..fc8a4c28dc3fe4 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -43,7 +43,7 @@ bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
 bool OpenACCClauseWithSingleIntExpr::classof(const OpenACCClause *C) {
   return OpenACCNumWorkersClause::classof(C) ||
          OpenACCVectorLengthClause::classof(C) ||
-         OpenACCAsyncClause::classof(C);
+         OpenACCCollapseClause::classof(C) || OpenACCAsyncClause::classof(C);
 }
 OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
                                                    OpenACCDefaultClauseKind K,
@@ -134,6 +134,30 @@ OpenACCNumWorkersClause::Create(const ASTContext &C, SourceLocation BeginLoc,
       OpenACCNumWorkersClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
 }
 
+OpenACCCollapseClause::OpenACCCollapseClause(SourceLocation BeginLoc,
+                                             SourceLocation LParenLoc,
+                                             bool HasForce, Expr *LoopCount,
+                                             SourceLocation EndLoc)
+    : OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::Collapse, BeginLoc,
+                                     LParenLoc, LoopCount, EndLoc),
+      HasForce(HasForce) {
+  assert(LoopCount && "LoopCount required");
+}
+
+OpenACCCollapseClause *
+OpenACCCollapseClause::Create(const ASTContext &C, SourceLocation BeginLoc,
+                              SourceLocation LParenLoc, bool HasForce,
+                              Expr *LoopCount, SourceLocation EndLoc) {
+  assert(LoopCount &&
+         (LoopCount->isInstantiationDependent() ||
+          (LoopCount->isIntegerConstantExpr(C))) &&
+         "Loop count not constant expression");
+  void *Mem =
+      C.Allocate(sizeof(OpenACCCollapseClause), alignof(OpenACCCollapseClause));
+  return new (Mem)
+      OpenACCCollapseClause(BeginLoc, LParenLoc, HasForce, LoopCount, EndLoc);
+}
+
 OpenACCVectorLengthClause::OpenACCVectorLengthClause(SourceLocation BeginLoc,
                                                      SourceLocation LParenLoc,
                                                      Expr *IntExpr,
@@ -550,3 +574,11 @@ void OpenACCClausePrinter::VisitIndependentClause(
 void OpenACCClausePrinter::VisitSeqClause(const OpenACCSeqClause &C) {
   OS << "seq";
 }
+
+void OpenACCClausePrinter::VisitCollapseClause(const OpenACCCollapseClause &C) {
+  OS << "collapse(";
+  if (C.hasForce())
+    OS << "force:";
+  printExpr(C.getLoopCount());
+  OS << ")";
+}
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index ad4281986f668e..c3812844ab8a31 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2558,6 +2558,12 @@ void OpenACCClauseProfiler::VisitNumWorkersClause(
   Profiler.VisitStmt(Clause.getIntExpr());
 }
 
+void OpenACCClauseProfiler::VisitCollapseClause(
+    const OpenACCCollapseClause &Clause) {
+  assert(Clause.getLoopCount() && "collapse clause requires a valid int expr");
+  Profiler.VisitStmt(Clause.getLoopCount());
+}
+
 void OpenACCClauseProfiler::VisitPrivateClause(
     const OpenACCPrivateClause &Clause) {
   for (auto *E : Clause.getVarList())
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 3c51c746471829..8a74159c7c93e5 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -419,6 +419,12 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
       // but print 'clause' here so it is clear what is happening from the dump.
       OS << " clause";
       break;
+    case OpenACCClauseKind::Collapse:
+      OS << " clause";
+      if (cast<OpenACCCollapseClause>(C)->hasForce())
+        OS << ": force";
+      break;
+
     case OpenACCClauseKind::CopyIn:
     case OpenACCClauseKind::PCopyIn:
     case OpenACCClauseKind::PresentOrCopyIn:
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 0261e8ea3c9b76..e66abd6873794e 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -976,14 +976,25 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
                                      /*IsReadOnly=*/false, /*IsZero=*/false);
       break;
     case OpenACCClauseKind::Collapse: {
-      tryParseAndConsumeSpecialTokenKind(*this, OpenACCSpecialTokenKind::Force,
-                                         ClauseKind);
-      ExprResult NumLoops =
+      bool HasForce = tryParseAndConsumeSpecialTokenKind(
+          *this, OpenACCSpecialTokenKind::Force, ClauseKind);
+      ExprResult LoopCount =
           getActions().CorrectDelayedTyposInExpr(ParseConstantExpression());
-      if (NumLoops.isInvalid()) {
+      if (LoopCount.isInvalid()) {
         Parens.skipToEnd();
         return OpenACCCanContinue();
       }
+
+      LoopCount = getActions().OpenACC().ActOnIntExpr(
+          OpenACCDirectiveKind::Invalid, ClauseKind,
+          LoopCount.get()->getBeginLoc(), LoopCount.get());
+
+      if (LoopCount.isInvalid()) {
+        Parens.skipToEnd();
+        return OpenACCCanContinue();
+      }
+
+      ParsedClause.setCollapseDetails(HasForce, LoopCount.get());
       break;
     }
     case OpenACCClauseKind::Bind: {
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index ecbcc19413dc61..6a4054b53fcb43 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -343,6 +343,18 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
       return false;
     }
 
+  case OpenACCClauseKind::Collapse: {
+    switch (DirectiveKind) {
+    case OpenACCDirectiveKind::Loop:
+    case OpenACCDirectiveKind::ParallelLoop:
+    case OpenACCDirectiveKind::SerialLoop:
+    case OpenACCDirectiveKind::KernelsLoop:
+      return true;
+    default:
+      return false;
+    }
+  }
+
   default:
     // Do nothing so we can go to the 'unimplemented' diagnostic instead.
     return true;
@@ -1037,6 +1049,26 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitReductionClause(
       ValidVars, Clause.getEndLoc());
 }
 
+OpenACCClause *SemaOpenACCClauseVisitor::VisitCollapseClause(
+    SemaOpenACC::OpenACCParsedClause &Clause) {
+  // Duplicates here are not really sensible.  We could possible permit
+  // multiples if they all had the same value, but there isn't really a good
+  // reason to do so. Also, this simplifies the suppression of duplicates, in
+  // that we know if we 'find' one after instantiation, that it is the same
+  // clause, which simplifies instantiation/checking/etc.
+  if (checkAlreadyHasClauseOfKind(SemaRef, ExistingClauses, Clause))
+    return nullptr;
+
+  ExprResult LoopCount = SemaRef.CheckCollapseLoopCount(Clause.getLoopCount());
+
+  if (!LoopCount.isUsable())
+    return nullptr;
+
+  return OpenACCCollapseClause::Create(
+      Ctx, Clause.getBeginLoc(), Clause.getLParenLoc(),
+      Clause.isForce(), LoopCount.get(), Clause.getEndLoc());
+}
+
 } // namespace
 
 SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
@@ -1273,6 +1305,9 @@ ExprResult SemaOpenACC::ActOnIntExpr(OpenACCDirectiveKind DK,
     }
   } IntExprDiagnoser(DK, CK, IntExpr);
 
+  if (!IntExpr)
+    return ExprError();
+
   ExprResult IntExprResult = SemaRef.PerformContextualImplicitConversion(
       Loc, IntExpr, IntExprDiagnoser);
   if (IntExprResult.isInvalid())
@@ -1583,6 +1618,33 @@ ExprResult SemaOpenACC::ActOnArraySectionExpr(Expr *Base, SourceLocation LBLoc,
                        OK_Ordinary, ColonLoc, RBLoc);
 }
 
+ExprResult SemaOpenACC::CheckCollapseLoopCount(Expr *LoopCount) {
+  if (!LoopCount)
+    return ExprError();
+
+  assert((LoopCount->isInstantiationDependent() ||
+          LoopCount->getType()->isIntegerType()) &&
+         "Loop argument non integer?");
+
+  // If this is dependent, there really isn't anything we can check.
+  if (LoopCount->isInstantiationDependent())
+    return ExprResult{LoopCount};
+
+  std::optional<llvm::APSInt> ICE =
+      LoopCount->getIntegerConstantExpr(getASTContext());
+
+  // OpenACC 3.3: 2.9.1
+  // The argument to the collapse clause must be a constant positive integer
+  // expression.
+  if (!ICE || *ICE <= 0) {
+    Diag(LoopCount->getBeginLoc(), diag::err_acc_collapse_loop_count)
+        << ICE.has_value() << ICE.value_or(llvm::APSInt{}).getExtValue();
+    return ExprError();
+  }
+
+  return ExprResult{LoopCount};
+}
+
 bool SemaOpenACC::ActOnStartStmtDirective(OpenACCDirectiveKind K,
                                           SourceLocation StartLoc) {
   SemaRef.DiscardCleanupsInEvaluationContext();
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 95ded5e59a9fa7..3dba8e53669e8d 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -11821,6 +11821,31 @@ void OpenACCClauseTransform<Derived>::VisitReductionClause(
       ParsedClause.getLParenLoc(), C.getReductionOp(), ValidVars,
       ParsedClause.getEndLoc());
 }
+
+template <typename Derived>
+void OpenACCClauseTransform<Derived>::VisitCollapseClause(
+    const OpenACCCollapseClause &C) {
+  Expr *LoopCount = const_cast<Expr *>(C.getLoopCount());
+  assert(LoopCount && "collapse clause constructed with invalid loop count");
+
+  ExprResult NewLoopCount = Self.TransformExpr(LoopCount);
+
+  NewLoopCount = Self.getSema().OpenACC().ActOnIntExpr(
+      OpenACCDirectiveKind::Invalid, ParsedClause.getClauseKind(),
+      NewLoopCount.get()->getBeginLoc(), NewLoopCount.get());
+
+  NewLoopCount =
+      Self.getSema().OpenACC().CheckCollapseLoopCount(NewLoopCount.get());
+
+  if (!NewLoopCount.isUsable())
+    return;
+
+  ParsedClause.setCollapseDetails(C.hasForce(), NewLoopCount.get());
+  NewClause = OpenACCCollapseClause::Create(
+      Self.getSema().getASTContext(), ParsedClause.getBeginLoc(),
+      ParsedClause.getLParenLoc(), ParsedClause.isForce(),
+      ParsedClause.getLoopCount(), ParsedClause.getEndLoc());
+}
 } // namespace
 template <typename Derived>
 OpenACCClause *TreeTransform<Derived>::TransformOpenACCClause(
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 7efcc81e194d95..7a0a6d81b21a58 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -12269,6 +12269,13 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
     return OpenACCIndependentClause::Create(getContext(), BeginLoc, EndLoc);
   case OpenACCClauseKind::Auto:
     return OpenACCAutoClause::Create(getContext(), BeginLoc, EndLoc);
+  case OpenACCClauseKind::Collapse: {
+    SourceLocation LParenLoc = readSourceLocation();
+    bool HasForce = readBool();
+    Expr *LoopCount = readSubExpr();
+    return OpenACCCollapseClause::Create(getContext(), BeginLoc, LParenLoc,
+                                         HasForce, LoopCount, EndLoc);
+  }
 
   case OpenACCClauseKind::Finalize:
   case OpenACCClauseKind::IfPresent:
@@ -12282,7 +12289,6 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
   case OpenACCClauseKind::DeviceResident:
   case OpenACCClauseKind::Host:
   case OpenACCClauseKind::Link:
-  case OpenACCClauseKind::Collapse:
   case OpenACCClauseKind::Bind:
   case OpenACCClauseKind::DeviceNum:
   case OpenACCClauseKind::DefaultAsync:
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 008bf571f847dc..2da0c94c9e4cbc 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -8126,6 +8126,13 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
     // Nothing to do here, there is no additional information beyond the
     // begin/end loc and clause kind.
     return;
+  case OpenACCClauseKind::Collapse: {
+    const auto *CC = cast<OpenACCCollapseClause>(C);
+    writeSourceLocation(CC->getLParenLoc());
+    writeBool(CC->hasForce());
+    AddStmt(const_cast<Expr *>(CC->getLoopCount()));
+    return;
+  }
 
   case OpenACCClauseKind::Finalize:
   case OpenACCClauseKind::IfPresent:
@@ -8139,7 +8146,6 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
   case OpenACCClauseKind::DeviceResident:
   case OpenACCClauseKind::Host:
   case OpenACCClauseKind::Link:
-  case OpenACCClauseKind::Collapse:
   case OpenACCClauseKind::Bind:
   case OpenACCClauseKind::DeviceNum:
   case OpenACCClauseKind::DefaultAsync:
diff --git a/clang/test/AST/ast-print-openacc-loop-construct.cpp b/clang/test/AST/ast-print-openacc-loop-construct.cpp
index cde302a66f3af7..ae1f7964f019eb 100644
--- a/clang/test/AST/ast-print-openacc-loop-construct.cpp
+++ b/clang/test/AST/ast-print-openacc-loop-construct.cpp
@@ -57,4 +57,29 @@ void foo() {
 // CHECK-NEXT: ;
 #pragma acc loop private(i, array[1], array, array[1:2])
   for(;;);
+
+// CHECK: #pragma acc loop collapse(1)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: ;
+#pragma acc loop collapse(1)
+  for(;;);
+// CHECK: #pragma acc loop collapse(force:1)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: ;
+#pragma acc loop collapse(force:1)
+  for(;;);
+// CHECK: #pragma acc loop collapse(2)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: ;
+#pragma acc loop collapse(2)
+  for(;;)
+    for(;;);
+// CHECK: #pragma acc loop collapse(force:2)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: for (;;)
+// CHECK-NEXT: ;
+#pragma acc loop collapse(force:2)
+  for(;;)
+    for(;;);
 }
diff --git a/clang/test/ParserOpenACC/parse-clauses.c b/clang/test/ParserOpenACC/parse-clauses.c
index 15c4554a31922a..6c9ce4ad5e1969 100644
--- a/clang/test/ParserOpenACC/parse-clauses.c
+++ b/clang/test/ParserOpenACC/parse-clauses.c
@@ -105,17 +105,14 @@ void func() {
 #pragma acc loop collapse(force:)
   for(;;){}
 
-  // expected-error@+2{{invalid tag 'unknown' on 'collapse' clause}}
-  // expected-warning@+1{{OpenACC clause 'collapse' not yet implemented, clause ignored}}
-#pragma acc loop collapse(unknown:5)
+  // expected-error@+1{{invalid tag 'unknown' on 'collapse' clause}}
+#pragma acc loop collapse(unknown:1)
   for(;;){}
 
-  // expected-warning@+1{{OpenACC clause 'collapse' not yet implemented, clause ignored}}
-#pragma acc loop collapse(force:5)
+#pragma acc loop collapse(force:1)
   for(;;){}
 
-  // expected-warning@+1{{OpenACC clause 'collapse' not yet implemented, clause ignored}}
-#pragma acc loop collapse(5)
+#pragma acc loop collapse(1)
   for(;;){}
 
   // expected-error@+2{{expected ')'}}
diff --git a/clang/test/ParserOpenACC/parse-clauses.cpp b/clang/test/ParserOpenACC/parse-clauses.cpp
index b7e252e892beab..9613530db77ddc 100644
--- a/clang/test/ParserOpenACC/parse-clauses.cpp
+++ b/clang/test/ParserOpenACC/parse-clauses.cpp
@@ -2,13 +2,23 @@
 
 template<unsigned I, typename T>
 void templ() {
-  // expected-warning@+1{{OpenACC clause 'collapse' not yet implemented, clause ignored}}
 #pragma acc loop collapse(I)
-  for(;;){}
+  for(;;)
+    for(;;)
+      for(;;)
+        for(;;)
+          for(;;)
+        ...
[truncated]

Copy link

github-actions bot commented Sep 20, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Comment on lines +1633 to +1634
std::optional<llvm::APSInt> ICE =
LoopCount->getIntegerConstantExpr(getASTContext());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would wrap LoopCount in a ConstantExpr if it's not dependent, so that you don't have to evaluate every time.

(We don't seem to do that for other openacc nodes though)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a great idea! We don't really have any other constant-expressions in OpenACC implemented yet, so this is the first.

@erichkeane erichkeane merged commit 97da34e into llvm:main Oct 1, 2024
10 checks passed
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Oct 2, 2024
…09461)

The 'collapse' clause on a 'loop' construct is used to specify how many
nested loops are associated with the 'loop' construct. It takes an
optional 'force' tag, and an integer constant expression as arguments.

There are many other restrictions based on the contents of the loop/etc,
but those are implemented in followup patches, for now, this patch just
adds the AST node and does basic argument checking on the loop-count.
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Oct 2, 2024
…09461)

The 'collapse' clause on a 'loop' construct is used to specify how many
nested loops are associated with the 'loop' construct. It takes an
optional 'force' tag, and an integer constant expression as arguments.

There are many other restrictions based on the contents of the loop/etc,
but those are implemented in followup patches, for now, this patch just
adds the AST node and does basic argument checking on the loop-count.
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
…09461)

The 'collapse' clause on a 'loop' construct is used to specify how many
nested loops are associated with the 'loop' construct. It takes an
optional 'force' tag, and an integer constant expression as arguments.

There are many other restrictions based on the contents of the loop/etc,
but those are implemented in followup patches, for now, this patch just
adds the AST node and does basic argument checking on the loop-count.
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
…09461)

The 'collapse' clause on a 'loop' construct is used to specify how many
nested loops are associated with the 'loop' construct. It takes an
optional 'force' tag, and an integer constant expression as arguments.

There are many other restrictions based on the contents of the loop/etc,
but those are implemented in followup patches, for now, this patch just
adds the AST node and does basic argument checking on the loop-count.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:as-a-library libclang and C++ API clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants