Skip to content

Commit

Permalink
Recursion progress
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Sep 18, 2023
1 parent ca8c03f commit d2e072f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 7 deletions.
60 changes: 59 additions & 1 deletion MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,69 @@ UNION ALL
var concat = AssertNode<ConcatenateNode>(spoolProducer.Source);
var depth0 = AssertNode<ComputeScalarNode>(concat.Sources[0]);
var anchor = AssertNode<FetchXmlScan>(depth0.Source);

AssertFetchXml(anchor, @"
<fetch>
<entity name='contact'>
<attribute name='contactid' />
<attribute name='firstname' />
<attribute name='lastname' />
<filter>
<condition attribute=""firstname"" operator=""eq"" value=""Mark"" />
</filter>
</entity>
</fetch>");

var assert = AssertNode<AssertNode>(concat.Sources[1]);
var nestedLoop = AssertNode<NestedLoopNode>(assert.Source);
var depthPlus1 = AssertNode<ComputeScalarNode>(nestedLoop.LeftSource);
var spoolConsumer = AssertNode<TableSpoolNode>(depthPlus1);
var spoolConsumer = AssertNode<TableSpoolNode>(depthPlus1.Source);
var children = AssertNode<FetchXmlScan>(nestedLoop.RightSource);

AssertFetchXml(children, @"
<fetch xmlns:generator=""MarkMpn.SQL4CDS"">
<entity name='contact'>
<attribute name='contactid' />
<attribute name='firstname' />
<attribute name='lastname' />
<filter>
<condition attribute=""parentcustomerid"" operator=""eq"" value=""@Expr3"" generator:IsVariable=""true"" />
</filter>
</entity>
</fetch>");
}

[TestMethod]
public void FactorialCalc()
{
using (var con = new Sql4CdsConnection(_localDataSource))
using (var cmd = con.CreateCommand())
{
cmd.CommandText = @"
WITH Factorial (N, Factorial) AS (
SELECT 1, 1
UNION ALL
SELECT N + 1, (N + 1) * Factorial FROM Factorial WHERE N < 5)
SELECT N, Factorial FROM Factorial";

using (var reader = cmd.ExecuteReader())
{
var n = 1;
var factorial = 1;

while (n <= 5)
{
Assert.IsTrue(reader.Read());
Assert.AreEqual(n, reader.GetInt32(0));
Assert.AreEqual(factorial, reader.GetInt32(1));

n++;
factorial *= n;
}

Assert.IsFalse(reader.Read());
}
}
}

private T AssertNode<T>(IExecutionPlanNode node) where T : IExecutionPlanNode
Expand Down
8 changes: 4 additions & 4 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op
var anchorSchema = anchorQuery.GetSchema(_nodeContext);

foreach (var col in anchorSchema.Schema)
recurseLoop.OuterReferences[col.Key] = "@" + _nodeContext.GetExpressionName();
recurseLoop.OuterReferences[col.Key.SplitMultiPartIdentifier().Last()] = "@" + _nodeContext.GetExpressionName();

if (cteValidator.RecursiveQueries.Count > 1)
{
Expand Down Expand Up @@ -437,7 +437,7 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op
private SelectNode ConvertRecursiveCTEQuery(QueryExpression queryExpression, INodeSchema anchorSchema, CteValidatorVisitor cteValidator, Dictionary<string, string> outerReferences)
{
// Convert the query using the anchor query as a subquery to check for ambiguous column names
ConvertSelectStatement(queryExpression, null, null, null, _nodeContext);
ConvertSelectStatement(queryExpression.Clone(), null, null, null, _nodeContext);

// Remove recursive references from the FROM clause, moving join predicates to the WHERE clause
// If the recursive reference was in an unqualified join, replace it with (SELECT @Expr1, @Expr2) AS cte (field1, field2)
Expand All @@ -446,7 +446,7 @@ private SelectNode ConvertRecursiveCTEQuery(QueryExpression queryExpression, INo
queryExpression.Accept(cteReplacer);

// Convert the modified query.
var childContext = new NodeCompilationContext(_nodeContext, outerReferences.ToDictionary(kvp => kvp.Value, kvp => anchorSchema.Schema[kvp.Key].Type, StringComparer.OrdinalIgnoreCase));
var childContext = new NodeCompilationContext(_nodeContext, outerReferences.ToDictionary(kvp => kvp.Value, kvp => anchorSchema.Schema[cteValidator.Name.EscapeIdentifier() + "." + kvp.Key].Type, StringComparer.OrdinalIgnoreCase));
return ConvertSelectStatement(queryExpression, null, null, null, childContext);
}

Expand Down Expand Up @@ -2081,7 +2081,7 @@ private SelectNode ConvertSelectQuerySpec(QuerySpecification querySpec, IList<Op
}

// Each table in the FROM clause starts as a separate FetchXmlScan node. Add appropriate join nodes
var node = querySpec.FromClause == null ? new ConstantScanNode { Values = { new Dictionary<string, ScalarExpression>() } } : ConvertFromClause(querySpec.FromClause, hints, querySpec, outerSchema, outerReferences, context);
var node = querySpec.FromClause == null || querySpec.FromClause.TableReferences.Count == 0 ? new ConstantScanNode { Values = { new Dictionary<string, ScalarExpression>() } } : ConvertFromClause(querySpec.FromClause, hints, querySpec, outerSchema, outerReferences, context);
var logicalSchema = node.GetSchema(context);

node = ConvertInSubqueries(node, hints, querySpec, context, outerSchema, outerReferences);
Expand Down
86 changes: 85 additions & 1 deletion MarkMpn.Sql4Cds.Engine/TSqlFragmentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,91 @@ public static T Clone<T>(this T fragment) where T : TSqlFragment
};
}

throw new NotSupportedQueryFragmentException("Unhandled expression type", fragment);
if (fragment is QuerySpecification querySpec)
{
var clone = new QuerySpecification
{
ForClause = querySpec.ForClause?.Clone(),
FromClause = querySpec.FromClause?.Clone(),
GroupByClause = querySpec.GroupByClause?.Clone(),
HavingClause = querySpec.HavingClause?.Clone(),
OffsetClause = querySpec.OffsetClause?.Clone(),
OrderByClause = querySpec.OrderByClause?.Clone(),
TopRowFilter = querySpec.TopRowFilter?.Clone(),
UniqueRowFilter = querySpec.UniqueRowFilter,
WhereClause = querySpec.WhereClause?.Clone(),
WindowClause = querySpec.WindowClause?.Clone()
};

foreach (var selectElement in querySpec.SelectElements)
clone.SelectElements.Add(selectElement.Clone());

return (T)(object)clone;
}

if (fragment is FromClause from)
{
var clone = new FromClause();

foreach (var predict in from.PredictTableReference)
clone.PredictTableReference.Add(predict.Clone());

foreach (var table in from.TableReferences)
clone.TableReferences.Add(table.Clone());

return (T)(object)clone;
}

if (fragment is NamedTableReference tableRef)
{
var clone = new NamedTableReference
{
Alias = tableRef.Alias?.Clone(),
ForPath = tableRef.ForPath,
SchemaObject = tableRef.SchemaObject.Clone(),
TableSampleClause = tableRef.TableSampleClause?.Clone(),
TemporalClause = tableRef.TemporalClause?.Clone(),
};

foreach (var hint in tableRef.TableHints)
clone.TableHints.Add(hint.Clone());

return (T)(object)clone;
}

if (fragment is GroupByClause groupBy)
{
var clone = new GroupByClause
{
All = groupBy.All,
GroupByOption = groupBy.GroupByOption
};

foreach (var groupBySpec in groupBy.GroupingSpecifications)
clone.GroupingSpecifications.Add(groupBySpec.Clone());

return (T)(object)clone;
}

if (fragment is WhereClause where)
{
return (T)(object)new WhereClause
{
Cursor = where.Cursor?.Clone(),
SearchCondition = where.SearchCondition?.Clone()
};
}

if (fragment is SelectScalarExpression selectScalarExpression)
{
return (T)(object)new SelectScalarExpression
{
ColumnName = selectScalarExpression.ColumnName?.Clone(),
Expression = selectScalarExpression.Expression?.Clone()
};
}

throw new NotSupportedQueryFragmentException("Unhandled expression type " + fragment.GetType().Name, fragment);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public override void ExplicitVisit(QuerySpecification node)
{
// Replace references to the recursive CTE columns with variables
var rewrites = _outerReferences
.SelectMany(kvp => new[] { kvp, new KeyValuePair<string, string>(kvp.Key.Split('.')[1], kvp.Value) })
.SelectMany(kvp => new[] { kvp, new KeyValuePair<string, string>(_name.EscapeIdentifier() + "." + kvp.Key, kvp.Value) })
.ToDictionary(kvp => (ScalarExpression)kvp.Key.ToColumnReference(), kvp => (ScalarExpression)new VariableReference { Name = kvp.Value });

node.Accept(new RewriteVisitor(rewrites));
Expand Down

0 comments on commit d2e072f

Please sign in to comment.