Skip to content

Commit

Permalink
Do not leak parameter definitions out of subqueries
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Oct 19, 2024
1 parent 5026771 commit 4175722
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 22 deletions.
16 changes: 16 additions & 0 deletions MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,22 @@ public void AliasedTVF()
}
}

[TestMethod]
public void FilteredTVFWithSubqueryParameters()
{
using (var con = new Sql4CdsConnection(_localDataSources))
using (var cmd = con.CreateCommand())
{
cmd.CommandTimeout = 0;
cmd.CommandText = "SELECT * FROM SampleMessage((select '1')) WHERE OutputParam1 = '2'";

using (var reader = cmd.ExecuteReader())
{
Assert.IsFalse(reader.Read());
}
}
}

[TestMethod]
public void CorrelatedNotExistsTypeConversion()
{
Expand Down
18 changes: 9 additions & 9 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,11 @@ private bool FoldFiltersToInnerJoinSources(NodeCompilationContext context, IList

if (join is NestedLoopNode loop && loop.OuterReferences != null)
{
var innerParameterTypes = context.ParameterTypes
.Concat(loop.OuterReferences.Select(or => new KeyValuePair<string, DataTypeReference>(or.Value, leftSchema.Schema[or.Key].Type)))
var innerParameterTypes = loop.OuterReferences
.Select(or => new KeyValuePair<string, DataTypeReference>(or.Value, leftSchema.Schema[or.Key].Type))
.ToDictionary(p => p.Key, p => p.Value, StringComparer.OrdinalIgnoreCase);

rightContext = new NodeCompilationContext(context, innerParameterTypes);
rightContext = context.CreateChildContext(innerParameterTypes);
}

var rightSchema = join.RightSource.GetSchema(rightContext);
Expand Down Expand Up @@ -798,11 +798,11 @@ private void ConvertOuterJoinsWithNonNullFiltersToInnerJoins(NodeCompilationCont
{
var leftSchema = join.LeftSource.GetSchema(context);

var innerParameterTypes = context.ParameterTypes
.Concat(loop.OuterReferences.Select(or => new KeyValuePair<string, DataTypeReference>(or.Value, leftSchema.Schema[or.Key].Type)))
var innerParameterTypes = loop.OuterReferences
.Select(or => new KeyValuePair<string, DataTypeReference>(or.Value, leftSchema.Schema[or.Key].Type))
.ToDictionary(p => p.Key, p => p.Value, StringComparer.OrdinalIgnoreCase);

outerContext = new NodeCompilationContext(context, innerParameterTypes);
outerContext = context.CreateChildContext(innerParameterTypes);
}

var outerSchema = outerSource.GetSchema(outerContext);
Expand Down Expand Up @@ -1775,11 +1775,11 @@ source is AliasNode ||
if (join is NestedLoopNode loop && loop.OuterReferences != null && loop.OuterReferences.Count > 0)
{
var leftSchema = join.LeftSource.GetSchema(context);
var innerParameterTypes = context.ParameterTypes
.Concat(loop.OuterReferences.Select(or => new KeyValuePair<string,DataTypeReference>(or.Value, leftSchema.Schema[or.Key].Type)))
var innerParameterTypes = loop.OuterReferences
.Select(or => new KeyValuePair<string,DataTypeReference>(or.Value, leftSchema.Schema[or.Key].Type))
.ToDictionary(p => p.Key, p => p.Value, StringComparer.OrdinalIgnoreCase);

childContext = new NodeCompilationContext(context, innerParameterTypes);
childContext = context.CreateChildContext(innerParameterTypes);
}

foreach (var subSource in GetFoldableSources(join.RightSource, childContext))
Expand Down
18 changes: 7 additions & 11 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ private SelectNode ConvertRecursiveCTEQuery(QueryExpression queryExpression, INo
queryExpression.Accept(cteReplacer);

// Convert the modified query.
var childContext = new NodeCompilationContext(_nodeContext, outerReferences.ToDictionary(kvp => kvp.Value, kvp => anchorSchema.Schema[cteValidator.Name.EscapeIdentifier() + "." + kvp.Key.EscapeIdentifier()].Type, StringComparer.OrdinalIgnoreCase));
var childContext = _nodeContext.CreateChildContext(outerReferences.ToDictionary(kvp => kvp.Value, kvp => anchorSchema.Schema[cteValidator.Name.EscapeIdentifier() + "." + kvp.Key.EscapeIdentifier()].Type, StringComparer.OrdinalIgnoreCase));
var converted = ConvertSelectStatement(queryExpression, null, null, null, childContext);
converted.ExpandWildcardColumns(childContext);
return converted;
Expand Down Expand Up @@ -3261,8 +3261,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubquery(IDataExecutionPlanNodeI
lhsCol = lhsColNormalized.ToColumnReference();
}

var parameters = context.ParameterTypes == null ? new Dictionary<string, DataTypeReference>(StringComparer.OrdinalIgnoreCase) : new Dictionary<string, DataTypeReference>(context.ParameterTypes, StringComparer.OrdinalIgnoreCase);
var innerContext = new NodeCompilationContext(context, parameters);
var innerContext = context.CreateChildContext(null);
var references = new Dictionary<string, string>();
var innerQuery = ConvertSelectStatement(inSubquery.Subquery.QueryExpression, hints, schema, references, innerContext);

Expand All @@ -3286,7 +3285,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubquery(IDataExecutionPlanNodeI
else
{
// We need the inner list to be distinct to avoid creating duplicates during the join
var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(Session, Options, parameters, Log));
var innerSchema = innerQuery.Source.GetSchema(innerContext);
if (innerQuery.ColumnSet[0].SourceColumn != innerSchema.PrimaryKey && !(innerQuery.Source is DistinctNode))
{
innerQuery.Source = new DistinctNode
Expand Down Expand Up @@ -3397,11 +3396,10 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubquery(IDataExecutionPlanN
var schema = source.GetSchema(context);

// Each query of the format "EXISTS (SELECT * FROM source)" becomes a outer semi join
var parameters = context.ParameterTypes == null ? new Dictionary<string, DataTypeReference>(StringComparer.OrdinalIgnoreCase) : new Dictionary<string, DataTypeReference>(context.ParameterTypes, StringComparer.OrdinalIgnoreCase);
var innerContext = new NodeCompilationContext(context, parameters);
var innerContext = context.CreateChildContext(null);
var references = new Dictionary<string, string>();
var innerQuery = ConvertSelectStatement(existsSubquery.Subquery.QueryExpression, hints, schema, references, innerContext);
var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(Session, Options, parameters, Log));
var innerSchema = innerQuery.Source.GetSchema(innerContext);
var innerSchemaPrimaryKey = innerSchema.PrimaryKey;

// Create the join
Expand Down Expand Up @@ -4341,8 +4339,7 @@ private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expressio
{
var outerSchema = node.GetSchema(context);
var outerReferences = new Dictionary<string, string>();
var innerParameterTypes = context.ParameterTypes == null ? new Dictionary<string, DataTypeReference>(StringComparer.OrdinalIgnoreCase) : new Dictionary<string, DataTypeReference>(context.ParameterTypes, StringComparer.OrdinalIgnoreCase);
var innerContext = new NodeCompilationContext(context, innerParameterTypes);
var innerContext = context.CreateChildContext(null);
var subqueryPlan = ConvertSelectStatement(subquery.QueryExpression, hints, outerSchema, outerReferences, innerContext);

// Scalar subquery must return exactly one column and one row
Expand Down Expand Up @@ -5243,8 +5240,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe
// CROSS APPLY / OUTER APPLY - treat the second table as a correlated subquery
var lhsSchema = lhs.GetSchema(context);
lhsReferences = new Dictionary<string, string>();
var innerParameterTypes = context.ParameterTypes == null ? new Dictionary<string, DataTypeReference>(StringComparer.OrdinalIgnoreCase) : new Dictionary<string, DataTypeReference>(context.ParameterTypes, StringComparer.OrdinalIgnoreCase);
var innerContext = new NodeCompilationContext(context, innerParameterTypes);
var innerContext = context.CreateChildContext(null);
var subqueryPlan = ConvertTableReference(unqualifiedJoin.SecondTableReference, hints, query, lhsSchema, lhsReferences, innerContext);
rhs = subqueryPlan;

Expand Down
20 changes: 18 additions & 2 deletions MarkMpn.Sql4Cds.Engine/NodeContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@ public NodeCompilationContext(
/// </summary>
/// <param name="parentContext">The parent context that this context is being created from</param>
/// <param name="parameterTypes">The names and types of the parameters that are available to this section of the query</param>
public NodeCompilationContext(
protected NodeCompilationContext(
NodeCompilationContext parentContext,
IDictionary<string, DataTypeReference> parameterTypes)
{
if (parentContext == null)
throw new ArgumentNullException(nameof(parentContext));

Session = parentContext.Session;
Options = parentContext.Options;
ParameterTypes = parameterTypes;
ParameterTypes = new LayeredDictionary<string, DataTypeReference>(parentContext.ParameterTypes, parameterTypes);
GlobalCalculations = parentContext.GlobalCalculations;
Log = parentContext.Log;
_parentContext = parentContext;
Expand Down Expand Up @@ -108,6 +111,19 @@ public string GetExpressionName()
return $"Expr{++_expressionCounter}";
}

/// <summary>
/// Creates a new <see cref="NodeCompilationContext"/> as a child of this context
/// </summary>
/// <param name="additionalParameters">Any additional parameters to add to the context</param>
/// <returns></returns>
public NodeCompilationContext CreateChildContext(IDictionary<string, DataTypeReference> additionalParameters)
{
if (additionalParameters == null)
additionalParameters = new Dictionary<string, DataTypeReference>(StringComparer.OrdinalIgnoreCase);

return new NodeCompilationContext(this, additionalParameters);
}

internal void ResetGlobalCalculations()
{
GlobalCalculations.OuterReferences.Clear();
Expand Down

0 comments on commit 4175722

Please sign in to comment.