diff --git a/internal/node.go b/internal/node.go index 16f0dc227..6699b030a 100644 --- a/internal/node.go +++ b/internal/node.go @@ -45,7 +45,7 @@ type Node struct { SynchronizedAfterSuiteProc1BodyHasContext bool ReportEachBody func(types.SpecReport) - ReportSuiteBody func(types.Report) + ReportSuiteBody func(SpecContext, types.Report) MarkedFocus bool MarkedPending bool @@ -333,7 +333,11 @@ func NewNode(deprecationTracker *types.DeprecationTracker, nodeType types.NodeTy } } else if nodeType.Is(types.NodeTypeReportBeforeSuite | types.NodeTypeReportAfterSuite) { if node.ReportSuiteBody == nil { - node.ReportSuiteBody = arg.(func(types.Report)) + if fn, ok := arg.(func(types.Report)); ok { + node.ReportSuiteBody = func(_ SpecContext, r types.Report) { fn(r) } + } else { + node.ReportSuiteBody = arg.(func(SpecContext, types.Report)) + } } else { appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType)) trackedFunctionError = true diff --git a/internal/node_test.go b/internal/node_test.go index a26d16919..05641a65c 100644 --- a/internal/node_test.go +++ b/internal/node_test.go @@ -861,7 +861,7 @@ var _ = Describe("Node", func() { Ω(node.ID).Should(BeNumerically(">", 0)) Ω(node.NodeType).Should(Equal(types.NodeTypeReportAfterSuite)) - node.ReportSuiteBody(types.Report{}) + node.ReportSuiteBody(internal.NewSpecContext(nil), types.Report{}) Ω(didRun).Should(BeTrue()) Ω(node.CodeLocation).Should(Equal(cl)) @@ -885,7 +885,7 @@ var _ = Describe("Node", func() { Ω(node.ID).Should(BeNumerically(">", 0)) Ω(node.NodeType).Should(Equal(types.NodeTypeReportBeforeSuite)) - node.ReportSuiteBody(types.Report{}) + node.ReportSuiteBody(internal.NewSpecContext(nil), types.Report{}) Ω(didRun).Should(BeTrue()) Ω(node.CodeLocation).Should(Equal(cl)) diff --git a/internal/suite.go b/internal/suite.go index 2b4db48af..44b531ffd 100644 --- a/internal/suite.go +++ b/internal/suite.go @@ -762,7 +762,7 @@ func (suite *Suite) runReportSuiteNode(node Node, report types.Report) { report = report.Add(aggregatedReport) } - node.Body = func(SpecContext) { node.ReportSuiteBody(report) } + node.Body = func(ctx SpecContext) { node.ReportSuiteBody(ctx, report) } suite.currentSpecReport.State, suite.currentSpecReport.Failure = suite.runNode(node, time.Time{}, "") suite.currentSpecReport.EndTime = time.Now()