From 41757226688a944de460e92f7684a056014053e1 Mon Sep 17 00:00:00 2001 From: Mark Carrington <31017244+MarkMpn@users.noreply.github.com> Date: Sat, 19 Oct 2024 09:00:18 +0100 Subject: [PATCH] Do not leak parameter definitions out of subqueries --- .../AdoProviderTests.cs | 16 +++++++++++++++ .../ExecutionPlan/FilterNode.cs | 18 ++++++++--------- .../ExecutionPlanBuilder.cs | 18 +++++++---------- MarkMpn.Sql4Cds.Engine/NodeContext.cs | 20 +++++++++++++++++-- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs index 57a8aa1c..35cff7b2 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs @@ -855,6 +855,22 @@ public void AliasedTVF() } } + [TestMethod] + public void FilteredTVFWithSubqueryParameters() + { + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandTimeout = 0; + cmd.CommandText = "SELECT * FROM SampleMessage((select '1')) WHERE OutputParam1 = '2'"; + + using (var reader = cmd.ExecuteReader()) + { + Assert.IsFalse(reader.Read()); + } + } + } + [TestMethod] public void CorrelatedNotExistsTypeConversion() { diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs index 813b840b..90fedcef 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs @@ -436,11 +436,11 @@ private bool FoldFiltersToInnerJoinSources(NodeCompilationContext context, IList if (join is NestedLoopNode loop && loop.OuterReferences != null) { - var innerParameterTypes = context.ParameterTypes - .Concat(loop.OuterReferences.Select(or => new KeyValuePair(or.Value, leftSchema.Schema[or.Key].Type))) + var innerParameterTypes = loop.OuterReferences + .Select(or => new KeyValuePair(or.Value, leftSchema.Schema[or.Key].Type)) .ToDictionary(p => p.Key, p => p.Value, StringComparer.OrdinalIgnoreCase); - rightContext = new NodeCompilationContext(context, innerParameterTypes); + rightContext = context.CreateChildContext(innerParameterTypes); } var rightSchema = join.RightSource.GetSchema(rightContext); @@ -798,11 +798,11 @@ private void ConvertOuterJoinsWithNonNullFiltersToInnerJoins(NodeCompilationCont { var leftSchema = join.LeftSource.GetSchema(context); - var innerParameterTypes = context.ParameterTypes - .Concat(loop.OuterReferences.Select(or => new KeyValuePair(or.Value, leftSchema.Schema[or.Key].Type))) + var innerParameterTypes = loop.OuterReferences + .Select(or => new KeyValuePair(or.Value, leftSchema.Schema[or.Key].Type)) .ToDictionary(p => p.Key, p => p.Value, StringComparer.OrdinalIgnoreCase); - outerContext = new NodeCompilationContext(context, innerParameterTypes); + outerContext = context.CreateChildContext(innerParameterTypes); } var outerSchema = outerSource.GetSchema(outerContext); @@ -1775,11 +1775,11 @@ source is AliasNode || if (join is NestedLoopNode loop && loop.OuterReferences != null && loop.OuterReferences.Count > 0) { var leftSchema = join.LeftSource.GetSchema(context); - var innerParameterTypes = context.ParameterTypes - .Concat(loop.OuterReferences.Select(or => new KeyValuePair(or.Value, leftSchema.Schema[or.Key].Type))) + var innerParameterTypes = loop.OuterReferences + .Select(or => new KeyValuePair(or.Value, leftSchema.Schema[or.Key].Type)) .ToDictionary(p => p.Key, p => p.Value, StringComparer.OrdinalIgnoreCase); - childContext = new NodeCompilationContext(context, innerParameterTypes); + childContext = context.CreateChildContext(innerParameterTypes); } foreach (var subSource in GetFoldableSources(join.RightSource, childContext)) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index e49f4519..2f38cc2b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -615,7 +615,7 @@ private SelectNode ConvertRecursiveCTEQuery(QueryExpression queryExpression, INo queryExpression.Accept(cteReplacer); // Convert the modified query. - var childContext = new NodeCompilationContext(_nodeContext, outerReferences.ToDictionary(kvp => kvp.Value, kvp => anchorSchema.Schema[cteValidator.Name.EscapeIdentifier() + "." + kvp.Key.EscapeIdentifier()].Type, StringComparer.OrdinalIgnoreCase)); + var childContext = _nodeContext.CreateChildContext(outerReferences.ToDictionary(kvp => kvp.Value, kvp => anchorSchema.Schema[cteValidator.Name.EscapeIdentifier() + "." + kvp.Key.EscapeIdentifier()].Type, StringComparer.OrdinalIgnoreCase)); var converted = ConvertSelectStatement(queryExpression, null, null, null, childContext); converted.ExpandWildcardColumns(childContext); return converted; @@ -3261,8 +3261,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubquery(IDataExecutionPlanNodeI lhsCol = lhsColNormalized.ToColumnReference(); } - var parameters = context.ParameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); - var innerContext = new NodeCompilationContext(context, parameters); + var innerContext = context.CreateChildContext(null); var references = new Dictionary(); var innerQuery = ConvertSelectStatement(inSubquery.Subquery.QueryExpression, hints, schema, references, innerContext); @@ -3286,7 +3285,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubquery(IDataExecutionPlanNodeI else { // We need the inner list to be distinct to avoid creating duplicates during the join - var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(Session, Options, parameters, Log)); + var innerSchema = innerQuery.Source.GetSchema(innerContext); if (innerQuery.ColumnSet[0].SourceColumn != innerSchema.PrimaryKey && !(innerQuery.Source is DistinctNode)) { innerQuery.Source = new DistinctNode @@ -3397,11 +3396,10 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubquery(IDataExecutionPlanN var schema = source.GetSchema(context); // Each query of the format "EXISTS (SELECT * FROM source)" becomes a outer semi join - var parameters = context.ParameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); - var innerContext = new NodeCompilationContext(context, parameters); + var innerContext = context.CreateChildContext(null); var references = new Dictionary(); var innerQuery = ConvertSelectStatement(existsSubquery.Subquery.QueryExpression, hints, schema, references, innerContext); - var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(Session, Options, parameters, Log)); + var innerSchema = innerQuery.Source.GetSchema(innerContext); var innerSchemaPrimaryKey = innerSchema.PrimaryKey; // Create the join @@ -4341,8 +4339,7 @@ private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expressio { var outerSchema = node.GetSchema(context); var outerReferences = new Dictionary(); - var innerParameterTypes = context.ParameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); - var innerContext = new NodeCompilationContext(context, innerParameterTypes); + var innerContext = context.CreateChildContext(null); var subqueryPlan = ConvertSelectStatement(subquery.QueryExpression, hints, outerSchema, outerReferences, innerContext); // Scalar subquery must return exactly one column and one row @@ -5243,8 +5240,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe // CROSS APPLY / OUTER APPLY - treat the second table as a correlated subquery var lhsSchema = lhs.GetSchema(context); lhsReferences = new Dictionary(); - var innerParameterTypes = context.ParameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); - var innerContext = new NodeCompilationContext(context, innerParameterTypes); + var innerContext = context.CreateChildContext(null); var subqueryPlan = ConvertTableReference(unqualifiedJoin.SecondTableReference, hints, query, lhsSchema, lhsReferences, innerContext); rhs = subqueryPlan; diff --git a/MarkMpn.Sql4Cds.Engine/NodeContext.cs b/MarkMpn.Sql4Cds.Engine/NodeContext.cs index 27f792b1..b4db2294 100644 --- a/MarkMpn.Sql4Cds.Engine/NodeContext.cs +++ b/MarkMpn.Sql4Cds.Engine/NodeContext.cs @@ -54,13 +54,16 @@ public NodeCompilationContext( /// /// The parent context that this context is being created from /// The names and types of the parameters that are available to this section of the query - public NodeCompilationContext( + protected NodeCompilationContext( NodeCompilationContext parentContext, IDictionary parameterTypes) { + if (parentContext == null) + throw new ArgumentNullException(nameof(parentContext)); + Session = parentContext.Session; Options = parentContext.Options; - ParameterTypes = parameterTypes; + ParameterTypes = new LayeredDictionary(parentContext.ParameterTypes, parameterTypes); GlobalCalculations = parentContext.GlobalCalculations; Log = parentContext.Log; _parentContext = parentContext; @@ -108,6 +111,19 @@ public string GetExpressionName() return $"Expr{++_expressionCounter}"; } + /// + /// Creates a new as a child of this context + /// + /// Any additional parameters to add to the context + /// + public NodeCompilationContext CreateChildContext(IDictionary additionalParameters) + { + if (additionalParameters == null) + additionalParameters = new Dictionary(StringComparer.OrdinalIgnoreCase); + + return new NodeCompilationContext(this, additionalParameters); + } + internal void ResetGlobalCalculations() { GlobalCalculations.OuterReferences.Clear();