diff --git a/MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs index 8622a8d5..fdd47156 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs @@ -512,11 +512,69 @@ UNION ALL var concat = AssertNode(spoolProducer.Source); var depth0 = AssertNode(concat.Sources[0]); var anchor = AssertNode(depth0.Source); + + AssertFetchXml(anchor, @" + + + + + + + + + + "); + var assert = AssertNode(concat.Sources[1]); var nestedLoop = AssertNode(assert.Source); var depthPlus1 = AssertNode(nestedLoop.LeftSource); - var spoolConsumer = AssertNode(depthPlus1); + var spoolConsumer = AssertNode(depthPlus1.Source); var children = AssertNode(nestedLoop.RightSource); + + AssertFetchXml(children, @" + + + + + + + + + + "); + } + + [TestMethod] + public void FactorialCalc() + { + using (var con = new Sql4CdsConnection(_localDataSource)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = @" + WITH Factorial (N, Factorial) AS ( + SELECT 1, 1 + UNION ALL + SELECT N + 1, (N + 1) * Factorial FROM Factorial WHERE N < 5) + SELECT N, Factorial FROM Factorial"; + + using (var reader = cmd.ExecuteReader()) + { + var n = 1; + var factorial = 1; + + while (n <= 5) + { + Assert.IsTrue(reader.Read()); + Assert.AreEqual(n, reader.GetInt32(0)); + Assert.AreEqual(factorial, reader.GetInt32(1)); + + n++; + factorial *= n; + } + + Assert.IsFalse(reader.Read()); + } + } } private T AssertNode(IExecutionPlanNode node) where T : IExecutionPlanNode diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index 9a82a025..a9cf153a 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -290,7 +290,7 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op var anchorSchema = anchorQuery.GetSchema(_nodeContext); foreach (var col in anchorSchema.Schema) - recurseLoop.OuterReferences[col.Key] = "@" + _nodeContext.GetExpressionName(); + recurseLoop.OuterReferences[col.Key.SplitMultiPartIdentifier().Last()] = "@" + _nodeContext.GetExpressionName(); if (cteValidator.RecursiveQueries.Count > 1) { @@ -437,7 +437,7 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op private SelectNode ConvertRecursiveCTEQuery(QueryExpression queryExpression, INodeSchema anchorSchema, CteValidatorVisitor cteValidator, Dictionary outerReferences) { // Convert the query using the anchor query as a subquery to check for ambiguous column names - ConvertSelectStatement(queryExpression, null, null, null, _nodeContext); + ConvertSelectStatement(queryExpression.Clone(), null, null, null, _nodeContext); // Remove recursive references from the FROM clause, moving join predicates to the WHERE clause // If the recursive reference was in an unqualified join, replace it with (SELECT @Expr1, @Expr2) AS cte (field1, field2) @@ -446,7 +446,7 @@ private SelectNode ConvertRecursiveCTEQuery(QueryExpression queryExpression, INo queryExpression.Accept(cteReplacer); // Convert the modified query. - var childContext = new NodeCompilationContext(_nodeContext, outerReferences.ToDictionary(kvp => kvp.Value, kvp => anchorSchema.Schema[kvp.Key].Type, StringComparer.OrdinalIgnoreCase)); + var childContext = new NodeCompilationContext(_nodeContext, outerReferences.ToDictionary(kvp => kvp.Value, kvp => anchorSchema.Schema[cteValidator.Name.EscapeIdentifier() + "." + kvp.Key].Type, StringComparer.OrdinalIgnoreCase)); return ConvertSelectStatement(queryExpression, null, null, null, childContext); } @@ -2081,7 +2081,7 @@ private SelectNode ConvertSelectQuerySpec(QuerySpecification querySpec, IList() } } : ConvertFromClause(querySpec.FromClause, hints, querySpec, outerSchema, outerReferences, context); + var node = querySpec.FromClause == null || querySpec.FromClause.TableReferences.Count == 0 ? new ConstantScanNode { Values = { new Dictionary() } } : ConvertFromClause(querySpec.FromClause, hints, querySpec, outerSchema, outerReferences, context); var logicalSchema = node.GetSchema(context); node = ConvertInSubqueries(node, hints, querySpec, context, outerSchema, outerReferences); diff --git a/MarkMpn.Sql4Cds.Engine/TSqlFragmentExtensions.cs b/MarkMpn.Sql4Cds.Engine/TSqlFragmentExtensions.cs index 20e6c7d3..87875986 100644 --- a/MarkMpn.Sql4Cds.Engine/TSqlFragmentExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/TSqlFragmentExtensions.cs @@ -455,7 +455,91 @@ public static T Clone(this T fragment) where T : TSqlFragment }; } - throw new NotSupportedQueryFragmentException("Unhandled expression type", fragment); + if (fragment is QuerySpecification querySpec) + { + var clone = new QuerySpecification + { + ForClause = querySpec.ForClause?.Clone(), + FromClause = querySpec.FromClause?.Clone(), + GroupByClause = querySpec.GroupByClause?.Clone(), + HavingClause = querySpec.HavingClause?.Clone(), + OffsetClause = querySpec.OffsetClause?.Clone(), + OrderByClause = querySpec.OrderByClause?.Clone(), + TopRowFilter = querySpec.TopRowFilter?.Clone(), + UniqueRowFilter = querySpec.UniqueRowFilter, + WhereClause = querySpec.WhereClause?.Clone(), + WindowClause = querySpec.WindowClause?.Clone() + }; + + foreach (var selectElement in querySpec.SelectElements) + clone.SelectElements.Add(selectElement.Clone()); + + return (T)(object)clone; + } + + if (fragment is FromClause from) + { + var clone = new FromClause(); + + foreach (var predict in from.PredictTableReference) + clone.PredictTableReference.Add(predict.Clone()); + + foreach (var table in from.TableReferences) + clone.TableReferences.Add(table.Clone()); + + return (T)(object)clone; + } + + if (fragment is NamedTableReference tableRef) + { + var clone = new NamedTableReference + { + Alias = tableRef.Alias?.Clone(), + ForPath = tableRef.ForPath, + SchemaObject = tableRef.SchemaObject.Clone(), + TableSampleClause = tableRef.TableSampleClause?.Clone(), + TemporalClause = tableRef.TemporalClause?.Clone(), + }; + + foreach (var hint in tableRef.TableHints) + clone.TableHints.Add(hint.Clone()); + + return (T)(object)clone; + } + + if (fragment is GroupByClause groupBy) + { + var clone = new GroupByClause + { + All = groupBy.All, + GroupByOption = groupBy.GroupByOption + }; + + foreach (var groupBySpec in groupBy.GroupingSpecifications) + clone.GroupingSpecifications.Add(groupBySpec.Clone()); + + return (T)(object)clone; + } + + if (fragment is WhereClause where) + { + return (T)(object)new WhereClause + { + Cursor = where.Cursor?.Clone(), + SearchCondition = where.SearchCondition?.Clone() + }; + } + + if (fragment is SelectScalarExpression selectScalarExpression) + { + return (T)(object)new SelectScalarExpression + { + ColumnName = selectScalarExpression.ColumnName?.Clone(), + Expression = selectScalarExpression.Expression?.Clone() + }; + } + + throw new NotSupportedQueryFragmentException("Unhandled expression type " + fragment.GetType().Name, fragment); } } } diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/RemoveRecursiveCTETableReferencesVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/RemoveRecursiveCTETableReferencesVisitor.cs index 342b27c3..7e1a763e 100644 --- a/MarkMpn.Sql4Cds.Engine/Visitors/RemoveRecursiveCTETableReferencesVisitor.cs +++ b/MarkMpn.Sql4Cds.Engine/Visitors/RemoveRecursiveCTETableReferencesVisitor.cs @@ -169,7 +169,7 @@ public override void ExplicitVisit(QuerySpecification node) { // Replace references to the recursive CTE columns with variables var rewrites = _outerReferences - .SelectMany(kvp => new[] { kvp, new KeyValuePair(kvp.Key.Split('.')[1], kvp.Value) }) + .SelectMany(kvp => new[] { kvp, new KeyValuePair(_name.EscapeIdentifier() + "." + kvp.Key, kvp.Value) }) .ToDictionary(kvp => (ScalarExpression)kvp.Key.ToColumnReference(), kvp => (ScalarExpression)new VariableReference { Name = kvp.Value }); node.Accept(new RewriteVisitor(rewrites));