diff --git a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs index 94ad21d9..5ee77684 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs @@ -2466,5 +2466,394 @@ FROM account AS a } } } + + [TestMethod] + public void DateTimeStringLiterals() + { + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + // https://learn.microsoft.com/en-us/sql/t-sql/data-types/date-transact-sql?view=sql-server-ver16#examples + cmd.CommandText = @" +SELECT + CAST('2022-05-08 12:35:29.1234567 +12:15' AS TIME(7)) AS 'time', + CAST('2022-05-08 12:35:29.1234567 +12:15' AS DATE) AS 'date', + CAST('2022-05-08 12:35:29.123' AS SMALLDATETIME) AS 'smalldatetime', + CAST('2022-05-08 12:35:29.123' AS DATETIME) AS 'datetime', + CAST('2022-05-08 12:35:29.1234567 +12:15' AS DATETIME2(7)) AS 'datetime2', + CAST('2022-05-08 12:35:29.1234567 +12:15' AS DATETIMEOFFSET(7)) AS 'datetimeoffset';"; + + using (var reader = cmd.ExecuteReader()) + { + Assert.IsTrue(reader.Read()); + Assert.AreEqual(new TimeSpan(12, 35, 29) + TimeSpan.FromTicks(1234567), reader.GetValue(0)); + Assert.AreEqual(new DateTime(2022, 5, 8), reader.GetValue(1)); + Assert.AreEqual(new DateTime(2022, 5, 8, 12, 35, 0), reader.GetValue(2)); + Assert.AreEqual(new DateTime(2022, 5, 8, 12, 35, 29, 123), reader.GetValue(3)); + Assert.AreEqual(new DateTime(2022, 5, 8, 12, 35, 29) + TimeSpan.FromTicks(1234567), reader.GetValue(4)); + Assert.AreEqual(new DateTimeOffset(new DateTime(2022, 5, 8, 12, 35, 29) + TimeSpan.FromTicks(1234567), new TimeSpan(12, 15, 0)), reader.GetValue(5)); + Assert.IsFalse(reader.Read()); + } + } + } + + [DataTestMethod] + [DataRow("datetime")] + [DataRow("smalldatetime")] + public void DateTimeToNumeric(string type) + { + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = $@" +declare @dt {type} = '2024-10-04 12:01:00' +select cast(@dt as float), cast(@dt as int)"; + + using (var reader = cmd.ExecuteReader()) + { + Assert.IsTrue(reader.Read()); + Assert.AreEqual(45567.500694444439, reader.GetDouble(0)); + Assert.AreEqual(45568, reader.GetInt32(1)); + Assert.IsFalse(reader.Read()); + } + } + } + + [DataTestMethod] + [DataRow("year", 2007)] + [DataRow("yyyy", 2007)] + [DataRow("yy", 2007)] + [DataRow("quarter", 4)] + [DataRow("qq", 4)] + [DataRow("q", 4)] + [DataRow("month", 10)] + [DataRow("mm", 10)] + [DataRow("m", 10)] + [DataRow("dayofyear", 303)] + [DataRow("dy", 303)] + [DataRow("y", 303)] + [DataRow("day", 30)] + [DataRow("dd", 30)] + [DataRow("d", 30)] + [DataRow("week", 44)] + [DataRow("wk", 44)] + [DataRow("ww", 44)] + [DataRow("weekday", 3)] + [DataRow("dw", 3)] + [DataRow("hour", 12)] + [DataRow("hh", 12)] + [DataRow("minute", 15)] + [DataRow("n", 15)] + [DataRow("second", 32)] + [DataRow("ss", 32)] + [DataRow("s", 32)] + [DataRow("millisecond", 123)] + [DataRow("ms", 123)] + [DataRow("microsecond", 123456)] + [DataRow("mcs", 123456)] + [DataRow("nanosecond", 123456700)] + [DataRow("ns", 123456700)] + [DataRow("tzoffset", 310)] + [DataRow("tz", 310)] + [DataRow("iso_week", 44)] + [DataRow("isowk", 44)] + [DataRow("isoww", 44)] + public void DatePartExamples1(string datepart, int expected) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#return-value + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = $"SELECT DATEPART({datepart}, '2007-10-30 12:15:32.1234567 +05:10')"; + Assert.AreEqual(expected, cmd.ExecuteScalar()); + } + } + + [DataTestMethod] + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#week-and-weekday-datepart-arguments + // Using default SET DATEFIRST 7 + [DataRow("week", "'2007-04-21 '", 16)] + [DataRow("weekday", "'2007-04-21 '", 7)] + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#default-returned-for-a-datepart-that-isnt-in-a-date-argument + [DataRow("year", "'12:10:30.123'", 1900)] + [DataRow("month", "'12:10:30.123'", 1)] + [DataRow("day", "'12:10:30.123'", 1)] + [DataRow("dayofyear", "'12:10:30.123'", 1)] + [DataRow("weekday", "'12:10:30.123'", 2)] + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#fractional-seconds + [DataRow("millisecond", "'00:00:01.1234567'", 123)] + [DataRow("microsecond", "'00:00:01.1234567'", 123456)] + [DataRow("nanosecond", "'00:00:01.1234567'", 123456700)] + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#examples + [DataRow("year", "0", 1900)] + [DataRow("month", "0", 1)] + [DataRow("day", "0", 1)] + public void DatePartExamples2(string datepart, string date, int expected) + { + // Assorted examples from https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16 + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = $"SELECT DATEPART({datepart}, {date})"; + Assert.AreEqual(expected, cmd.ExecuteScalar()); + } + } + + [DataTestMethod] + [DataRow("date", "day")] + [DataRow("smalldatetime", "hour")] + [DataRow("datetime", "hour")] + [DataRow("datetime2", "hour")] + [DataRow("datetimeoffset", "hour")] + [DataRow("time", "hour")] + [DataRow("varchar(100)", "hour")] + public void DateAddReturnsOriginalDataType(string type, string datePart) + { + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = $@" +DECLARE @dt {type} = '2024-10-04 12:01:02'; +SELECT DATEADD({datePart}, 1, @dt);" ; + + if (type == "varchar(100)") + type = "datetime"; + + using (var reader = cmd.ExecuteReader()) + { + var schema = reader.GetSchemaTable(); + + Assert.AreEqual(type, schema.Rows[0]["DataTypeName"]); + + Assert.IsTrue(reader.Read()); + + switch (type) + { + case "date": + Assert.AreEqual(new DateTime(2024, 10, 5), reader.GetValue(0)); + break; + + case "smalldatetime": + Assert.AreEqual(new DateTime(2024, 10, 4, 13, 1, 0), reader.GetValue(0)); + break; + + case "datetime": + case "datetime2": + Assert.AreEqual(new DateTime(2024, 10, 4, 13, 1, 2), reader.GetValue(0)); + break; + + case "datetimeoffset": + Assert.AreEqual(new DateTimeOffset(2024, 10, 4, 13, 1, 2, TimeSpan.Zero), reader.GetValue(0)); + break; + + case "time": + Assert.AreEqual(new TimeSpan(13, 1, 2), reader.GetValue(0)); + break; + } + + Assert.IsFalse(reader.Read()); + } + } + } + + [DataTestMethod] + [DataRow("date", "month")] + [DataRow("smalldatetime", "hour")] + [DataRow("datetime", "hour")] + [DataRow("datetime2", "hour")] + [DataRow("datetimeoffset", "hour")] + [DataRow("time", "hour")] + [DataRow("varchar(100)", "hour")] + public void DateTruncReturnsOriginalDataType(string type, string datePart) + { + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = $@" +DECLARE @dt {type} = '2024-10-04 12:01:02'; +SELECT DATETRUNC({datePart}, @dt);"; + + if (type == "varchar(100)") + type = "datetime2"; + + using (var reader = cmd.ExecuteReader()) + { + var schema = reader.GetSchemaTable(); + + Assert.AreEqual(type, schema.Rows[0]["DataTypeName"]); + + Assert.IsTrue(reader.Read()); + + switch (type) + { + case "date": + Assert.AreEqual(new DateTime(2024, 10, 1), reader.GetValue(0)); + break; + + case "smalldatetime": + Assert.AreEqual(new DateTime(2024, 10, 4, 12, 0, 0), reader.GetValue(0)); + break; + + case "datetime": + case "datetime2": + Assert.AreEqual(new DateTime(2024, 10, 4, 12, 0, 0), reader.GetValue(0)); + break; + + case "datetimeoffset": + Assert.AreEqual(new DateTimeOffset(2024, 10, 4, 12, 0, 0, TimeSpan.Zero), reader.GetValue(0)); + break; + + case "time": + Assert.AreEqual(new TimeSpan(12, 0, 0), reader.GetValue(0)); + break; + } + + Assert.IsFalse(reader.Read()); + } + } + } + + [TestMethod] + public void DateDiffString() + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datediff-transact-sql?view=sql-server-ver16#i-finding-difference-between-startdate-and-enddate-as-date-parts-strings + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = @" +-- DOES NOT ACCOUNT FOR LEAP YEARS +DECLARE @date1 DATETIME, @date2 DATETIME, @result VARCHAR(100); +DECLARE @years INT, @months INT, @days INT, + @hours INT, @minutes INT, @seconds INT, @milliseconds INT; + +SET @date1 = '1900-01-01 00:00:00.000' +SET @date2 = '2018-12-12 07:08:01.123' + +SELECT @years = DATEDIFF(yy, @date1, @date2) +IF DATEADD(yy, -@years, @date2) < @date1 +SELECT @years = @years-1 +SET @date2 = DATEADD(yy, -@years, @date2) + +SELECT @months = DATEDIFF(mm, @date1, @date2) +IF DATEADD(mm, -@months, @date2) < @date1 +SELECT @months=@months-1 +SET @date2= DATEADD(mm, -@months, @date2) + +SELECT @days=DATEDIFF(dd, @date1, @date2) +IF DATEADD(dd, -@days, @date2) < @date1 +SELECT @days=@days-1 +SET @date2= DATEADD(dd, -@days, @date2) + +SELECT @hours=DATEDIFF(hh, @date1, @date2) +IF DATEADD(hh, -@hours, @date2) < @date1 +SELECT @hours=@hours-1 +SET @date2= DATEADD(hh, -@hours, @date2) + +SELECT @minutes=DATEDIFF(mi, @date1, @date2) +IF DATEADD(mi, -@minutes, @date2) < @date1 +SELECT @minutes=@minutes-1 +SET @date2= DATEADD(mi, -@minutes, @date2) + +SELECT @seconds=DATEDIFF(s, @date1, @date2) +IF DATEADD(s, -@seconds, @date2) < @date1 +SELECT @seconds=@seconds-1 +SET @date2= DATEADD(s, -@seconds, @date2) + +SELECT @milliseconds=DATEDIFF(ms, @date1, @date2) + +SELECT @result= ISNULL(CAST(NULLIF(@years,0) AS VARCHAR(10)) + ' years,','') + + ISNULL(' ' + CAST(NULLIF(@months,0) AS VARCHAR(10)) + ' months,','') + + ISNULL(' ' + CAST(NULLIF(@days,0) AS VARCHAR(10)) + ' days,','') + + ISNULL(' ' + CAST(NULLIF(@hours,0) AS VARCHAR(10)) + ' hours,','') + + ISNULL(' ' + CAST(@minutes AS VARCHAR(10)) + ' minutes and','') + + ISNULL(' ' + CAST(@seconds AS VARCHAR(10)) + + CASE + WHEN @milliseconds > 0 + THEN '.' + CAST(@milliseconds AS VARCHAR(10)) + ELSE '' + END + + ' seconds','') + +SELECT @result"; + + var actual = (string)cmd.ExecuteScalar(); + Assert.AreEqual("118 years, 11 months, 11 days, 7 hours, 8 minutes and 1.123 seconds", actual); + } + } + + [DataTestMethod] + [DataRow("datetimeoffset", ".1234567")] + [DataRow("datetimeoffset(4)", ".1235")] + [DataRow("datetimeoffset(0)", "")] + public void DateTimeOffsetToString(string type, string suffix) + { + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = $"SELECT CAST(CAST('2024-10-04 12:01:02.1234567 +05:10' AS {type}) AS VARCHAR(100))"; + + var actual = (string)cmd.ExecuteScalar(); + Assert.AreEqual($"2024-10-04 12:01:02{suffix} +05:10", actual); + } + } + + [DataTestMethod] + [DataRow("mdy", "2003-01-02")] + [DataRow("dmy", "2003-02-01")] + [DataRow("ymd", "2001-02-03")] + [DataRow("ydm", "2001-03-02")] + [DataRow("myd", "2002-01-03")] + [DataRow("dym", "2002-03-01")] + public void SetDataFormat(string format, string expected) + { + using (var con = new Sql4CdsConnection(_localDataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = $@" +SET DATEFORMAT {format}; +SELECT CAST('01/02/03' AS DATETIME)"; + + var actual = (DateTime)cmd.ExecuteScalar(); + Assert.AreEqual(DateTime.ParseExact(expected, "yyyy-MM-dd", CultureInfo.InvariantCulture), actual); + } + } + + [TestMethod] + public void ErrorNumberPersistedBetweenExecutions() + { + using (var con = new Sql4CdsConnection(_localDataSources)) + { + using (var cmd = con.CreateCommand()) + { + try + { + cmd.CommandText = "SELECT 1/0"; + cmd.ExecuteScalar(); + Assert.Fail(); + } + catch (Sql4CdsException ex) + { + if (ex.Number != 8134) + Assert.Fail(); + } + } + + using (var cmd = con.CreateCommand()) + { + // Error should be persisted in the connection session from the previous command + cmd.CommandText = "SELECT @@ERROR"; + var error = (int)cmd.ExecuteScalar(); + Assert.AreEqual(8134, error); + } + + using (var cmd = con.CreateCommand()) + { + // Error should be reset by the previous execution + cmd.CommandText = "SELECT @@ERROR"; + var error = (int)cmd.ExecuteScalar(); + Assert.AreEqual(0, error); + } + } + } } } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs index 65b2ad49..c29082a4 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs @@ -46,6 +46,8 @@ public class CteTests : FakeXrmEasyTestsBase, IQueryExecutionOptions bool IQueryExecutionOptions.BypassCustomPlugins => false; + public event EventHandler PrimaryDataSourceChanged; + void IQueryExecutionOptions.ConfirmInsert(ConfirmDmlStatementEventArgs e) { } @@ -78,7 +80,7 @@ void IQueryExecutionOptions.Progress(double? progress, string message) [TestMethod] public void SimpleSelect() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS (SELECT accountid, name FROM account) @@ -102,7 +104,7 @@ WITH cte AS (SELECT accountid, name FROM account) [TestMethod] public void ColumnAliases() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte (id, n) AS (SELECT accountid, name FROM account) @@ -126,7 +128,7 @@ WITH cte (id, n) AS (SELECT accountid, name FROM account) [TestMethod] public void MultipleAnchorQueries() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte (id, n) AS (SELECT accountid, name FROM account UNION ALL select contactid, fullname FROM contact) @@ -159,7 +161,7 @@ WITH cte (id, n) AS (SELECT accountid, name FROM account UNION ALL select contac [TestMethod] public void MergeFilters() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS (SELECT contactid, firstname, lastname FROM contact WHERE firstname = 'Mark') @@ -188,7 +190,7 @@ WITH cte AS (SELECT contactid, firstname, lastname FROM contact WHERE firstname [TestMethod] public void MultipleReferencesWithAliases() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS (SELECT contactid, firstname, lastname FROM contact WHERE firstname = 'Mark') @@ -226,7 +228,7 @@ WITH cte AS (SELECT contactid, firstname, lastname FROM contact WHERE firstname [TestMethod] public void MultipleReferencesInUnionAll() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS (SELECT contactid, firstname, lastname FROM contact WHERE firstname = 'Mark') @@ -262,7 +264,7 @@ WITH cte AS (SELECT contactid, firstname, lastname FROM contact WHERE firstname [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void MultipleRecursiveReferences() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -279,7 +281,7 @@ UNION ALL [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void HintsOnRecursiveReference() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -296,7 +298,7 @@ UNION ALL [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void RecursionWithoutUnionAll() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -313,7 +315,7 @@ SELECT cte.* FROM cte [ExpectedException(typeof(QueryParseException))] public void OrderByWithoutTop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -331,7 +333,7 @@ ORDER BY firstname [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void GroupByOnRecursiveReference() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -348,7 +350,7 @@ UNION ALL [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void AggregateOnRecursiveReference() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -365,7 +367,7 @@ SELECT MIN(contactid), MIN(firstname), MIN(lastname) FROM cte [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void TopOnRecursiveReference() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -382,7 +384,7 @@ SELECT TOP 10 cte.* FROM cte [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void OuterJoinOnRecursiveReference() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -399,7 +401,7 @@ UNION ALL [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void SubqueryOnRecursiveReference() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -416,7 +418,7 @@ UNION ALL [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void IncorrectColumnCount() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte (id, fname) AS ( @@ -431,7 +433,7 @@ WITH cte (id, fname) AS ( [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void AnonymousColumn() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -446,7 +448,7 @@ WITH cte AS ( [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void MissingAnchorQuery() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte (x, y) AS ( @@ -460,7 +462,7 @@ WITH cte (x, y) AS ( [TestMethod] public void AliasedAnonymousColumn() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte (id, fname, lname) AS ( @@ -498,7 +500,7 @@ WITH cte (id, fname, lname) AS ( [TestMethod] public void SelectStarFromValues() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH source_data_cte AS ( @@ -535,7 +537,7 @@ WITH source_data_cte AS ( [TestMethod] public void SimpleRecursion() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH cte AS ( @@ -675,7 +677,7 @@ union all [TestMethod] public void Under() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH account_hierarchical(accountid) AS ( @@ -706,7 +708,7 @@ UNION ALL [TestMethod] public void EqOrUnder() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH account_hierarchical(accountid) AS ( @@ -737,7 +739,7 @@ UNION ALL [TestMethod] public void Above() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH account_hierarchical(accountid, parentaccountid) AS ( @@ -768,7 +770,7 @@ UNION ALL [TestMethod] public void EqOrAbove() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" WITH account_hierarchical(accountid, parentaccountid) AS ( diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs index 86dbaf33..cf5050d0 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs @@ -37,7 +37,7 @@ public void ConstantScanTest() Alias = "test" }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).ToArray(); + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).ToArray(); Assert.AreEqual(1, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["test.firstname"]).Value); @@ -84,7 +84,7 @@ public void FilterNodeTest() } }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).ToArray(); + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).ToArray(); Assert.AreEqual(1, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["test.firstname"]).Value); @@ -161,7 +161,7 @@ public void MergeJoinInnerTest() JoinType = QualifiedJoinType.Inner }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).ToArray(); + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).ToArray(); Assert.AreEqual(2, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["f.firstname"]).Value); @@ -241,7 +241,7 @@ public void MergeJoinLeftOuterTest() JoinType = QualifiedJoinType.LeftOuter }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).ToArray(); + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).ToArray(); Assert.AreEqual(3, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["f.firstname"]).Value); @@ -323,7 +323,7 @@ public void MergeJoinRightOuterTest() JoinType = QualifiedJoinType.RightOuter }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).ToArray(); + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).ToArray(); Assert.AreEqual(3, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["f.firstname"]).Value); @@ -362,7 +362,7 @@ public void AssertionTest() ErrorMessage = "Only Mark is allowed" }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).GetEnumerator(); + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).GetEnumerator(); Assert.IsTrue(results.MoveNext()); Assert.AreEqual("Mark", results.Current.GetAttributeValue("test.name").Value); @@ -420,7 +420,7 @@ public void ComputeScalarTest() } }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)) + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)) .Select(e => e.GetAttributeValue("mul").Value) .ToArray(); @@ -462,7 +462,7 @@ public void DistinctTest() } }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)) + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)) .Select(e => e.GetAttributeValue("test.value1").Value) .ToArray(); @@ -504,7 +504,7 @@ public void DistinctCaseInsensitiveTest() } }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)) + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)) .Select(e => e.GetAttributeValue("test.value1").Value) .ToArray(); @@ -562,7 +562,7 @@ public void SortNodeTest() } }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)) + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)) .Select(e => e.GetAttributeValue("test.expectedorder").Value) .ToArray(); @@ -627,7 +627,7 @@ public void SortNodePresortedTest() } }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)) + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)) .Select(e => e.GetAttributeValue("test.expectedorder").Value) .ToArray(); @@ -659,11 +659,11 @@ public void TableSpoolTest() var spool = new TableSpoolNode { Source = source }; - var results1 = spool.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)) + var results1 = spool.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)) .Select(e => e.GetAttributeValue("test.value1").Value) .ToArray(); - var results2 = spool.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)) + var results2 = spool.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)) .Select(e => e.GetAttributeValue("test.value1").Value) .ToArray(); @@ -708,7 +708,7 @@ public void CaseInsenstiveHashMatchAggregateNodeTest() } }; - var results = spool.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)) + var results = spool.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)) .Select(e => new { Name = e.GetAttributeValue("src.value1").Value, Count = e.GetAttributeValue("count").Value }) .ToArray(); @@ -719,7 +719,7 @@ public void CaseInsenstiveHashMatchAggregateNodeTest() public void SqlTransformSchemaOnly() { var sql = "SELECT name FROM account; DECLARE @id uniqueidentifier; SELECT name FROM account WHERE accountid = @id"; - var transformed = SqlNode.ApplyCommandBehavior(sql, System.Data.CommandBehavior.SchemaOnly, new StubOptions()); + var transformed = SqlNode.ApplyCommandBehavior(sql, System.Data.CommandBehavior.SchemaOnly, new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)); transformed = Regex.Replace(transformed, "[ \\r\\n]+", " ").Trim(); Assert.AreEqual("SELECT name FROM account WHERE 0 = 1; DECLARE @id AS UNIQUEIDENTIFIER; SELECT name FROM account WHERE accountid = @id AND 0 = 1;", transformed); @@ -729,7 +729,7 @@ public void SqlTransformSchemaOnly() public void SqlTransformSingleRow() { var sql = "SELECT name FROM account; DECLARE @id uniqueidentifier; SELECT name FROM account WHERE accountid = @id"; - var transformed = SqlNode.ApplyCommandBehavior(sql, System.Data.CommandBehavior.SingleRow, new StubOptions()); + var transformed = SqlNode.ApplyCommandBehavior(sql, System.Data.CommandBehavior.SingleRow, new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)); transformed = Regex.Replace(transformed, "[ \\r\\n]+", " ").Trim(); Assert.AreEqual("SELECT TOP 1 name FROM account; DECLARE @id AS UNIQUEIDENTIFIER;", transformed); @@ -739,7 +739,7 @@ public void SqlTransformSingleRow() public void SqlTransformSingleResult() { var sql = "SELECT name FROM account; DECLARE @id uniqueidentifier; SELECT name FROM account WHERE accountid = @id"; - var transformed = SqlNode.ApplyCommandBehavior(sql, System.Data.CommandBehavior.SingleResult, new StubOptions()); + var transformed = SqlNode.ApplyCommandBehavior(sql, System.Data.CommandBehavior.SingleResult, new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)); transformed = Regex.Replace(transformed, "[ \\r\\n]+", " ").Trim(); Assert.AreEqual("SELECT name FROM account; DECLARE @id AS UNIQUEIDENTIFIER;", transformed); @@ -749,7 +749,7 @@ public void SqlTransformSingleResult() public void AggregateInitialTest() { var aggregate = CreateAggregateTest(); - var result = aggregate.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).Single(); + var result = aggregate.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).Single(); Assert.AreEqual(SqlInt32.Null, result["min"]); Assert.AreEqual(SqlInt32.Null, result["max"]); @@ -767,7 +767,7 @@ public void AggregateInitialTest() public void AggregateSingleValueTest() { var aggregate = CreateAggregateTest(1); - var result = aggregate.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).Single(); + var result = aggregate.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).Single(); Assert.AreEqual((SqlInt32)1, result["min"]); Assert.AreEqual((SqlInt32)1, result["max"]); @@ -785,7 +785,7 @@ public void AggregateSingleValueTest() public void AggregateTwoEqualValuesTest() { var aggregate = CreateAggregateTest(1, 1); - var result = aggregate.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).Single(); + var result = aggregate.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).Single(); Assert.AreEqual((SqlInt32)1, result["min"]); Assert.AreEqual((SqlInt32)1, result["max"]); @@ -803,7 +803,7 @@ public void AggregateTwoEqualValuesTest() public void AggregateMultipleValuesTest() { var aggregate = CreateAggregateTest(1, 3, 1, 1); - var result = aggregate.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).Single(); + var result = aggregate.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).Single(); Assert.AreEqual((SqlInt32)1, result["min"]); Assert.AreEqual((SqlInt32)3, result["max"]); @@ -962,7 +962,7 @@ public void NestedLoopJoinInnerTest() } }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).ToArray(); + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).ToArray(); Assert.AreEqual(2, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["f.firstname"]).Value); @@ -1034,7 +1034,7 @@ public void NestedLoopJoinLeftOuterTest() } }; - var results = node.Execute(new NodeExecutionContext(_localDataSources, new StubOptions(), null, null, null)).ToArray(); + var results = node.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null, null)).ToArray(); Assert.AreEqual(3, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["f.firstname"]).Value); @@ -1068,7 +1068,7 @@ public void FetchXmlSingleTablePrimaryKey() } } }; - var schema = fetch.GetSchema(new NodeCompilationContext(_localDataSources, new StubOptions(), null, null)); + var schema = fetch.GetSchema(new NodeCompilationContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null)); Assert.AreEqual("account.accountid", schema.PrimaryKey); } @@ -1108,7 +1108,7 @@ public void FetchXmlChildTablePrimaryKey() } } }; - var schema = fetch.GetSchema(new NodeCompilationContext(_localDataSources, new StubOptions(), null, null)); + var schema = fetch.GetSchema(new NodeCompilationContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null)); Assert.AreEqual("contact.contactid", schema.PrimaryKey); } @@ -1148,7 +1148,7 @@ public void FetchXmlChildTableOuterJoinPrimaryKey() } } }; - var schema = fetch.GetSchema(new NodeCompilationContext(_localDataSources, new StubOptions(), null, null)); + var schema = fetch.GetSchema(new NodeCompilationContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null)); Assert.IsNull(schema.PrimaryKey); } @@ -1188,7 +1188,7 @@ public void FetchXmlParentTablePrimaryKey() } } }; - var schema = fetch.GetSchema(new NodeCompilationContext(_localDataSources, new StubOptions(), null, null)); + var schema = fetch.GetSchema(new NodeCompilationContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null)); Assert.AreEqual("account.accountid", schema.PrimaryKey); } @@ -1228,7 +1228,7 @@ public void FetchXmlParentTableOuterJoinPrimaryKey() } } }; - var schema = fetch.GetSchema(new NodeCompilationContext(_localDataSources, new StubOptions(), null, null)); + var schema = fetch.GetSchema(new NodeCompilationContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null)); Assert.AreEqual("account.accountid", schema.PrimaryKey); } @@ -1268,7 +1268,7 @@ public void FetchXmlChildTableFreeTextJoinPrimaryKey() } } }; - var schema = fetch.GetSchema(new NodeCompilationContext(_localDataSources, new StubOptions(), null, null)); + var schema = fetch.GetSchema(new NodeCompilationContext(new SessionContext(_localDataSources, new StubOptions()), new StubOptions(), null, null)); Assert.IsNull(schema.PrimaryKey); } } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index 33827bb5..160a7df0 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -46,6 +46,8 @@ public class ExecutionPlanTests : FakeXrmEasyTestsBase, IQueryExecutionOptions bool IQueryExecutionOptions.BypassCustomPlugins => false; + public event EventHandler PrimaryDataSourceChanged; + void IQueryExecutionOptions.ConfirmInsert(ConfirmDmlStatementEventArgs e) { } @@ -78,7 +80,7 @@ void IQueryExecutionOptions.Progress(double? progress, string message) [TestMethod] public void SimpleSelect() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT accountid, name FROM account"; @@ -100,7 +102,7 @@ public void SimpleSelect() [TestMethod] public void SimpleSelectStar() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account"; @@ -136,7 +138,7 @@ public void SimpleSelectStar() [TestMethod] public void Join() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT accountid, name FROM account INNER JOIN contact ON account.accountid = contact.parentcustomerid"; @@ -160,7 +162,7 @@ public void Join() [TestMethod] public void JoinWithExtraCondition() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -193,7 +195,7 @@ public void JoinWithExtraCondition() [TestMethod] public void NonUniqueJoin() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT accountid, name FROM account INNER JOIN contact ON account.name = contact.fullname"; @@ -221,7 +223,7 @@ public void NonUniqueJoin() [TestMethod] public void NonUniqueJoinExpression() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT accountid, name FROM account INNER JOIN contact ON account.name = (contact.firstname + ' ' + contact.lastname)"; @@ -258,7 +260,7 @@ public void NonUniqueJoinExpression() [TestMethod] public void SimpleWhere() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -290,7 +292,7 @@ public void SimpleWhere() [TestMethod] public void WhereColumnComparison() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -322,7 +324,7 @@ public void WhereColumnComparison() [TestMethod] public void WhereColumnComparisonCrossTable() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -359,7 +361,7 @@ public void WhereColumnComparisonCrossTable() [TestMethod] public void SimpleSort() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -389,7 +391,7 @@ ORDER BY [TestMethod] public void SimpleSortIndex() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -419,7 +421,7 @@ ORDER BY [TestMethod] public void SimpleDistinct() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT DISTINCT @@ -445,7 +447,7 @@ SELECT DISTINCT [TestMethod] public void SimpleTop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 10 @@ -472,7 +474,7 @@ SELECT TOP 10 [TestMethod] public void SimpleOffset() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -502,7 +504,7 @@ ORDER BY name [TestMethod] public void SimpleGroupAggregate() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -568,7 +570,7 @@ public void SimpleGroupAggregate() [TestMethod] public void AliasedAggregate() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -618,7 +620,7 @@ public void AliasedAggregate() [TestMethod] public void AliasedGroupingAggregate() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -668,7 +670,7 @@ public void AliasedGroupingAggregate() [TestMethod] public void SimpleAlias() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT accountid, name AS test FROM account"; @@ -690,7 +692,7 @@ public void SimpleAlias() [TestMethod] public void SimpleHaving() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -756,7 +758,7 @@ GROUP BY name [TestMethod] public void GroupByDatePart() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -814,7 +816,7 @@ public void GroupByDatePart() [TestMethod] public void GroupByDatePartUsingYearMonthDayFunctions() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -893,7 +895,7 @@ GROUP BY [TestMethod] public void PartialOrdering() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -932,7 +934,7 @@ public void OrderByEntityName() { using (_localDataSource.SetOrderByEntityName(true)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 1000 @@ -970,7 +972,7 @@ ORDER BY [TestMethod] public void PartialOrderingAvoidingLegacyPagingWithTop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 100 @@ -1007,7 +1009,7 @@ ORDER BY [TestMethod] public void PartialWhere() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -1043,7 +1045,7 @@ public void PartialWhere() [TestMethod] public void ComputeScalarSelect() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname + ' ' + lastname AS fullname FROM contact WHERE firstname = 'Mark'"; @@ -1072,7 +1074,7 @@ public void ComputeScalarSelect() [TestMethod] public void ComputeScalarFilter() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT contactid FROM contact WHERE firstname + ' ' + lastname = 'Mark Carrington'"; @@ -1098,7 +1100,7 @@ public void ComputeScalarFilter() [TestMethod] public void SelectSubqueryWithMergeJoin() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE accountid = parentcustomerid) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1131,7 +1133,7 @@ public void SelectSubqueryWithMergeJoin() [TestMethod] public void SelectSubqueryWithNestedLoop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE createdon = contact.createdon) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1181,7 +1183,7 @@ public void SelectSubqueryWithNestedLoop() [TestMethod] public void SelectSubqueryWithChildRecordUsesNestedLoop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name, (SELECT TOP 1 fullname FROM contact WHERE parentcustomerid = account.accountid) FROM account WHERE name = 'Data8'"; @@ -1224,7 +1226,7 @@ public void SelectSubqueryWithChildRecordUsesNestedLoop() [TestMethod] public void SelectSubqueryWithSmallNestedLoop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 10 firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE createdon = contact.createdon) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1269,7 +1271,7 @@ public void SelectSubqueryWithSmallNestedLoop() [TestMethod] public void SelectSubqueryWithNonCorrelatedNestedLoop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT TOP 1 name FROM account) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1311,7 +1313,7 @@ public void SelectSubqueryWithNonCorrelatedNestedLoop() [TestMethod] public void SelectSubqueryWithCorrelatedSpooledNestedLoop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE createdon = contact.createdon) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1362,7 +1364,7 @@ public void SelectSubqueryWithCorrelatedSpooledNestedLoop() [TestMethod] public void SelectSubqueryWithPartiallyCorrelatedSpooledNestedLoop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE createdon = contact.createdon AND employees > 10) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1415,7 +1417,7 @@ public void SelectSubqueryWithPartiallyCorrelatedSpooledNestedLoop() public void SelectSubqueryUsingOuterReferenceInSelectClause() { var tableSize = new StubTableSizeCache(); - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT firstname + ' ' + name FROM account WHERE accountid = parentcustomerid) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1464,7 +1466,7 @@ public void SelectSubqueryUsingOuterReferenceInSelectClause() [TestMethod] public void SelectSubqueryUsingOuterReferenceInOrderByClause() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname FROM contact ORDER BY (SELECT TOP 1 name FROM account WHERE accountid = parentcustomerid ORDER BY firstname)"; @@ -1504,7 +1506,7 @@ public void SelectSubqueryUsingOuterReferenceInOrderByClause() [TestMethod] public void WhereSubquery() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT firstname + ' ' + lastname AS fullname FROM contact WHERE (SELECT name FROM account WHERE accountid = parentcustomerid) = 'Data8'"; @@ -1535,7 +1537,7 @@ public void WhereSubquery() [TestMethod] public void ComputeScalarDistinct() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT DISTINCT TOP 10 @@ -1565,7 +1567,7 @@ SELECT DISTINCT TOP 10 [TestMethod] public void Union() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name FROM account @@ -1608,7 +1610,7 @@ SELECT name FROM account [TestMethod] public void UnionMultiple() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name FROM account @@ -1662,7 +1664,7 @@ SELECT fullname FROM contact [TestMethod] public void UnionSort() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name FROM account @@ -1708,7 +1710,7 @@ SELECT fullname FROM contact [TestMethod] public void UnionSortOnAlias() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name AS n FROM account @@ -1754,7 +1756,7 @@ SELECT fullname FROM contact [TestMethod] public void UnionSortOnAliasedColumnsOriginalName() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name AS n FROM account @@ -1800,7 +1802,7 @@ SELECT fullname FROM contact [TestMethod] public void UnionAll() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name FROM account @@ -1839,7 +1841,7 @@ UNION ALL [TestMethod] public void SimpleInFilter() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -1874,7 +1876,7 @@ public void SimpleInFilter() [TestMethod] public void SubqueryInFilterUncorrelated() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -1915,7 +1917,7 @@ public void SubqueryInFilterUncorrelated() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void SubqueryInFilterMultipleColumnsError() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -1932,7 +1934,7 @@ public void SubqueryInFilterMultipleColumnsError() [TestMethod] public void SubqueryInFilterUncorrelatedPrimaryKey() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -1967,7 +1969,7 @@ public void SubqueryInFilterUncorrelatedPrimaryKey() [TestMethod] public void SubqueryInFilterCorrelated() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2015,7 +2017,7 @@ public void SubqueryInFilterCorrelated() [TestMethod] public void SubqueryNotInFilterCorrelated() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2063,7 +2065,7 @@ public void SubqueryNotInFilterCorrelated() [TestMethod] public void ExistsFilterUncorrelated() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2101,7 +2103,7 @@ public void ExistsFilterUncorrelated() [TestMethod] public void ExistsFilterCorrelatedPrimaryKey() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2133,7 +2135,7 @@ public void ExistsFilterCorrelatedPrimaryKey() [TestMethod] public void ExistsFilterCorrelatedPrimaryKeyOr() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2170,7 +2172,7 @@ public void ExistsFilterCorrelatedPrimaryKeyOr() [TestMethod] public void ExistsFilterCorrelated() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2216,7 +2218,7 @@ public void ExistsFilterCorrelatedWithAny() { using (_localDataSource.EnableJoinOperator(JoinOperator.Any)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2255,7 +2257,7 @@ public void ExistsFilterCorrelatedWithAnyParentAndChildAndAdditionalFilter() { using (_localDataSource.EnableJoinOperator(JoinOperator.Any)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2302,7 +2304,7 @@ public void NotExistsFilterCorrelatedOnLinkEntity() { using (_localDataSource.EnableJoinOperator(JoinOperator.NotAny)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2342,7 +2344,7 @@ public void NotExistsFilterCorrelatedOnLinkEntity() [TestMethod] public void NotExistsFilterCorrelated() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2377,7 +2379,7 @@ public void NotExistsFilterCorrelated() [TestMethod] public void QueryDerivedTableSimple() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 10 @@ -2408,7 +2410,7 @@ SELECT TOP 10 [TestMethod] public void QueryDerivedTableAlias() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 10 @@ -2439,7 +2441,7 @@ SELECT TOP 10 [TestMethod] public void QueryDerivedTableValues() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 10 @@ -2458,7 +2460,7 @@ SELECT TOP 10 var filter = AssertNode(top.Source); var constant = AssertNode(filter.Source); - var schema = constant.GetSchema(new NodeCompilationContext(_dataSources, this, null, null)); + var schema = constant.GetSchema(new NodeCompilationContext(new SessionContext(_localDataSources, this), this, null, null)); Assert.AreEqual(typeof(SqlInt32), schema.Schema["a.ID"].Type.ToNetType(out _)); Assert.AreEqual(typeof(SqlString), schema.Schema["a.name"].Type.ToNetType(out _)); } @@ -2466,7 +2468,7 @@ SELECT TOP 10 [TestMethod] public void NoLockTableHint() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 10 @@ -2496,7 +2498,7 @@ SELECT TOP 10 [TestMethod] public void CrossJoin() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2535,7 +2537,7 @@ CROSS JOIN [TestMethod] public void CrossApply() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2572,7 +2574,7 @@ FROM contact [TestMethod] public void CrossApplyAllColumns() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2619,7 +2621,7 @@ FROM contact [TestMethod] public void CrossApplyRestrictedColumns() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2656,7 +2658,7 @@ FROM contact [TestMethod] public void CrossApplyRestrictedColumnsWithAlias() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2696,7 +2698,7 @@ FROM contact [TestMethod] public void CrossApplyJoin() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2743,7 +2745,7 @@ FROM contact [TestMethod] public void OuterApply() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2780,7 +2782,7 @@ FROM contact [TestMethod] public void OuterApplyNestedLoop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2838,7 +2840,7 @@ ORDER BY firstname [TestMethod] public void FetchXmlNativeWhere() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -2870,7 +2872,7 @@ public void FetchXmlNativeWhere() [TestMethod] public void SimpleMetadataSelect() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT logicalname @@ -2891,7 +2893,7 @@ SELECT logicalname [TestMethod] public void SimpleMetadataWhere() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT logicalname @@ -2917,7 +2919,7 @@ FROM metadata.entity [TestMethod] public void CaseSensitiveMetadataWhere() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT logicalname @@ -2948,7 +2950,7 @@ FROM metadata.entity [TestMethod] public void SimpleUpdate() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "UPDATE account SET name = 'foo' WHERE name = 'bar'"; @@ -2977,7 +2979,7 @@ public void SimpleUpdate() [TestMethod] public void UpdateFromJoin() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "UPDATE a SET name = 'foo' FROM account a INNER JOIN contact c ON a.accountid = c.parentcustomerid WHERE name = 'bar'"; @@ -3009,7 +3011,7 @@ public void UpdateFromJoin() [TestMethod] public void QueryHints() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT accountid, name FROM account OPTION (OPTIMIZE FOR UNKNOWN, FORCE ORDER, RECOMPILE, USE HINT('DISABLE_OPTIMIZER_ROWGOAL'), USE HINT('ENABLE_QUERY_OPTIMIZER_HOTFIXES'), LOOP JOIN, MERGE JOIN, HASH JOIN, NO_PERFORMANCE_SPOOL, MAXRECURSION 2)"; @@ -3031,7 +3033,7 @@ public void QueryHints() [TestMethod] public void AggregateSort() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name, count(*) from account group by name order by 2 desc"; @@ -3083,7 +3085,7 @@ public void AggregateSort() [TestMethod] public void FoldFilterWithNonFoldedJoin() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name from account INNER JOIN contact ON left(name, 4) = left(firstname, 4) where name like 'Data8%' and firstname like 'Mark%'"; @@ -3124,7 +3126,7 @@ public void FoldFilterWithNonFoldedJoin() [TestMethod] public void FoldFilterWithInClause() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name from account where name like 'Data8%' and primarycontactid in (select contactid from contact where firstname = 'Mark')"; @@ -3153,7 +3155,7 @@ public void FoldFilterWithInClause() [TestMethod] public void FoldFilterWithInClauseOr() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name from account where name like 'Data8%' or primarycontactid in (select contactid from contact where firstname = 'Mark')"; @@ -3185,7 +3187,7 @@ public void FoldFilterWithInClauseWithoutPrimaryKey() { using (_localDataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name from account where name like 'Data8%' and createdon in (select createdon from contact where firstname = 'Mark')"; @@ -3215,7 +3217,7 @@ public void FoldFilterWithInClauseWithoutPrimaryKey() [TestMethod] public void FoldNotInToLeftOuterJoin() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "uat" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "uat" }), new OptionsWrapper(this) { PrimaryDataSource = "uat" }); var query = "SELECT name from account where name like 'Data8%' and createdon not in (select createdon from contact where firstname = 'Mark')"; @@ -3248,7 +3250,7 @@ public void FoldFilterWithInClauseOnLinkEntityWithoutPrimaryKey() { using (_localDataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name from account inner join contact on account.accountid = contact.parentcustomerid where name like 'Data8%' and contact.createdon in (select createdon from contact where firstname = 'Mark')"; @@ -3282,7 +3284,7 @@ public void FoldFilterWithExistsClauseWithoutPrimaryKey() { using (_localDataSource.EnableJoinOperator(JoinOperator.Exists)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name from account where name like 'Data8%' and exists (select * from contact where firstname = 'Mark' and createdon = account.createdon)"; @@ -3313,7 +3315,7 @@ public void FoldFilterWithExistsClauseWithoutPrimaryKey() [TestMethod] public void DistinctNotRequiredWithPrimaryKey() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT DISTINCT accountid, name from account"; @@ -3335,7 +3337,7 @@ public void DistinctNotRequiredWithPrimaryKey() [TestMethod] public void DistinctRequiredWithoutPrimaryKey() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT DISTINCT accountid, name from account INNER JOIN contact ON account.accountid = contact.parentcustomerid"; @@ -3361,7 +3363,7 @@ public void DistinctRequiredWithoutPrimaryKey() [TestMethod] public void SimpleDelete() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "DELETE FROM account WHERE name = 'bar'"; @@ -3387,7 +3389,7 @@ public void SimpleDelete() [TestMethod] public void SimpleInsertSelect() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "INSERT INTO account (name) SELECT fullname FROM contact WHERE firstname = 'Mark'"; @@ -3413,7 +3415,7 @@ public void SimpleInsertSelect() [TestMethod] public void SelectDuplicateColumnNames() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT fullname, lastname + ', ' + firstname as fullname FROM contact WHERE firstname = 'Mark'"; @@ -3446,7 +3448,7 @@ public void SelectDuplicateColumnNames() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void SubQueryDuplicateColumnNamesError() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM (SELECT fullname, lastname + ', ' + firstname as fullname FROM contact WHERE firstname = 'Mark') a"; @@ -3456,7 +3458,7 @@ public void SubQueryDuplicateColumnNamesError() [TestMethod] public void UnionDuplicateColumnNames() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT fullname, lastname + ', ' + firstname as fullname FROM contact WHERE firstname = 'Mark' UNION @@ -3487,7 +3489,7 @@ public void UnionDuplicateColumnNames() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void SubQueryUnionDuplicateColumnNamesError() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT * FROM ( SELECT fullname, lastname + ', ' + firstname as fullname FROM contact WHERE firstname = 'Mark' UNION @@ -3523,7 +3525,7 @@ private void AssertFetchXml(FetchXmlScan node, string fetchXml) [TestMethod] public void SelectStarInSubquery() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT * FROM account WHERE accountid IN (SELECT parentcustomerid FROM contact)"; @@ -3563,7 +3565,7 @@ public void SelectStarInSubquery() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void CannotSelectColumnsFromSemiJoin() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT contact.* FROM account WHERE accountid IN (SELECT parentcustomerid FROM contact)"; @@ -3573,7 +3575,7 @@ public void CannotSelectColumnsFromSemiJoin() [TestMethod] public void MinAggregateNotFoldedToFetchXmlForOptionset() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT new_name, min(new_optionsetvalue) FROM new_customentity GROUP BY new_name"; @@ -3598,7 +3600,7 @@ public void MinAggregateNotFoldedToFetchXmlForOptionset() [TestMethod] public void HelpfulErrorMessageOnMissingGroupBy() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT new_name, min(new_optionsetvalue) FROM new_customentity"; @@ -3616,7 +3618,7 @@ public void HelpfulErrorMessageOnMissingGroupBy() [TestMethod] public void AggregateInSubquery() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT firstname FROM contact @@ -3712,7 +3714,7 @@ GROUP BY firstname }, }; - var result = select.Execute(new NodeExecutionContext(_localDataSources, this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var result = select.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(result); @@ -3724,7 +3726,7 @@ GROUP BY firstname [TestMethod] public void SelectVirtualNameAttributeFromLinkEntity() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT parentcustomeridname FROM account INNER JOIN contact ON account.accountid = contact.parentcustomerid"; @@ -3747,7 +3749,7 @@ public void SelectVirtualNameAttributeFromLinkEntity() [TestMethod] public void DuplicatedDistinctColumns() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT DISTINCT name AS n1, name AS n2 FROM account"; @@ -3769,7 +3771,7 @@ public void DuplicatedDistinctColumns() [TestMethod] public void GroupByDatetimeWithoutDatePart() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT createdon, COUNT(*) FROM account GROUP BY createdon"; @@ -3792,7 +3794,7 @@ public void GroupByDatetimeWithoutDatePart() [TestMethod] public void MetadataExpressions() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT collectionschemaname + '.' + entitysetname FROM metadata.entity WHERE description LIKE '%test%'"; @@ -3817,7 +3819,7 @@ public void MetadataExpressions() [TestMethod] public void AliasedAttribute() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name AS n1 FROM account WHERE name = 'test'"; @@ -3844,7 +3846,7 @@ public void AliasedAttribute() [TestMethod] public void MultipleAliases() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name AS n1, name AS n2 FROM account WHERE name = 'test'"; @@ -3898,10 +3900,14 @@ public void CrossInstanceJoin() new DataSource { Name = "local", // Hack so that ((IQueryExecutionOptions)this).PrimaryDataSource = "local" doesn't cause test to fail + Connection = _context2.GetOrganizationService(), + Metadata = metadata2, + TableSizeCache = new StubTableSizeCache(), + MessageCache = new StubMessageCache(), DefaultCollation = Collation.USEnglish } }; - var planBuilder = new ExecutionPlanBuilder(datasources, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(datasources.ToDictionary(d => d.Name), this), this); var query = "SELECT uat.name, prod.name FROM uat.dbo.account AS uat INNER JOIN prod.dbo.account AS prod ON uat.accountid = prod.accountid WHERE uat.name <> prod.name AND uat.name LIKE '%test%'"; @@ -3947,7 +3953,7 @@ public void CrossInstanceJoin() [TestMethod] public void FilterOnGroupByExpression() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -3991,7 +3997,7 @@ GROUP BY [TestMethod] public void SystemFunctions() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT CURRENT_TIMESTAMP, CURRENT_USER, GETDATE(), USER_NAME()"; @@ -4007,7 +4013,7 @@ public void SystemFunctions() [TestMethod] public void FoldEqualsCurrentUser() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name FROM account WHERE ownerid = CURRENT_USER"; @@ -4031,7 +4037,7 @@ public void FoldEqualsCurrentUser() [TestMethod] public void DoNotFoldEqualsCurrentUserOnStringField() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT systemuserid FROM systemuser WHERE domainname = CURRENT_USER"; @@ -4061,7 +4067,7 @@ public void DoNotFoldEqualsCurrentUserOnStringField() [TestMethod] public void UseNestedLoopToInjectDynamicFilterValue_GlobalVariable() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account WHERE turnover = CAST(CURRENT_TIMESTAMP AS int)"; @@ -4091,7 +4097,7 @@ public void UseNestedLoopToInjectDynamicFilterValue_GlobalVariable() [TestMethod] public void UseNestedLoopToInjectDynamicFilterValue_NonDeterministicFunction() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account WHERE turnover = CAST(GETDATE() AS int)"; @@ -4121,7 +4127,7 @@ public void UseNestedLoopToInjectDynamicFilterValue_NonDeterministicFunction() [TestMethod] public void DoNotUseNestedLoopToInjectDynamicFilterValue_DeterministicFunction() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account WHERE turnover = CAST(LEFT('1test', 1) AS int) + 1"; @@ -4145,7 +4151,7 @@ public void DoNotUseNestedLoopToInjectDynamicFilterValue_DeterministicFunction() [TestMethod] public void EntityReferenceInQuery() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name FROM account WHERE accountid IN ('0000000000000000-0000-0000-000000000000', '0000000000000000-0000-0000-000000000001')"; @@ -4172,7 +4178,7 @@ public void EntityReferenceInQuery() [TestMethod] public void OrderBySelectExpression() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name + 'foo' FROM account ORDER BY 1"; @@ -4198,7 +4204,7 @@ public void OrderBySelectExpression() [TestMethod] public void OrderByAlias() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name AS companyname FROM account ORDER BY companyname"; @@ -4223,7 +4229,7 @@ public void OrderByAlias() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void OrderByAliasCantUseExpression() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name AS companyname FROM account ORDER BY companyname + ''"; @@ -4233,7 +4239,7 @@ public void OrderByAliasCantUseExpression() [TestMethod] public void DistinctOrderByUsesScalarAggregate() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT DISTINCT account.accountid FROM metadata.entity INNER JOIN account ON entity.metadataid = account.accountid"; @@ -4263,7 +4269,7 @@ public void DistinctOrderByUsesScalarAggregate() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void WindowFunctionsNotSupported() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT COUNT(accountid) OVER(PARTITION BY accountid) AS test FROM account"; @@ -4273,7 +4279,7 @@ public void WindowFunctionsNotSupported() [TestMethod] public void DeclareVariableSetLiteralSelect() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test int @@ -4310,7 +4316,7 @@ DECLARE @test int { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues, null), CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -4320,7 +4326,7 @@ DECLARE @test int } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues, null), out _, out _); + dmlQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), out _, out _); } } } @@ -4328,7 +4334,7 @@ DECLARE @test int [TestMethod] public void SetVariableInDeclaration() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test int = 1 @@ -4364,7 +4370,7 @@ public void SetVariableInDeclaration() { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues, null), CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -4374,7 +4380,7 @@ public void SetVariableInDeclaration() } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues, null), out _, out _); + dmlQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), out _, out _); } } } @@ -4383,7 +4389,7 @@ public void SetVariableInDeclaration() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void UnknownVariable() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SET @test = 1"; @@ -4394,7 +4400,7 @@ public void UnknownVariable() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void DuplicateVariable() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test INT @@ -4406,7 +4412,7 @@ DECLARE @test INT [TestMethod] public void VariableTypeConversionIntToString() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test varchar(3) @@ -4422,7 +4428,7 @@ DECLARE @test varchar(3) { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(new NodeExecutionContext(_localDataSources, this, parameterTypes, parameterValues, null), CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -4432,7 +4438,7 @@ DECLARE @test varchar(3) } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(new NodeExecutionContext(_localDataSources, this, parameterTypes, parameterValues, null), out _, out _); + dmlQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), out _, out _); } } } @@ -4440,7 +4446,7 @@ DECLARE @test varchar(3) [TestMethod] public void VariableTypeConversionStringTruncation() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test varchar(3) @@ -4456,7 +4462,7 @@ DECLARE @test varchar(3) { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(new NodeExecutionContext(_localDataSources, this, parameterTypes, parameterValues, null), CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -4466,7 +4472,7 @@ DECLARE @test varchar(3) } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(new NodeExecutionContext(_localDataSources, this, parameterTypes, parameterValues, null), out _, out _); + dmlQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), out _, out _); } } } @@ -4475,7 +4481,7 @@ DECLARE @test varchar(3) [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void CannotCombineSetVariableAndDataRetrievalInSelect() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); // A SELECT statement that assigns a value to a variable must not be combined with data-retrieval operations var query = @" @@ -4488,7 +4494,7 @@ DECLARE @test varchar(3) [TestMethod] public void SetVariableWithSelectUsesFinalValue() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test varchar(3) @@ -4548,7 +4554,7 @@ DECLARE @test varchar(3) { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(new NodeExecutionContext(_localDataSources, this, parameterTypes, parameterValues, null), CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -4558,7 +4564,7 @@ DECLARE @test varchar(3) } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(new NodeExecutionContext(_localDataSources, this, parameterTypes, parameterValues, null), out _, out _); + dmlQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), out _, out _); } } } @@ -4566,7 +4572,7 @@ DECLARE @test varchar(3) [TestMethod] public void VarCharLengthDefaultsTo1() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test varchar @@ -4582,7 +4588,7 @@ DECLARE @test varchar { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(new NodeExecutionContext(_localDataSources, this, parameterTypes, parameterValues, null), CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -4592,7 +4598,7 @@ DECLARE @test varchar } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(new NodeExecutionContext(_localDataSources, this, parameterTypes, parameterValues, null), out _, out _); + dmlQuery.Execute(new NodeExecutionContext(new SessionContext(_localDataSources, this), this, parameterTypes, parameterValues, null), out _, out _); } } } @@ -4601,7 +4607,7 @@ DECLARE @test varchar [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void CursorVariableNotSupported() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test CURSOR"; @@ -4613,7 +4619,7 @@ public void CursorVariableNotSupported() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void TableVariableNotSupported() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" DECLARE @test TABLE (ID INT)"; @@ -4624,7 +4630,7 @@ public void TableVariableNotSupported() [TestMethod] public void IfStatement() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this) { EstimatedPlanOnly = true }; + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this) { EstimatedPlanOnly = true }; var query = @" IF @param1 = 1 @@ -4658,7 +4664,7 @@ INSERT INTO account (name) VALUES ('one') [TestMethod] public void WhileStatement() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this) { EstimatedPlanOnly = true }; + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this) { EstimatedPlanOnly = true }; var query = @" WHILE @param1 < 10 @@ -4687,7 +4693,7 @@ INSERT INTO account (name) VALUES (@param1) [TestMethod] public void IfNotExists() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this) { EstimatedPlanOnly = true }; + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this) { EstimatedPlanOnly = true }; var query = @" IF NOT EXISTS(SELECT * FROM account WHERE name = @param1) @@ -4712,7 +4718,7 @@ INSERT INTO account (name) VALUES (@param1) [TestMethod] public void DuplicatedAliases() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT name, createdon AS name FROM account"; @@ -4734,7 +4740,7 @@ public void DuplicatedAliases() [TestMethod] public void MetadataLeftJoinData() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT entity.logicalname, account.name, contact.firstname @@ -4773,7 +4779,7 @@ public void MetadataLeftJoinData() [TestMethod] public void NotEqualExcludesNull() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name FROM account WHERE name <> 'Data8'"; @@ -4799,7 +4805,7 @@ public void NotEqualExcludesNull() [TestMethod] public void DistinctFromAllowsNull() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT name FROM account WHERE name IS DISTINCT FROM 'Data8'"; @@ -4824,7 +4830,7 @@ public void DistinctFromAllowsNull() [TestMethod] public void DoNotFoldFilterOnNameVirtualAttributeWithTooManyJoins() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select top 10 a.name @@ -4917,7 +4923,7 @@ from account a [TestMethod] public void FilterOnVirtualTypeAttributeEquals() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype = 'contact'"; @@ -4941,7 +4947,7 @@ public void FilterOnVirtualTypeAttributeEquals() [TestMethod] public void FilterOnVirtualTypeAttributeEqualsImpossible() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype = 'non-existent-entity'"; @@ -4965,7 +4971,7 @@ public void FilterOnVirtualTypeAttributeEqualsImpossible() [TestMethod] public void FilterOnVirtualTypeAttributeNotEquals() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype <> 'contact'"; @@ -4990,7 +4996,7 @@ public void FilterOnVirtualTypeAttributeNotEquals() [TestMethod] public void FilterOnVirtualTypeAttributeNotInImpossible() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype NOT IN ('account', 'contact')"; @@ -5017,7 +5023,7 @@ public void FilterOnVirtualTypeAttributeNotInImpossible() [TestMethod] public void FilterOnVirtualTypeAttributeNull() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype IS NULL"; @@ -5041,7 +5047,7 @@ public void FilterOnVirtualTypeAttributeNull() [TestMethod] public void FilterOnVirtualTypeAttributeNotNull() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype IS NOT NULL"; @@ -5065,7 +5071,7 @@ public void FilterOnVirtualTypeAttributeNotNull() [TestMethod] public void SubqueriesInValueList() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT a FROM (VALUES ('a'), ((SELECT TOP 1 firstname FROM contact)), ('b'), (1)) AS MyTable (a)"; @@ -5094,7 +5100,7 @@ public void SubqueriesInValueList() [TestMethod] public void FoldFilterOnIdentity() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT name FROM account WHERE accountid = @@IDENTITY"; @@ -5119,7 +5125,7 @@ public void FoldFilterOnIdentity() [TestMethod] public void FoldPrimaryIdInQuery() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT name FROM account WHERE accountid IN (SELECT accountid FROM account INNER JOIN contact ON account.primarycontactid = contact.contactid WHERE name = 'Data8')"; @@ -5145,7 +5151,7 @@ public void FoldPrimaryIdInQuery() [TestMethod] public void FoldPrimaryIdInQueryWithTop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"DELETE FROM account WHERE accountid IN (SELECT TOP 10 accountid FROM account ORDER BY createdon DESC)"; @@ -5168,7 +5174,7 @@ public void FoldPrimaryIdInQueryWithTop() [TestMethod] public void InsertParameters() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"DECLARE @name varchar(100) = 'test'; INSERT INTO account (name) VALUES (@name)"; @@ -5186,7 +5192,7 @@ public void InsertParameters() [TestMethod] public void NotExistsParameters() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"DECLARE @firstname AS VARCHAR (100) = 'Mark', @lastname AS VARCHAR (100) = 'Carrington'; @@ -5227,7 +5233,7 @@ INSERT INTO contact (firstname, lastname) [TestMethod] public void UpdateParameters() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"declare @name varchar(100) = 'Data8', @employees int = 10 UPDATE account SET employees = @employees WHERE name = @name"; @@ -5258,7 +5264,7 @@ public void UpdateParameters() [TestMethod] public void CountUsesAggregateByDefault() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT count(*) FROM account"; @@ -5282,7 +5288,7 @@ public void CountUsesAggregateByDefault() [TestMethod] public void CountUsesRetrieveTotalRecordCountWithHint() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT count(*) FROM account OPTION (USE HINT ('RETRIEVE_TOTAL_RECORD_COUNT'))"; @@ -5301,7 +5307,7 @@ public void CountUsesRetrieveTotalRecordCountWithHint() [TestMethod] public void MaxDOPUsesDefault() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"UPDATE account SET name = 'test'"; @@ -5316,7 +5322,7 @@ public void MaxDOPUsesDefault() [TestMethod] public void MaxDOPUsesHint() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"UPDATE account SET name = 'test' OPTION (MAXDOP 7)"; @@ -5331,7 +5337,7 @@ public void MaxDOPUsesHint() [TestMethod] public void MaxDOPUsesHintInsideIfBlock() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"IF (1 = 1) BEGIN UPDATE account SET name = 'test' OPTION (MAXDOP 7) END"; @@ -5347,7 +5353,7 @@ public void MaxDOPUsesHintInsideIfBlock() [TestMethod] public void SubqueryUsesSpoolByDefault() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT accountid, (SELECT TOP 1 fullname FROM contact) FROM account"; @@ -5365,7 +5371,7 @@ public void SubqueryUsesSpoolByDefault() [TestMethod] public void SubqueryDoesntUseSpoolWithHint() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT accountid, (SELECT TOP 1 fullname FROM contact) FROM account OPTION (NO_PERFORMANCE_SPOOL)"; @@ -5382,7 +5388,7 @@ public void SubqueryDoesntUseSpoolWithHint() [TestMethod] public void BypassPluginExecutionUsesDefault() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"UPDATE account SET name = 'test'"; @@ -5397,7 +5403,7 @@ public void BypassPluginExecutionUsesDefault() [TestMethod] public void BypassPluginExecutionUsesHint() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"UPDATE account SET name = 'test' OPTION (USE HINT ('BYPASS_CUSTOM_PLUGIN_EXECUTION'))"; @@ -5412,7 +5418,7 @@ public void BypassPluginExecutionUsesHint() [TestMethod] public void PageSizeUsesHint() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT name FROM account OPTION (USE HINT ('FETCHXML_PAGE_SIZE_100'))"; @@ -5434,7 +5440,7 @@ public void PageSizeUsesHint() [TestMethod] public void DistinctOrderByOptionSet() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT DISTINCT new_optionsetvalue FROM new_customentity ORDER BY new_optionsetvalue"; @@ -5457,7 +5463,7 @@ public void DistinctOrderByOptionSet() [TestMethod] public void DistinctVirtualAttribute() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT DISTINCT new_optionsetvaluename FROM new_customentity"; @@ -5482,7 +5488,7 @@ public void DistinctVirtualAttribute() [TestMethod] public void TopAliasStar() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT TOP 10 A.* FROM account A"; @@ -5503,7 +5509,7 @@ public void TopAliasStar() [TestMethod] public void OrderByStar() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account ORDER BY primarycontactid"; @@ -5525,7 +5531,7 @@ public void OrderByStar() [TestMethod] public void UpdateColumnInWhereClause() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "UPDATE account SET name = '1' WHERE name <> '1'"; @@ -5555,7 +5561,7 @@ public void UpdateColumnInWhereClause() [TestMethod] public void NestedOrFilters() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account WHERE name = '1' OR name = '2' OR name = '3' OR name = '4'"; @@ -5583,7 +5589,7 @@ public void NestedOrFilters() [TestMethod] public void UnknownHint() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account OPTION(USE HINT('invalid'))"; @@ -5593,7 +5599,7 @@ public void UnknownHint() [TestMethod] public void MultipleTablesJoinFromWhereClause() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT firstname FROM account, contact WHERE accountid = parentcustomerid AND lastname = 'Carrington' AND name = 'Data8'"; @@ -5622,7 +5628,7 @@ public void MultipleTablesJoinFromWhereClause() [TestMethod] public void MultipleTablesJoinFromWhereClauseReversed() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT firstname FROM account, contact WHERE lastname = 'Carrington' AND name = 'Data8' AND parentcustomerid = accountid"; @@ -5651,7 +5657,7 @@ public void MultipleTablesJoinFromWhereClauseReversed() [TestMethod] public void MultipleTablesJoinFromWhereClause3() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT firstname FROM account, contact, systemuser WHERE accountid = parentcustomerid AND lastname = 'Carrington' AND name = 'Data8' AND account.ownerid = systemuserid"; @@ -5681,7 +5687,7 @@ public void MultipleTablesJoinFromWhereClause3() [TestMethod] public void NestedInSubqueries() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT firstname FROM contact WHERE parentcustomerid IN (SELECT accountid FROM account WHERE primarycontactid IN (SELECT contactid FROM contact WHERE lastname = 'Carrington'))"; @@ -5709,7 +5715,7 @@ public void NestedInSubqueries() [TestMethod] public void SpoolNestedLoop() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT account.name, contact.fullname FROM account INNER JOIN contact ON account.accountid = contact.parentcustomerid OR account.createdon < contact.createdon"; @@ -5744,7 +5750,7 @@ public void SpoolNestedLoop() [TestMethod] public void SelectFromTVF() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM SampleMessage('test')"; @@ -5763,7 +5769,7 @@ public void SelectFromTVF() [TestMethod] public void OuterApplyCorrelatedTVF() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT account.name, msg.OutputParam1 FROM account OUTER APPLY (SELECT * FROM SampleMessage(account.name)) AS msg WHERE account.name = 'Data8'"; @@ -5790,7 +5796,7 @@ public void OuterApplyCorrelatedTVF() [TestMethod] public void OuterApplyUncorrelatedTVF() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT account.name, msg.OutputParam1 FROM account OUTER APPLY (SELECT * FROM SampleMessage('test')) AS msg WHERE account.name = 'Data8'"; @@ -5817,7 +5823,7 @@ public void OuterApplyUncorrelatedTVF() [TestMethod] public void TVFScalarSubqueryParameter() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM SampleMessage((SELECT TOP 1 name FROM account))"; @@ -5858,7 +5864,7 @@ public void TVFScalarSubqueryParameter() [TestMethod] public void ExecuteSproc() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "EXEC SampleMessage 'test'"; @@ -5875,7 +5881,7 @@ public void ExecuteSproc() [TestMethod] public void ExecuteSprocNamedParameters() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"DECLARE @i int EXEC SampleMessage @StringParam = 'test', @OutputParam2 = @i OUTPUT @@ -5895,7 +5901,7 @@ public void FoldMultipleJoinConditionsWithKnownValue() { using (_localDataSource.SetColumnComparison(false)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT a.name, c.fullname FROM account a INNER JOIN contact c ON a.accountid = c.parentcustomerid AND a.name = c.fullname WHERE a.name = 'Data8'"; @@ -5926,7 +5932,7 @@ public void FoldMultipleJoinConditionsWithKnownValue() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void CollationConflict() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM prod.dbo.account p, french.dbo.account f WHERE p.name = f.name"; planBuilder.Build(query, null, out _); } @@ -5934,7 +5940,7 @@ public void CollationConflict() [TestMethod] public void CollationConflictJoin() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM prod.dbo.account p INNER JOIN french.dbo.account f ON p.name = f.name"; try @@ -5951,7 +5957,7 @@ public void CollationConflictJoin() [TestMethod] public void TypeConflictJoin() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM account p INNER JOIN account f ON p.accountid = f.turnover"; try @@ -5968,7 +5974,7 @@ public void TypeConflictJoin() [TestMethod] public void TypeConflictCrossInstanceJoin() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM prod.dbo.account p INNER JOIN french.dbo.account f ON p.accountid = f.turnover"; try @@ -5985,7 +5991,7 @@ public void TypeConflictCrossInstanceJoin() [TestMethod] public void ExplicitCollation() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM prod.dbo.account p, french.dbo.account f WHERE p.name = f.name COLLATE French_CI_AS"; var plans = planBuilder.Build(query, null, out _); @@ -6022,7 +6028,7 @@ public void ExplicitCollation() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void NoCollationSelectListError() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT (CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END) FROM prod.dbo.account p, french.dbo.account f"; planBuilder.Build(query, null, out _); } @@ -6030,7 +6036,7 @@ public void NoCollationSelectListError() [TestMethod] public void NoCollationExprWithExplicitCollationSelectList() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT (CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END) COLLATE Latin1_General_CI_AS FROM prod.dbo.account p, french.dbo.account f"; planBuilder.Build(query, null, out _); } @@ -6039,7 +6045,7 @@ public void NoCollationExprWithExplicitCollationSelectList() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void NoCollationCollationSensitiveFunctionError() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT PATINDEX((CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END), 'a') FROM prod.dbo.account p, french.dbo.account f"; planBuilder.Build(query, null, out _); } @@ -6047,7 +6053,7 @@ public void NoCollationCollationSensitiveFunctionError() [TestMethod] public void NoCollationExprWithExplicitCollationCollationSensitiveFunctionError() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT PATINDEX((CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END) COLLATE Latin1_General_CI_AS, 'a') FROM prod.dbo.account p, french.dbo.account f"; planBuilder.Build(query, null, out _); } @@ -6055,7 +6061,7 @@ public void NoCollationExprWithExplicitCollationCollationSensitiveFunctionError( [TestMethod] public void CollationFunctions() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT *, COLLATIONPROPERTY(name, 'lcid') FROM sys.fn_helpcollations()"; var plans = planBuilder.Build(query, null, out _); @@ -6071,7 +6077,7 @@ public void CollationFunctions() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void DuplicatedTableName() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM account, account"; planBuilder.Build(query, null, out _); } @@ -6080,7 +6086,7 @@ public void DuplicatedTableName() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void DuplicatedTableNameJoin() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM account INNER JOIN contact ON contact.parentcustomerid = account.accountid INNER JOIN contact ON contact.parentcustomerid = account.accountid"; planBuilder.Build(query, null, out _); } @@ -6089,7 +6095,7 @@ public void DuplicatedTableNameJoin() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void DuplicatedAliasName() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM account x INNER JOIN contact x ON x.parentcustomerid = x.accountid"; planBuilder.Build(query, null, out _); } @@ -6098,7 +6104,7 @@ public void DuplicatedAliasName() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void TableNameMatchesAliasName() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM account INNER JOIN contact AS account ON account.parentcustomerid = account.accountid"; planBuilder.Build(query, null, out _); } @@ -6108,7 +6114,7 @@ public void AuditJoinsToCallingUserIdAndUserId() { // https://learn.microsoft.com/en-us/power-apps/developer/data-platform/auditing/retrieve-audit-data?tabs=webapi#audit-table-relationships // Audit table can only be joined to systemuser on callinguserid or userid. Both joins together, or any other joins, are not valid. - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM audit INNER JOIN systemuser cu ON audit.callinguserid = cu.systemuserid INNER JOIN systemuser u ON audit.userid = u.systemuserid"; var plans = planBuilder.Build(query, null, out _); @@ -6146,7 +6152,7 @@ public void AuditJoinsToObjectId() { // https://learn.microsoft.com/en-us/power-apps/developer/data-platform/auditing/retrieve-audit-data?tabs=webapi#audit-table-relationships // Audit table can only be joined to systemuser on callinguserid or userid. Both joins together, or any other joins, are not valid. - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM audit INNER JOIN account ON audit.objectid = account.accountid"; var plans = planBuilder.Build(query, null, out _); @@ -6180,7 +6186,7 @@ public void AuditJoinsToObjectId() public void SelectAuditObjectId() { // https://github.com/MarkMpn/Sql4Cds/issues/296 - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT auditid, objectidtype AS o FROM audit"; var plans = planBuilder.Build(query, null, out _); @@ -6207,7 +6213,7 @@ public void SelectAuditObjectIdDistinct(string column) // https://github.com/MarkMpn/Sql4Cds/issues/519 // We need to add the objecttypecode attribute to the FetchXML for the first issue, but combining this // with DISTINCT doesn't work because of the second issue. - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = $"SELECT DISTINCT {column} AS o FROM audit"; var plans = planBuilder.Build(query, null, out _); @@ -6227,7 +6233,7 @@ public void SelectAuditObjectIdDistinct(string column) [TestMethod] public void FilterAuditOnLeftJoinColumn() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM audit LEFT OUTER JOIN systemuser ON audit.userid = systemuser.systemuserid WHERE systemuser.domainname IS NULL"; var plans = planBuilder.Build(query, null, out _); @@ -6249,7 +6255,7 @@ public void FilterAuditOnLeftJoinColumn() [TestMethod] public void FilterAuditOnUserId() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM audit WHERE userid = 'B52C0694-E16F-40CD-9F27-800023C47A98'"; var plans = planBuilder.Build(query, null, out _); @@ -6272,7 +6278,7 @@ public void FilterAuditOnUserIdAggregate() { // Can do aggregates on audit table when any filtering is done only on the table itself, // not on any joins - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT COUNT(*) FROM audit WHERE userid = 'B52C0694-E16F-40CD-9F27-800023C47A98'"; var plans = planBuilder.Build(query, null, out _); @@ -6305,7 +6311,7 @@ public void FilterAuditOnUserIdAggregate() [TestMethod] public void FilterAuditOnInnerJoinColumn() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM audit INNER JOIN systemuser ON audit.userid = systemuser.systemuserid WHERE systemuser.domainname <> 'SYSTEM'"; var plans = planBuilder.Build(query, null, out _); @@ -6332,7 +6338,7 @@ public void FilterAuditOnInnerJoinColumnAggregate() { // Can't do aggregates on audit table when filtering is done on a join // https://github.com/MarkMpn/Sql4Cds/issues/488 - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT COUNT(*) FROM audit INNER JOIN systemuser ON audit.userid = systemuser.systemuserid WHERE systemuser.domainname <> 'SYSTEM'"; var plans = planBuilder.Build(query, null, out _); @@ -6357,7 +6363,7 @@ public void FilterAuditOnInnerJoinColumnAggregate() [TestMethod] public void SortAuditOnJoinColumn() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM audit LEFT OUTER JOIN systemuser ON audit.userid = systemuser.systemuserid ORDER BY systemuser.domainname"; var plans = planBuilder.Build(query, null, out _); @@ -6379,7 +6385,7 @@ public void SortAuditOnJoinColumn() [TestMethod] public void NestedSubqueries() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = @" SELECT name, (select STUFF((SELECT ', ' + fullname @@ -6403,7 +6409,7 @@ FOR XML PATH ('')), 1, 2, '')) [TestMethod] public void CalculatedColumnUsesEmptyName() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = @"SELECT (select STUFF('abcdef', 2, 3, 'ijklmn'))"; var plans = planBuilder.Build(query, null, out _); @@ -6415,7 +6421,7 @@ public void CalculatedColumnUsesEmptyName() [TestMethod] public void OuterJoinWithFiltersConvertedToInnerJoin() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var outerJoinQuery = "SELECT * FROM account LEFT OUTER JOIN contact ON contact.parentcustomerid = account.accountid WHERE contact.firstname = 'Mark'"; var outerJoinPlans = planBuilder.Build(outerJoinQuery, null, out _); @@ -6437,7 +6443,7 @@ public void OuterJoinWithFiltersConvertedToInnerJoin() public void CorrelatedSubqueryWithMultipleConditions() { // https://github.com/MarkMpn/Sql4Cds/issues/316 - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = @" SELECT r.accountid, r.employees, @@ -6530,10 +6536,14 @@ public void CrossInstanceJoinOnStringColumn() new DataSource { Name = "local", // Hack so that ((IQueryExecutionOptions)this).PrimaryDataSource = "local" doesn't cause test to fail + Connection = _context2.GetOrganizationService(), + Metadata = metadata2, + TableSizeCache = new StubTableSizeCache(), + MessageCache = new StubMessageCache(), DefaultCollation = Collation.USEnglish } }; - var planBuilder = new ExecutionPlanBuilder(datasources, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(datasources.ToDictionary(d => d.Name), this), this); var query = "SELECT uat.name, prod.name FROM uat.dbo.account AS uat INNER JOIN prod.dbo.account AS prod ON uat.name = prod.name"; @@ -6575,7 +6585,7 @@ public void CrossInstanceJoinOnStringColumn() [TestMethod] public void LiftOrFilterToLinkEntityWithInnerJoin() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM contact INNER JOIN account ON contact.parentcustomerid = account.accountid WHERE account.name = 'Data8' OR account.name = 'Data 8'"; var plans = planBuilder.Build(query, null, out _); @@ -6600,7 +6610,7 @@ public void LiftOrFilterToLinkEntityWithInnerJoin() [TestMethod] public void DoNotLiftOrFilterToLinkEntityWithOuterJoin() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM contact LEFT OUTER JOIN account ON contact.parentcustomerid = account.accountid WHERE account.name = 'Data8' OR account.name = 'Data 8'"; var plans = planBuilder.Build(query, null, out _); @@ -6625,7 +6635,7 @@ public void DoNotLiftOrFilterToLinkEntityWithOuterJoin() [TestMethod] public void DoNotLiftOrFilterToLinkEntityWithDifferentEntities() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT * FROM contact LEFT OUTER JOIN account ON contact.parentcustomerid = account.accountid WHERE account.name = 'Data8' OR contact.fullname = 'Mark Carrington'"; var plans = planBuilder.Build(query, null, out _); @@ -6650,7 +6660,7 @@ public void DoNotLiftOrFilterToLinkEntityWithDifferentEntities() [TestMethod] public void FoldSortOrderToInnerJoinLeftInput() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "prod" }), new OptionsWrapper(this) { PrimaryDataSource = "prod" }); var query = "SELECT TOP 10 audit.* FROM contact CROSS APPLY SampleMessage(firstname) AS audit WHERE firstname = 'Mark' ORDER BY contact.createdon;"; var plans = planBuilder.Build(query, null, out _); @@ -6675,7 +6685,7 @@ public void FoldSortOrderToInnerJoinLeftInput() [TestMethod] public void UpdateFromSubquery() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "UPDATE account SET name = 'foo' FROM account INNER JOIN (SELECT name, MIN(createdon) FROM account GROUP BY name HAVING COUNT(*) > 1) AS dupes ON account.name = dupes.name"; @@ -6699,7 +6709,7 @@ public void UpdateFromSubquery() [TestMethod] public void MinPrimaryKey() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT MIN(accountid) FROM account"; @@ -6721,7 +6731,7 @@ public void MinPrimaryKey() [TestMethod] public void MinPicklist() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT MIN(new_optionsetvalue) FROM new_customentity"; @@ -6744,7 +6754,7 @@ public void MinPicklist() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void AvgGuidIsNotSupported() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT AVG(accountid) FROM account"; planBuilder.Build(query, null, out _); } @@ -6752,7 +6762,7 @@ public void AvgGuidIsNotSupported() [TestMethod] public void StringAggWithOrderAndNoGroups() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT STRING_AGG(name, ',') WITHIN GROUP (ORDER BY name DESC) FROM account"; @@ -6775,7 +6785,7 @@ public void StringAggWithOrderAndNoGroups() [TestMethod] public void StringAggWithOrderAndScalarGroups() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT STRING_AGG(name, ',') WITHIN GROUP (ORDER BY name DESC) FROM account GROUP BY employees"; @@ -6800,7 +6810,7 @@ public void StringAggWithOrderAndScalarGroups() [TestMethod] public void StringAggWithOrderAndNonScalarGroups() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT STRING_AGG(name, ',') WITHIN GROUP (ORDER BY name DESC) FROM account GROUP BY name + 'x'"; @@ -6824,7 +6834,7 @@ public void StringAggWithOrderAndNonScalarGroups() [TestMethod] public void NestedExistsAndIn() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "IF NOT EXISTS(SELECT * FROM account WHERE primarycontactid IN (SELECT contactid FROM contact WHERE firstname = 'Mark')) SELECT 1"; @@ -6835,7 +6845,7 @@ public void NestedExistsAndIn() [TestMethod] public void HashJoinUsedForDifferentDataTypes() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account WHERE EXISTS(SELECT * FROM contact WHERE account.name = contact.createdon)"; @@ -6866,7 +6876,7 @@ public void DoNotFoldFilterOnParameterToIndexSpool() { // Subquery on right side of nested loop will use an index spool to reduce number of FetchXML requests. Do not use this logic if the // filter variable is an external parameter or the FetchXML is on the left side of the loop - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "SELECT * FROM account WHERE name = @name and primarycontactid = (SELECT contactid FROM contact WHERE firstname = 'Mark')"; @@ -6917,7 +6927,7 @@ public void DoNotFoldFilterOnParameterToIndexSpool() [TestMethod] public void DoNotFoldJoinsOnReusedAliases() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT s.systemuserid, @@ -6988,7 +6998,7 @@ systemuser AS s [TestMethod] public void ComplexFetchXmlAlias() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "uat" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "uat" }), new OptionsWrapper(this) { PrimaryDataSource = "uat" }); var query = "SELECT name FROM account AS [acc. table]"; @@ -7004,7 +7014,7 @@ public void ComplexFetchXmlAlias() [TestMethod] public void ComplexMetadataAlias() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "uat" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "uat" }), new OptionsWrapper(this) { PrimaryDataSource = "uat" }); var query = "SELECT logicalname FROM metadata.entity AS [m.d. table] WHERE [m.d. table].logicalname = 'account'"; @@ -7023,7 +7033,7 @@ public void ComplexMetadataAlias() [TestMethod] public void ComplexInlineTableAlias() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "uat" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "uat" }), new OptionsWrapper(this) { PrimaryDataSource = "uat" }); var query = "SELECT [full name] FROM (VALUES ('Mark Carrington')) AS [inline table] ([full name])"; @@ -7041,7 +7051,7 @@ public void ComplexInlineTableAlias() [TestMethod] public void FoldFiltersToUnionAllAndJoins() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "uat" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "uat" }), new OptionsWrapper(this) { PrimaryDataSource = "uat" }); var query = @" SELECT [union. all].eln, @@ -7117,7 +7127,7 @@ AND [union. all].logicalname IN ('createdon') [TestMethod] public void PreserveAdditionalFiltersInMetadataJoinConditions() { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "uat" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "uat" }), new OptionsWrapper(this) { PrimaryDataSource = "uat" }); var query = @" SELECT e.logicalname, @@ -7153,7 +7163,7 @@ public void FoldFilterToJoinWithAlias() // Filter is applied to a join with a LHS of FetchXmlScan, which the filter can be entirely folded to // Exception is thrown when trying to fold the remaining null filter to the RHS of the join - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "uat" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "uat" }), new OptionsWrapper(this) { PrimaryDataSource = "uat" }); var query = @" SELECT app.* @@ -7176,7 +7186,7 @@ public void DoNotUseCustomPagingForInJoin() using (_dataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "uat" }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_dataSources, new OptionsWrapper(this) { PrimaryDataSource = "uat" }), new OptionsWrapper(this) { PrimaryDataSource = "uat" }); var query = @" SELECT contactid @@ -7215,7 +7225,7 @@ public void FoldFilterToCorrectTableAlias() // The same table alias can be used in the main query and in a query-derived table. Ensure filters are // folded to the correct one. - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT * @@ -7259,7 +7269,7 @@ public void FoldFilterToCorrectTableAlias() [TestMethod] public void IgnoreDupKeyHint() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"INSERT INTO account (accountid, name) VALUES ('{CD503427-E785-40D8-AD0E-FBDF4918D298}', 'Data8') OPTION (USE HINT ('IGNORE_DUP_KEY'))"; @@ -7274,7 +7284,7 @@ public void IgnoreDupKeyHint() [TestMethod] public void GroupByWithoutAggregateUsesDistinct() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT @@ -7301,7 +7311,7 @@ public void GroupByWithoutAggregateUsesDistinct() [TestMethod] public void FilterOnCrossApply() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select name, n from account @@ -7337,7 +7347,7 @@ cross apply (select name + '' as n) x [TestMethod] public void GotoCantMoveIntoTryBlock() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" GOTO label1 @@ -7364,7 +7374,7 @@ BEGIN CATCH [DataRow(3)] public void UpdateTop(int top) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = $@" UPDATE account @@ -7399,7 +7409,7 @@ AND employees > 0 [TestMethod] public void RethrowMustBeWithinCatchBlock() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = "THROW;"; @@ -7417,7 +7427,7 @@ public void RethrowMustBeWithinCatchBlock() [TestMethod] public void MistypedJoinCriteriaGeneratesWarning() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT a.name, c.fullname @@ -7457,7 +7467,7 @@ public void MistypedJoinCriteriaGeneratesWarning() [TestMethod] public void AliasSameAsVirtualAttribute() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select a.name, c.fullname as primarycontactidname from account a @@ -7483,7 +7493,7 @@ public void AliasSameAsVirtualAttribute() [TestMethod] public void OrderByOptionSetName() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT new_customentityid FROM new_customentity ORDER BY new_optionsetvaluename"; @@ -7507,7 +7517,7 @@ public void OrderByOptionSetValueWithUseRawOrderBy() { using (_localDataSource.SetUseRawOrderByReliable(true)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT new_customentityid FROM new_customentity ORDER BY new_optionsetvalue"; @@ -7530,7 +7540,7 @@ public void OrderByOptionSetValueWithUseRawOrderBy() [TestMethod] public void OrderByOptionSetValueWithoutUseRawOrderBy() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT new_customentityid FROM new_customentity ORDER BY new_optionsetvalue"; @@ -7556,7 +7566,7 @@ public void OrderByOptionSetValueAndName() { using (_localDataSource.SetUseRawOrderByReliable(true)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @"SELECT new_customentityid FROM new_customentity ORDER BY new_optionsetvalue, new_optionsetvaluename"; @@ -7587,7 +7597,7 @@ public void ExistsOrInAndColumnComparisonOrderByEntityName() using (_localDataSource.EnableJoinOperator(JoinOperator.Any)) using (_localDataSource.SetOrderByEntityName(true)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 100 @@ -7646,7 +7656,7 @@ public void ExistsOrInAndColumnComparisonOrderByEntityNameLegacy() { using (_localDataSource.SetColumnComparison(false)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT TOP 100 @@ -7723,7 +7733,7 @@ ORDER BY [TestMethod] public void DistinctUsesCustomPaging() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select distinct @@ -7755,7 +7765,7 @@ from account [TestMethod] public void NotExistWithJoin() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select top 10 a2.name @@ -7824,7 +7834,7 @@ from account a [TestMethod] public void ScalarSubquery() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select top 10 * from ( @@ -7854,7 +7864,7 @@ public void SubqueryInJoinCriteriaRHS() { using (_localDataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select @@ -7887,7 +7897,7 @@ public void SubqueryInJoinCriteriaRHSCorrelatedExists() { using (_localDataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select @@ -7920,7 +7930,7 @@ public void SubqueryInJoinCriteriaLHS() { using (_localDataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select @@ -7953,7 +7963,7 @@ public void SubqueryInJoinCriteriaLHSCorrelatedExists() { using (_localDataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select @@ -7986,7 +7996,7 @@ public void SubqueryInJoinCriteriaLHSAndRHSInnerJoin() { using (_localDataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select @@ -8040,7 +8050,7 @@ public void SubqueryInJoinCriteriaLHSAndRHSOuterJoin() { using (_localDataSource.EnableJoinOperator(JoinOperator.In)) { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select @@ -8102,7 +8112,7 @@ from account [TestMethod] public void VirtualAttributeAliases() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" select statecodename [state], parentcustomerid x, parentcustomeridname from contact"; @@ -8138,7 +8148,7 @@ FROM account as a LEFT JOIN account AS ca ON c.contactid = ca.primarycontactid WHERE c.parentcustomerid IN (SELECT contactid FROM contact WHERE createdon = today())"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var plans = planBuilder.Build(query, null, out _); @@ -8188,7 +8198,7 @@ FROM account as a LEFT JOIN account AS ca ON c.contactid = ca.primarycontactid WHERE c.contactid IN (SELECT contactid FROM contact WHERE createdon = today())"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var plans = planBuilder.Build(query, null, out _); @@ -8236,7 +8246,7 @@ metadata.entity AS p ON a.targets = p.logicalname WHERE a.entitylogicalname IN ('team')"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var plans = planBuilder.Build(query, null, out _); @@ -8260,7 +8270,7 @@ FROM account WHERE createdon >= CAST(GETDATE() AS DATE) AND createdon < DATEADD(day, 1, CAST(GETDATE() AS DATE))"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var plans = planBuilder.Build(query, null, out _); @@ -8293,7 +8303,7 @@ FROM new_customentity WHERE new_name IN ('test')) AS a ORDER BY a.new_customentityid"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var plans = planBuilder.Build(query, null, out _); @@ -8330,7 +8340,7 @@ FROM new_customentity WHERE new_name IN ('test')) AS a ORDER BY a.new_customentityid"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var plans = planBuilder.Build(query, null, out _); @@ -8364,7 +8374,7 @@ LEFT OUTER JOIN (SELECT a.parentaccountid, a.name AS test FROM account a) AS a2 ON a2.parentaccountid = a1.accountid"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var plans = planBuilder.Build(query, null, out _); @@ -8408,7 +8418,7 @@ LEFT OUTER JOIN public void EntityKeyIndexStatusFilter() { // https://github.com/MarkMpn/Sql4Cds/issues/534 - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT ak.entitylogicalname, @@ -8438,7 +8448,7 @@ FROM metadata.alternate_key AS ak public void OuterApplyOuterReference() { // https://github.com/MarkMpn/Sql4Cds/issues/547 - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT * @@ -8478,7 +8488,7 @@ OUTER APPLY ( public void FilterOnOuterApply() { // https://github.com/MarkMpn/Sql4Cds/issues/548 - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT * @@ -8511,7 +8521,7 @@ SELECT IIF(q1.name = 'Test1', 1, 0) AS [flag1], [TestMethod] public void ScalarSubqueryWithoutAlias() { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var query = @" SELECT a.accountid, @@ -8532,6 +8542,132 @@ public void ScalarSubqueryWithoutAlias() +"); + } + + [TestMethod] + public void DeleteByIdUsesConstantScan() + { + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); + + var query = "DELETE FROM account WHERE accountid = '1D3AACA6-DEA4-490F-973E-E4181D4BE11C'"; + + var plans = planBuilder.Build(query, null, out _); + + Assert.AreEqual(1, plans.Length); + + var delete = AssertNode(plans[0]); + var constant = AssertNode(delete.Source); + } + + [TestMethod] + public void DeleteByIdWithGlobalVariableUsesConstantScan() + { + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); + + var query = "DELETE FROM account WHERE accountid = @@IDENTITY"; + + var plans = planBuilder.Build(query, null, out _); + + Assert.AreEqual(1, plans.Length); + + var delete = AssertNode(plans[0]); + var constant = AssertNode(delete.Source); + } + + [TestMethod] + public void DeleteByIdWithVariableUsesConstantScan() + { + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); + + var query = "DELETE FROM account WHERE accountid = @Id"; + var parameterTypes = new Dictionary + { + ["@Id"] = DataTypeHelpers.UniqueIdentifier + }; + var plans = planBuilder.Build(query, parameterTypes, out _); + + Assert.AreEqual(1, plans.Length); + + var delete = AssertNode(plans[0]); + var constant = AssertNode(delete.Source); + } + + [TestMethod] + public void DeleteByNameDoesNotUseConstantScan() + { + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); + + var query = "DELETE FROM account WHERE name = 'Data8'"; + + var plans = planBuilder.Build(query, null, out _); + + Assert.AreEqual(1, plans.Length); + + var delete = AssertNode(plans[0]); + var fetch = AssertNode(delete.Source); + } + + [TestMethod] + public void UpdateByIdUsesConstantScan() + { + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); + + var query = "UPDATE account SET name = 'test' WHERE accountid = '1D3AACA6-DEA4-490F-973E-E4181D4BE11C'"; + + var plans = planBuilder.Build(query, null, out _); + + Assert.AreEqual(1, plans.Length); + + var update = AssertNode(plans[0]); + var computeScalar = AssertNode(update.Source); + var constant = AssertNode(computeScalar.Source); + } + + [TestMethod] + public void UpdateByIdCopyingFieldUsesConstantScan() + { + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); + + var query = "UPDATE account SET name = owneridname WHERE accountid = '1D3AACA6-DEA4-490F-973E-E4181D4BE11C'"; + + var plans = planBuilder.Build(query, null, out _); + + Assert.AreEqual(1, plans.Length); + + var update = AssertNode(plans[0]); + var fetch = AssertNode(update.Source); + } + + [TestMethod] + public void InVariables() + { + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); + + var query = "SELECT name FROM account WHERE accountid IN (@Id1, @Id2)"; + var parameterTypes = new Dictionary + { + ["@Id1"] = DataTypeHelpers.UniqueIdentifier, + ["@Id2"] = DataTypeHelpers.UniqueIdentifier + }; + + var plans = planBuilder.Build(query, parameterTypes, out _); + + Assert.AreEqual(1, plans.Length); + + var select = AssertNode(plans[0]); + var fetch = AssertNode(select.Source); + AssertFetchXml(fetch, @" + + + + + + @Id1 + @Id2 + + + "); } } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExpressionFunctionTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExpressionFunctionTests.cs new file mode 100644 index 00000000..e47e72d0 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExpressionFunctionTests.cs @@ -0,0 +1,362 @@ +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using MarkMpn.Sql4Cds.Engine.ExecutionPlan; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace MarkMpn.Sql4Cds.Engine.Tests +{ + [TestClass] + public class ExpressionFunctionTests + { + [TestMethod] + public void DatePart_Week() + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#week-and-weekday-datepart-arguments + // Assuming default SET DATEFIRST 7 -- ( Sunday ) + var actual = ExpressionFunctions.DatePart("week", (SqlDateTime)new DateTime(2007, 4, 21), DataTypeHelpers.DateTime); + Assert.AreEqual(16, actual); + } + + [TestMethod] + public void DatePart_WeekDay() + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#week-and-weekday-datepart-arguments + // Assuming default SET DATEFIRST 7 -- ( Sunday ) + var actual = ExpressionFunctions.DatePart("weekday", (SqlDateTime)new DateTime(2007, 4, 21), DataTypeHelpers.DateTime); + Assert.AreEqual(7, actual); + } + + [TestMethod] + public void DatePart_TZOffset() + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#tzoffset + var actual = ExpressionFunctions.DatePart("tzoffset", (SqlDateTimeOffset)new DateTimeOffset(2007, 5, 10, 0, 0, 1, TimeSpan.FromMinutes(5 * 60 + 10)), DataTypeHelpers.DateTimeOffset(7)); + Assert.AreEqual(310, actual); + } + + [TestMethod] + public void DatePart_ErrorOnInvalidPartsForTimeValue() + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#default-returned-for-a-datepart-that-isnt-in-a-date-argument + try + { + ExpressionFunctions.DatePart("year", new SqlTime(new TimeSpan(0, 12, 10, 30, 123)), DataTypeHelpers.Time(7)); + Assert.Fail(); + } + catch (QueryExecutionException ex) + { + Assert.AreEqual(9810, ex.Errors.Single().Number); + } + } + + [DataTestMethod] + [DataRow("millisecond", 123)] + [DataRow("microsecond", 123456)] + [DataRow("nanosecond", 123456700)] + public void DatePart_FractionalSeconds(string part, int expected) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datepart-transact-sql?view=sql-server-ver16#fractional-seconds + var actual = ExpressionFunctions.DatePart(part, (SqlString)"00:00:01.1234567", DataTypeHelpers.VarChar(100, Collation.USEnglish, CollationLabel.CoercibleDefault)); + Assert.AreEqual(expected, actual); + } + + [DataTestMethod] + [DataRow("20240830")] + [DataRow("2024-08-31")] + public void DateAdd_MonthLimitedToDaysInFollowingMonth(string date) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#datepart-argument + SqlDateParsing.TryParse(date, DateFormat.mdy, out SqlDateTime startDate); + var actual = ExpressionFunctions.DateAdd("month", 1, startDate, DataTypeHelpers.DateTime); + Assert.AreEqual(new SqlDateTime(2024, 9, 30), (SqlDateTime)actual); + } + + [DataTestMethod] + [DataRow(2147483647)] + [DataRow(-2147483647)] + public void DateAdd_ThrowsIfResultIsOutOfRange(int number) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#date-argument + try + { + ExpressionFunctions.DateAdd("year", number, new SqlDateTime(2024, 7, 31), DataTypeHelpers.DateTime); + } + catch (QueryExecutionException ex) + { + Assert.AreEqual(517, ex.Errors.Single().Number); + } + } + + [DataTestMethod] + [DataRow(-30, 0)] + [DataRow(29, 0)] + [DataRow(-31, -1)] + [DataRow(30, 1)] + public void DateAdd_SmallDateTimeSeconds(int number, int expectedMinutesDifference) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#return-values-for-a-smalldatetime-date-and-a-second-or-fractional-seconds-datepart + var startDateTime = new DateTime(2024, 10, 5); + var actual = ((SqlSmallDateTime)ExpressionFunctions.DateAdd("second", number, new SqlSmallDateTime(startDateTime), DataTypeHelpers.SmallDateTime)).Value; + var expected = startDateTime.AddMinutes(expectedMinutesDifference); + Assert.AreEqual(expected, actual); + } + + [DataTestMethod] + [DataRow(-30001, 0)] + [DataRow(29998, 0)] + [DataRow(-30002, -1)] + [DataRow(29999, 1)] + public void DateAdd_SmallDateTimeMilliSeconds(int number, int expectedMinutesDifference) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#return-values-for-a-smalldatetime-date-and-a-second-or-fractional-seconds-datepart + var startDateTime = new DateTime(2024, 10, 5); + var actual = ((SqlSmallDateTime)ExpressionFunctions.DateAdd("millisecond", number, new SqlSmallDateTime(startDateTime), DataTypeHelpers.SmallDateTime)).Value; + var expected = startDateTime.AddMinutes(expectedMinutesDifference); + Assert.AreEqual(expected, actual); + } + + [DataTestMethod] + [DataRow("millisecond", 1, 1121111)] + [DataRow("millisecond", 2, 1131111)] + [DataRow("microsecond", 1, 1111121)] + [DataRow("microsecond", 2, 1111131)] + [DataRow("nanosecond", 49, 1111111)] + [DataRow("nanosecond", 50, 1111112)] + [DataRow("nanosecond", 150, 1111113)] + public void DateAdd_FractionalSeconds(string datepart, int number, int expected) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#fractional-seconds-precision + var startDateTime = new DateTime(2024, 1, 1, 13, 10, 10).AddTicks(1111111); + var actual = ExpressionFunctions.DateAdd(datepart, number, new SqlDateTime2(startDateTime), DataTypeHelpers.DateTime2(7)).Value; + Assert.AreEqual(expected, actual.Ticks % TimeSpan.TicksPerSecond); + } + + [DataTestMethod] + [DataRow("year", "2025-01-01 13:10:10.1111111")] + [DataRow("quarter", "2024-04-01 13:10:10.1111111")] + [DataRow("month", "2024-02-01 13:10:10.1111111")] + [DataRow("dayofyear", "2024-01-02 13:10:10.1111111")] + [DataRow("day", "2024-01-02 13:10:10.1111111")] + [DataRow("week", "2024-01-08 13:10:10.1111111")] + [DataRow("weekday", "2024-01-02 13:10:10.1111111")] + [DataRow("hour", "2024-01-01 14:10:10.1111111")] + [DataRow("minute", "2024-01-01 13:11:10.1111111")] + [DataRow("second", "2024-01-01 13:10:11.1111111")] + [DataRow("millisecond", "2024-01-01 13:10:10.1121111")] + [DataRow("microsecond", "2024-01-01 13:10:10.1111121")] + [DataRow("nanosecond", "2024-01-01 13:10:10.1111111")] + public void DateAdd_DateParts(string datepart, string expected) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#a-increment-datepart-by-an-interval-of-1 + var startDateTime = new DateTime(2024, 1, 1, 13, 10, 10).AddTicks(1111111); + var actual = ExpressionFunctions.DateAdd(datepart, 1, new SqlDateTime2(startDateTime), DataTypeHelpers.DateTime2(7)).Value; + Assert.AreEqual(expected, actual.ToString("yyyy-MM-dd HH:mm:ss.fffffff")); + } + + [DataTestMethod] + [DataRow("quarter", 4, "2025-01-01 01:01:01.1111111")] + [DataRow("month", 13, "2025-02-01 01:01:01.1111111")] + [DataRow("dayofyear", 366, "2025-01-01 01:01:01.1111111")] // NOTE: Docs used 365, but 2024 is a leap year + [DataRow("day", 366, "2025-01-01 01:01:01.1111111")] // NOTE: Docs used 365, but 2024 is a leap year + [DataRow("week", 5, "2024-02-05 01:01:01.1111111")] + [DataRow("weekday", 31, "2024-02-01 01:01:01.1111111")] + [DataRow("hour", 23, "2024-01-02 00:01:01.1111111")] + [DataRow("minute", 59, "2024-01-01 02:00:01.1111111")] + [DataRow("second", 59, "2024-01-01 01:02:00.1111111")] + [DataRow("millisecond", 1, "2024-01-01 01:01:01.1121111")] + public void DateAdd_Carry(string datepart, int number, string expected) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#b-increment-more-than-one-level-of-datepart-in-one-statement + var startDateTime = new DateTime(2024, 1, 1, 1, 1, 1).AddTicks(1111111); + var actual = ExpressionFunctions.DateAdd(datepart, number, new SqlDateTime2(startDateTime), DataTypeHelpers.DateTime2(7)).Value; + Assert.AreEqual(expected, actual.ToString("yyyy-MM-dd HH:mm:ss.fffffff")); + } + + [DataTestMethod] + [DataRow("microsecond")] + [DataRow("nanosecond")] + public void DateAdd_MicroSecondAndNanoSecondNotSupportedForSmallDateTime(string datepart) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#fractional-seconds-precision + try + { + ExpressionFunctions.DateAdd(datepart, 1, new SqlSmallDateTime(new DateTime(2024, 1, 1)), DataTypeHelpers.SmallDateTime); + Assert.Fail(); + } + catch (QueryExecutionException ex) + { + Assert.AreEqual(9810, ex.Errors.Single().Number); + } + } + + [DataTestMethod] + [DataRow("microsecond")] + [DataRow("nanosecond")] + public void DateAdd_MicroSecondAndNanoSecondNotSupportedForDate(string datepart) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#fractional-seconds-precision + try + { + ExpressionFunctions.DateAdd(datepart, 1, new SqlDate(new DateTime(2024, 1, 1)), DataTypeHelpers.Date); + Assert.Fail(); + } + catch (QueryExecutionException ex) + { + Assert.AreEqual(9810, ex.Errors.Single().Number); + } + } + + [DataTestMethod] + [DataRow("microsecond")] + [DataRow("nanosecond")] + public void DateAdd_MicroSecondAndNanoSecondNotSupportedForDateTime(string datepart) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#fractional-seconds-precision + try + { + ExpressionFunctions.DateAdd(datepart, 1, new SqlDateTime(new DateTime(2024, 1, 1)), DataTypeHelpers.DateTime); + Assert.Fail(); + } + catch (QueryExecutionException ex) + { + Assert.AreEqual(9810, ex.Errors.Single().Number); + } + } + + [DataTestMethod] + [DataRow("weekday")] + [DataRow("tzoffset")] + [DataRow("nanosecond")] + public void DateTrunc_InvalidDateParts(string datepart) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datetrunc-transact-sql?view=sql-server-ver16#datepart + try + { + ExpressionFunctions.DateTrunc(datepart, new SqlDateTime(new DateTime(2024, 1, 1)), DataTypeHelpers.DateTime); + Assert.Fail(); + } + catch (QueryExecutionException ex) + { + Assert.AreEqual(9810, ex.Errors.Single().Number); + } + } + + [DataTestMethod] + // date doesn't support any of the time-based dateparts + [DataRow("date", "hour")] + [DataRow("date", "minute")] + [DataRow("date", "second")] + [DataRow("date", "millisecond")] + [DataRow("date", "microsecond")] + + // datetime doesn't support microsecond + [DataRow("datetime", "microsecond")] + + // smalldatetime doesn't support millisecond or microsecond + [DataRow("smalldatetime", "millisecond")] + [DataRow("smalldatetime", "microsecond")] + + // datetime2, datetimeoffset and time vary depending on scale + [DataRow("datetime2(1)", "millisecond")] + [DataRow("datetime2(1)", "microsecond")] + [DataRow("datetime2(2)", "millisecond")] + [DataRow("datetime2(2)", "microsecond")] + [DataRow("datetime2(3)", "microsecond")] + [DataRow("datetime2(4)", "microsecond")] + [DataRow("datetime2(5)", "microsecond")] + + [DataRow("datetimeoffset(1)", "millisecond")] + [DataRow("datetimeoffset(1)", "microsecond")] + [DataRow("datetimeoffset(2)", "millisecond")] + [DataRow("datetimeoffset(2)", "microsecond")] + [DataRow("datetimeoffset(3)", "microsecond")] + [DataRow("datetimeoffset(4)", "microsecond")] + [DataRow("datetimeoffset(5)", "microsecond")] + + [DataRow("time(1)", "millisecond")] + [DataRow("time(1)", "microsecond")] + [DataRow("time(2)", "millisecond")] + [DataRow("time(2)", "microsecond")] + [DataRow("time(3)", "microsecond")] + [DataRow("time(4)", "microsecond")] + [DataRow("time(5)", "microsecond")] + + // time also doesn't support any date-based dateparts + [DataRow("time", "year")] + [DataRow("time", "quarter")] + [DataRow("time", "month")] + [DataRow("time", "dayofyear")] + [DataRow("time", "day")] + [DataRow("time", "week")] + public void DateTrunc_RequiresMinimalPrecision(string datetype, string datepart) + { + DataTypeHelpers.TryParse(null, datetype, out var type); + + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datetrunc-transact-sql?view=sql-server-ver16#fractional-time-scale-precision + try + { + ExpressionFunctions.DateTrunc(datepart, new SqlDateTime(new DateTime(2024, 1, 1)), type); + Assert.Fail(); + } + catch (QueryExecutionException ex) + { + Assert.AreEqual(9810, ex.Errors.Single().Number); + } + } + + [DataTestMethod] + [DataRow("year", "2021-01-01 00:00:00.0000000")] + [DataRow("quarter", "2021-10-01 00:00:00.0000000")] + [DataRow("month", "2021-12-01 00:00:00.0000000")] + [DataRow("week", "2021-12-05 00:00:00.0000000")] + [DataRow("iso_week", "2021-12-06 00:00:00.0000000")] + [DataRow("dayofyear", "2021-12-08 00:00:00.0000000")] + [DataRow("day", "2021-12-08 00:00:00.0000000")] + [DataRow("hour", "2021-12-08 11:00:00.0000000")] + [DataRow("minute", "2021-12-08 11:30:00.0000000")] + [DataRow("second", "2021-12-08 11:30:15.0000000")] + [DataRow("millisecond", "2021-12-08 11:30:15.1230000")] + [DataRow("microsecond", "2021-12-08 11:30:15.1234560")] + public void DateTrunc_Values(string datepart, string expected) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datetrunc-transact-sql?view=sql-server-ver16#a-use-different-datepart-options + var actual = ExpressionFunctions.DateTrunc(datepart, new SqlDateTime2(new DateTime(2021, 12, 8, 11, 30, 15).AddTicks(1234567)), DataTypeHelpers.DateTime2(7)); + Assert.AreEqual(expected, actual.Value.ToString("yyyy-MM-dd HH:mm:ss.fffffff")); + } + + [DataTestMethod] + [DataRow("year", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("quarter", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("month", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("dayofyear", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("day", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("week", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("weekday", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("hour", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("minute", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("second", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("millisecond", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + [DataRow("microsecond", "2005-12-31 23:59:59.9999999", "2006-01-01 00:00:00.0000000")] + public void DateDiff_1Boundary(string datepart, string startdate, string enddate) + { + // https://learn.microsoft.com/en-us/sql/t-sql/functions/datediff-transact-sql?view=sql-server-ver16#datepart-boundaries + SqlDateParsing.TryParse(startdate, DateFormat.mdy, out SqlDateTimeOffset start); + SqlDateParsing.TryParse(enddate, DateFormat.mdy, out SqlDateTimeOffset end); + var actual = ExpressionFunctions.DateDiff(datepart, start, end, DataTypeHelpers.DateTimeOffset(7), DataTypeHelpers.DateTimeOffset(7)); + Assert.AreEqual(1, actual); + } + + [TestMethod] + public void DateDiff_TimeZone() + { + var dateTime = new DateTime(2024, 10, 5, 12, 0, 0); + var offset = new DateTimeOffset(dateTime, TimeSpan.FromHours(1)); + var actual = ExpressionFunctions.DateDiff("hour", new SqlDateTime(dateTime), new SqlDateTimeOffset(offset), DataTypeHelpers.DateTime, DataTypeHelpers.DateTimeOffset(7)); + Assert.AreEqual(-1, actual); + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExpressionTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExpressionTests.cs index ff6c0273..1d13faec 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExpressionTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExpressionTests.cs @@ -23,7 +23,7 @@ public void StringLiteral() var schema = new NodeSchema(new Dictionary(), new Dictionary>(), null, Array.Empty()); var parameterTypes = new Dictionary(); var options = new StubOptions(); - var compilationContext = new ExpressionCompilationContext(_localDataSources, options, parameterTypes, schema, null); + var compilationContext = new ExpressionCompilationContext(new SessionContext(_localDataSources, options), options, parameterTypes, schema, null); var func = expr.Compile(compilationContext); var actual = func(new ExpressionExecutionContext(compilationContext)); @@ -40,7 +40,7 @@ public void IntegerLiteral() var schema = new NodeSchema(new Dictionary(), new Dictionary>(), null, Array.Empty()); var parameterTypes = new Dictionary(); var options = new StubOptions(); - var compilationContext = new ExpressionCompilationContext(_localDataSources, options, parameterTypes, schema, null); + var compilationContext = new ExpressionCompilationContext(new SessionContext(_localDataSources, options), options, parameterTypes, schema, null); var func = expr.Compile(compilationContext); var actual = func(new ExpressionExecutionContext(compilationContext)); @@ -62,7 +62,7 @@ public void StringConcat() var schema = new NodeSchema(new Dictionary(), new Dictionary>(), null, Array.Empty()); var parameterTypes = new Dictionary(); var options = new StubOptions(); - var compilationContext = new ExpressionCompilationContext(_localDataSources, options, parameterTypes, schema, null); + var compilationContext = new ExpressionCompilationContext(new SessionContext(_localDataSources, options), options, parameterTypes, schema, null); var func = expr.Compile(compilationContext); var actual = func(new ExpressionExecutionContext(compilationContext)); @@ -83,7 +83,7 @@ public void IntegerAddition() var schema = new NodeSchema(new Dictionary(), new Dictionary>(), null, Array.Empty()); var parameterTypes = new Dictionary(); var options = new StubOptions(); - var compilationContext = new ExpressionCompilationContext(_localDataSources, options, parameterTypes, schema, null); + var compilationContext = new ExpressionCompilationContext(new SessionContext(_localDataSources, options), options, parameterTypes, schema, null); var func = expr.Compile(compilationContext); var actual = func(new ExpressionExecutionContext(compilationContext)); @@ -118,7 +118,7 @@ public void SimpleCaseExpression() }, new Dictionary>(), null, Array.Empty()); var parameterTypes = new Dictionary(); var options = new StubOptions(); - var compilationContext = new ExpressionCompilationContext(_localDataSources, options, parameterTypes, schema, null); + var compilationContext = new ExpressionCompilationContext(new SessionContext(_localDataSources, options), options, parameterTypes, schema, null); var func = expr.Compile(compilationContext); var record = new Entity @@ -158,7 +158,7 @@ public void FormatDateTime() }, new Dictionary>(), null, Array.Empty()); var parameterTypes = new Dictionary(); var options = new StubOptions(); - var compilationContext = new ExpressionCompilationContext(_localDataSources, options, parameterTypes, schema, null); + var compilationContext = new ExpressionCompilationContext(new SessionContext(_localDataSources, options), options, parameterTypes, schema, null); var func = expr.Compile(compilationContext); var record = new Entity @@ -186,7 +186,7 @@ public void LikeWithEmbeddedReturns() }, new Dictionary>(), null, Array.Empty()); var parameterTypes = new Dictionary(); var options = new StubOptions(); - var compilationContext = new ExpressionCompilationContext(_localDataSources, options, parameterTypes, schema, null); + var compilationContext = new ExpressionCompilationContext(new SessionContext(_localDataSources, options), options, parameterTypes, schema, null); var func = expr.Compile(compilationContext); var record = new Entity diff --git a/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj b/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj index 04f7900c..3124265e 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj +++ b/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj @@ -77,9 +77,11 @@ + + @@ -101,6 +103,8 @@ + + diff --git a/MarkMpn.Sql4Cds.Engine.Tests/OptionsWrapper.cs b/MarkMpn.Sql4Cds.Engine.Tests/OptionsWrapper.cs index ac29f142..76db91da 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/OptionsWrapper.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/OptionsWrapper.cs @@ -60,6 +60,19 @@ public OptionsWrapper(IQueryExecutionOptions options) public ColumnOrdering ColumnOrdering { get; set; } + public event EventHandler PrimaryDataSourceChanged + { + add + { + _options.PrimaryDataSourceChanged += value; + } + + remove + { + _options.PrimaryDataSourceChanged -= value; + } + } + public void ConfirmDelete(ConfirmDmlStatementEventArgs e) { } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs index 99b7dd18..b1d3ef78 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs @@ -54,12 +54,14 @@ public class Sql2FetchXmlTests : FakeXrmEasyTestsBase, IQueryExecutionOptions ColumnOrdering IQueryExecutionOptions.ColumnOrdering => ColumnOrdering.Alphabetical; + public event EventHandler PrimaryDataSourceChanged; + [TestMethod] public void SimpleSelect() { var query = "SELECT accountid, name FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -77,7 +79,7 @@ public void SelectSameFieldMultipleTimes() { var query = "SELECT accountid, name, name FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -102,7 +104,7 @@ public void SelectStar() { var query = "SELECT * FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -135,7 +137,7 @@ public void SelectStarAndField() { var query = "SELECT *, name FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -169,7 +171,7 @@ public void SimpleFilter() { var query = "SELECT accountid, name FROM account WHERE name = 'test'"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -190,7 +192,7 @@ public void BetweenFilter() { var query = "SELECT accountid, name FROM account WHERE employees BETWEEN 1 AND 10 AND turnover NOT BETWEEN 2 AND 20"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -216,7 +218,7 @@ public void FetchFilter() { var query = "SELECT contactid, firstname FROM contact WHERE createdon = lastxdays(7)"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -237,7 +239,7 @@ public void NestedFilters() { var query = "SELECT accountid, name FROM account WHERE name = 'test' OR (employees is not null and name like 'foo%')"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -262,7 +264,7 @@ public void Sorts() { var query = "SELECT accountid, name FROM account ORDER BY name DESC, accountid"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -282,7 +284,7 @@ public void SortByColumnIndex() { var query = "SELECT accountid, name FROM account ORDER BY 2 DESC, 1"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -302,7 +304,7 @@ public void SortByAliasedColumn() { var query = "SELECT accountid, name as accountname FROM account ORDER BY name"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -321,7 +323,7 @@ public void Top() { var query = "SELECT TOP 10 accountid, name FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -339,7 +341,7 @@ public void TopBrackets() { var query = "SELECT TOP (10) accountid, name FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -357,7 +359,7 @@ public void Top10KUsesExtension() { var query = "SELECT TOP 10000 accountid, name FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -378,7 +380,7 @@ public void NoLock() { var query = "SELECT accountid, name FROM account (NOLOCK)"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -396,7 +398,7 @@ public void Distinct() { var query = "SELECT DISTINCT name FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -414,7 +416,7 @@ public void Offset() { var query = "SELECT accountid, name FROM account ORDER BY name OFFSET 100 ROWS FETCH NEXT 50 ROWS ONLY"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -433,7 +435,7 @@ public void SimpleJoin() { var query = "SELECT accountid, name FROM account INNER JOIN contact ON primarycontactid = contactid"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -453,7 +455,7 @@ public void SelfReferentialJoin() { var query = "SELECT contact.contactid, contact.firstname, manager.firstname FROM contact LEFT OUTER JOIN contact AS manager ON contact.parentcustomerid = manager.contactid"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -474,7 +476,7 @@ public void AdditionalJoinCriteria() { var query = "SELECT accountid, name FROM account INNER JOIN contact ON accountid = parentcustomerid AND (firstname = 'Mark' OR lastname = 'Carrington')"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -498,7 +500,7 @@ public void InvalidAdditionalJoinCriteria() { var query = "SELECT accountid, name FROM account INNER JOIN contact ON accountid = parentcustomerid OR (firstname = 'Mark' AND lastname = 'Carrington')"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); Assert.IsNotInstanceOfType(((SelectNode)queries[0]).Source, typeof(FetchXmlScan)); @@ -509,7 +511,7 @@ public void SortOnLinkEntity() { var query = "SELECT TOP 100 accountid, name FROM account INNER JOIN contact ON primarycontactid = contactid ORDER BY name, firstname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -531,7 +533,7 @@ public void InvalidSortOnLinkEntity() { var query = "SELECT TOP 100 accountid, name FROM account INNER JOIN contact ON accountid = parentcustomerid ORDER BY name, firstname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -558,7 +560,7 @@ public void SimpleAggregate() { var query = "SELECT count(*), count(name), count(DISTINCT name), max(name), min(name), avg(employees) FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -580,7 +582,7 @@ public void GroupBy() { var query = "SELECT name, count(*) FROM account GROUP BY name"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -599,7 +601,7 @@ public void GroupBySorting() { var query = "SELECT name, count(*) FROM account GROUP BY name ORDER BY name, count(*)"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -619,7 +621,7 @@ public void GroupBySortingOnLinkEntity() { var query = "SELECT name, firstname, count(*) FROM account INNER JOIN contact ON parentcustomerid = account.accountid GROUP BY name, firstname ORDER BY firstname, name"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -642,7 +644,7 @@ public void GroupBySortingOnAliasedAggregate() { var query = "SELECT name, firstname, count(*) as count FROM account INNER JOIN contact ON parentcustomerid = account.accountid GROUP BY name, firstname ORDER BY count"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -664,7 +666,7 @@ public void UpdateFieldToValue() { var query = "UPDATE contact SET firstname = 'Mark'"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -681,7 +683,7 @@ public void SelectArithmetic() { var query = "SELECT employees + 1 AS a, employees * 2 AS b, turnover / 3 AS c, turnover - 4 AS d, turnover / employees AS e FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -704,7 +706,7 @@ public void SelectArithmetic() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -723,7 +725,7 @@ public void WhereComparingTwoFields() { var query = "SELECT contactid FROM contact WHERE firstname = lastname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -754,7 +756,7 @@ public void WhereComparingTwoFields() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -768,7 +770,7 @@ public void WhereComparingExpression() { var query = "SELECT contactid FROM contact WHERE lastname = firstname + 'rington'"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -799,7 +801,7 @@ public void WhereComparingExpression() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -812,7 +814,7 @@ public void BackToFrontLikeExpression() { var query = "SELECT contactid FROM contact WHERE 'Mark' like firstname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -840,7 +842,7 @@ public void BackToFrontLikeExpression() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -853,7 +855,7 @@ public void UpdateFieldToField() { var query = "UPDATE contact SET firstname = lastname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -875,7 +877,7 @@ public void UpdateFieldToField() } }; - ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), out _, out _); + ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), out _, out _); Assert.AreEqual("Carrington", _context.Data["contact"][guid]["firstname"]); } @@ -885,7 +887,7 @@ public void UpdateFieldToExpression() { var query = "UPDATE contact SET firstname = 'Hello ' + lastname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -907,7 +909,7 @@ public void UpdateFieldToExpression() } }; - ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), out _, out _); + ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), out _, out _); Assert.AreEqual("Hello Carrington", _context.Data["contact"][guid]["firstname"]); } @@ -917,7 +919,7 @@ public void UpdateReplace() { var query = "UPDATE contact SET firstname = REPLACE(firstname, 'Dataflex Pro', 'CDS') WHERE lastname = 'Carrington'"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -943,7 +945,7 @@ public void UpdateReplace() } }; - ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), out _, out _); + ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), out _, out _); Assert.AreEqual("--CDS--", _context.Data["contact"][guid]["firstname"]); } @@ -953,7 +955,7 @@ public void StringFunctions() { var query = "SELECT trim(firstname) as trim, ltrim(firstname) as ltrim, rtrim(firstname) as rtrim, substring(firstname, 2, 3) as substring23, len(firstname) as len FROM contact"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -974,7 +976,7 @@ public void StringFunctions() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -993,7 +995,7 @@ public void SelectExpression() { var query = "SELECT firstname, 'Hello ' + firstname AS greeting FROM contact"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1014,7 +1016,7 @@ public void SelectExpression() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1043,7 +1045,7 @@ public void SelectExpressionNullValues() { var query = "SELECT firstname, 'Hello ' + firstname AS greeting, case when createdon > '2020-01-01' then 'new' else 'old' end AS age FROM contact"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1064,7 +1066,7 @@ public void SelectExpressionNullValues() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1079,7 +1081,7 @@ public void OrderByExpression() { var query = "SELECT firstname, lastname FROM contact ORDER BY lastname + ', ' + firstname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1109,7 +1111,7 @@ public void OrderByExpression() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1123,7 +1125,7 @@ public void OrderByAliasedField() { var query = "SELECT firstname, lastname AS surname FROM contact ORDER BY surname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1154,7 +1156,7 @@ public void OrderByAliasedField() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1168,7 +1170,7 @@ public void OrderByCalculatedField() { var query = "SELECT firstname, lastname, lastname + ', ' + firstname AS fullname FROM contact ORDER BY fullname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1198,7 +1200,7 @@ public void OrderByCalculatedField() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1212,7 +1214,7 @@ public void OrderByCalculatedFieldByIndex() { var query = "SELECT firstname, lastname, lastname + ', ' + firstname AS fullname FROM contact ORDER BY 3"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1242,7 +1244,7 @@ public void OrderByCalculatedFieldByIndex() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1256,7 +1258,7 @@ public void DateCalculations() { var query = "SELECT contactid, DATEADD(day, 1, createdon) AS nextday, DATEPART(minute, createdon) AS minute FROM contact WHERE DATEDIFF(hour, '2020-01-01', createdon) < 1"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1284,7 +1286,7 @@ public void DateCalculations() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1301,7 +1303,7 @@ public void TopAppliedAfterCustomFilter() { var query = "SELECT TOP 10 contactid FROM contact WHERE firstname = lastname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1324,7 +1326,7 @@ public void CustomFilterAggregateHavingProjectionSortAndTop() { var query = "SELECT TOP 10 lastname, SUM(CASE WHEN firstname = 'Mark' THEN 1 ELSE 0 END) as nummarks, LEFT(lastname, 1) AS lastinitial FROM contact WHERE DATEDIFF(day, '2020-01-01', createdon) > 10 GROUP BY lastname HAVING count(*) > 1 ORDER BY 2 DESC"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1372,7 +1374,7 @@ public void CustomFilterAggregateHavingProjectionSortAndTop() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1392,7 +1394,7 @@ public void FilterCaseInsensitive() { var query = "SELECT contactid FROM contact WHERE DATEDIFF(day, '2020-01-01', createdon) < 10 OR lastname = 'Carrington' ORDER BY createdon"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1435,7 +1437,7 @@ public void FilterCaseInsensitive() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1450,7 +1452,7 @@ public void GroupCaseInsensitive() { var query = "SELECT lastname, count(*) FROM contact WHERE DATEDIFF(day, '2020-01-01', createdon) > 10 GROUP BY lastname ORDER BY 2 DESC"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1492,7 +1494,7 @@ public void GroupCaseInsensitive() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1507,7 +1509,7 @@ public void AggregateExpressionsWithoutGrouping() { var query = "SELECT count(DISTINCT firstname + ' ' + lastname) AS distinctnames FROM contact"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1545,7 +1547,7 @@ public void AggregateExpressionsWithoutGrouping() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1559,7 +1561,7 @@ public void AggregateQueryProducesAlternative() { var query = "SELECT name, count(*) FROM account GROUP BY name ORDER BY 2 DESC"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var select = (SelectNode)queries[0]; @@ -1600,7 +1602,7 @@ public void AggregateQueryProducesAlternative() } }; - var dataReader = alternativeQuery.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = alternativeQuery.Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1617,7 +1619,7 @@ public void GuidEntityReferenceInequality() { var query = "SELECT a.name FROM account a INNER JOIN contact c ON a.primarycontactid = c.contactid WHERE (c.parentcustomerid is null or a.accountid <> c.parentcustomerid)"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var select = (SelectNode)queries[0]; @@ -1656,7 +1658,7 @@ public void GuidEntityReferenceInequality() } }; - var dataReader = select.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = select.Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1671,7 +1673,7 @@ public void UpdateGuidToEntityReference() { var query = "UPDATE a SET primarycontactid = c.contactid FROM account AS a INNER JOIN contact AS c ON a.accountid = c.parentcustomerid"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var update = (UpdateNode)queries[0]; @@ -1708,7 +1710,7 @@ public void UpdateGuidToEntityReference() } }; - update.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), out _, out _); + update.Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), out _, out _); Assert.AreEqual(new EntityReference("contact", contact1), _context.Data["account"][account1].GetAttributeValue("primarycontactid")); Assert.AreEqual(new EntityReference("contact", contact2), _context.Data["account"][account2].GetAttributeValue("primarycontactid")); @@ -1719,7 +1721,7 @@ public void CompareDateFields() { var query = "DELETE c2 FROM contact c1 INNER JOIN contact c2 ON c1.parentcustomerid = c2.parentcustomerid AND c2.createdon > c1.createdon"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1744,7 +1746,7 @@ public void ColumnComparison() { var query = "SELECT firstname, lastname FROM contact WHERE firstname = lastname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1767,7 +1769,7 @@ public void QuotedIdentifierError() try { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, new OptionsWrapper(this) { QuotedIdentifiers = true }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, new OptionsWrapper(this) { QuotedIdentifiers = true }), new OptionsWrapper(this) { QuotedIdentifiers = true }); var queries = planBuilder.Build(query, null, out _); Assert.Fail("Expected exception"); @@ -1783,7 +1785,7 @@ public void FilterExpressionConstantValueToFetchXml() { var query = "SELECT firstname, lastname FROM contact WHERE firstname = 'Ma' + 'rk'"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1804,7 +1806,7 @@ public void Count1ConvertedToCountStar() { var query = "SELECT COUNT(1) FROM contact OPTION(USE HINT('RETRIEVE_TOTAL_RECORD_COUNT'))"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var selectNode = (SelectNode)queries[0]; @@ -1817,7 +1819,7 @@ public void CaseInsensitive() { var query = "Select Name From Account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1834,7 +1836,7 @@ public void ContainsValues1() { var query = "SELECT new_name FROM new_customentity WHERE CONTAINS(new_optionsetvaluecollection, '1')"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1856,7 +1858,7 @@ public void ContainsValuesFunction1() { var query = "SELECT new_name FROM new_customentity WHERE new_optionsetvaluecollection = containvalues(1)"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1878,7 +1880,7 @@ public void ContainsValues() { var query = "SELECT new_name FROM new_customentity WHERE CONTAINS(new_optionsetvaluecollection, '1 OR 2')"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1901,7 +1903,7 @@ public void ContainsValuesFunction() { var query = "SELECT new_name FROM new_customentity WHERE new_optionsetvaluecollection = containvalues(1, 2)"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1924,7 +1926,7 @@ public void NotContainsValues() { var query = "SELECT new_name FROM new_customentity WHERE NOT CONTAINS(new_optionsetvaluecollection, '1 OR 2')"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1961,7 +1963,7 @@ public void ImplicitTypeConversion() { var query = "SELECT employees / 2.0 AS half FROM account"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var account1 = Guid.NewGuid(); @@ -1981,7 +1983,7 @@ public void ImplicitTypeConversion() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1994,7 +1996,7 @@ public void ImplicitTypeConversionComparison() { var query = "SELECT accountid FROM account WHERE turnover / 2 > 10"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var account1 = Guid.NewGuid(); @@ -2014,7 +2016,7 @@ public void ImplicitTypeConversionComparison() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2026,7 +2028,7 @@ public void GlobalOptionSet() { var query = "SELECT displayname FROM metadata.globaloptionset WHERE name = 'test'"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); Assert.IsInstanceOfType(queries.Single(), typeof(SelectNode)); @@ -2038,7 +2040,7 @@ public void GlobalOptionSet() Assert.AreEqual("globaloptionset.name = 'test'", filterNode.Filter.ToNormalizedSql()); var optionsetNode = (GlobalOptionSetQueryNode)filterNode.Source; - var dataReader = selectNode.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = selectNode.Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2050,7 +2052,7 @@ public void EntityDetails() { var query = "SELECT logicalname FROM metadata.entity ORDER BY 1"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); Assert.IsInstanceOfType(queries.Single(), typeof(SelectNode)); @@ -2059,7 +2061,7 @@ public void EntityDetails() var sortNode = (SortNode)selectNode.Source; var metadataNode = (MetadataQueryNode)sortNode.Source; - var dataReader = selectNode.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = selectNode.Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2074,10 +2076,10 @@ public void AttributeDetails() { var query = "SELECT e.logicalname, a.logicalname FROM metadata.entity e INNER JOIN metadata.attribute a ON e.logicalname = a.entitylogicalname WHERE e.logicalname = 'new_customentity' ORDER BY 1, 2"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2099,7 +2101,7 @@ public void OptionSetNameSelect() { var query = "SELECT new_optionsetvalue, new_optionsetvaluename FROM new_customentity ORDER BY new_optionsetvaluename"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var record1 = Guid.NewGuid(); @@ -2135,7 +2137,7 @@ public void OptionSetNameSelect() CollectionAssert.AreEqual(new[] { "new_optionsetvalue", "new_optionsetvaluename" }, select.ColumnSet.Select(c => c.OutputColumn).ToList()); - var dataReader = select.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = select.Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2150,7 +2152,7 @@ public void OptionSetNameFilter() { var query = "SELECT new_customentityid FROM new_customentity WHERE new_optionsetvaluename = 'Value1'"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2169,7 +2171,7 @@ public void EntityReferenceNameSelect() { var query = "SELECT primarycontactid, primarycontactidname FROM account ORDER BY primarycontactidname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2190,7 +2192,7 @@ public void EntityReferenceNameFilter() { var query = "SELECT accountid FROM account WHERE primarycontactidname = 'Mark Carrington'"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2209,7 +2211,7 @@ public void UpdateMissingAlias() { var query = "UPDATE account SET primarycontactid = c.contactid FROM account AS a INNER JOIN contact AS c ON a.name = c.fullname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); } @@ -2220,7 +2222,7 @@ public void UpdateMissingAliasAmbiguous() try { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); Assert.Fail("Expected exception"); } @@ -2235,7 +2237,7 @@ public void ConvertIntToBool() { var query = "UPDATE new_customentity SET new_boolprop = CASE WHEN new_name = 'True' THEN 1 ELSE 0 END"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); } @@ -2246,7 +2248,7 @@ public void ImpersonateRevert() EXECUTE AS LOGIN = 'test1' REVERT"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); Assert.IsInstanceOfType(queries[0], typeof(ExecuteAsNode)); @@ -2271,7 +2273,7 @@ SELECT contact.fullname FROM contact INNER JOIN account ON contact.contactid = account.primarycontactid INNER JOIN new_customentity ON contact.parentcustomerid = new_customentity.new_parentid ORDER BY account.employees, contact.fullname"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(new[] { queries[0] }, @" @@ -2361,7 +2363,7 @@ private void BuildTDSQuery(Action action) ds.Connection = null; try { - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, new OptionsWrapper(this) { UseTDSEndpoint = true }); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, new OptionsWrapper(this) { UseTDSEndpoint = true }), new OptionsWrapper(this) { UseTDSEndpoint = true }); action(planBuilder); } finally @@ -2375,7 +2377,7 @@ public void OrderByAggregateByIndex() { var query = "SELECT firstname, count(*) FROM contact GROUP BY firstname ORDER BY 2"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2394,7 +2396,7 @@ public void OrderByAggregateJoinByIndex() { var query = "SELECT firstname, count(*) FROM contact INNER JOIN account ON contact.parentcustomerid = account.accountid GROUP BY firstname ORDER BY 2"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2415,7 +2417,7 @@ public void AggregateAlternativeDoesNotOrderByLinkEntity() { var query = "SELECT name, count(*) FROM contact INNER JOIN account ON contact.parentcustomerid = account.accountid GROUP BY name"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var select = (SelectNode)queries[0]; @@ -2445,7 +2447,7 @@ public void CharIndex() { var query = "SELECT CHARINDEX('a', fullname) AS ci0, CHARINDEX('a', fullname, 1) AS ci1, CHARINDEX('a', fullname, 2) AS ci2, CHARINDEX('a', fullname, 3) AS ci3, CHARINDEX('a', fullname, 8) AS ci8 FROM contact"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var contact1 = Guid.NewGuid(); @@ -2459,7 +2461,7 @@ public void CharIndex() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2475,7 +2477,7 @@ public void CastDateTimeToDate() { var query = "SELECT CAST(createdon AS date) AS converted FROM contact"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var contact1 = Guid.NewGuid(); @@ -2489,7 +2491,7 @@ public void CastDateTimeToDate() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2501,7 +2503,7 @@ public void GroupByPrimaryFunction() { var query = "SELECT left(firstname, 1) AS initial, count(*) AS count FROM contact GROUP BY left(firstname, 1) ORDER BY 2 DESC"; - var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this); + var planBuilder = new ExecutionPlanBuilder(new SessionContext(_localDataSources, this), this); var queries = planBuilder.Build(query, null, out _); var contact1 = Guid.NewGuid(); @@ -2527,7 +2529,7 @@ public void GroupByPrimaryFunction() } }; - var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(new SessionContext(GetDataSources(_context), this), this, new Dictionary(), new Dictionary(), null), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); diff --git a/MarkMpn.Sql4Cds.Engine.Tests/SqlDateTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/SqlDateTests.cs new file mode 100644 index 00000000..d459d885 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine.Tests/SqlDateTests.cs @@ -0,0 +1,127 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace MarkMpn.Sql4Cds.Engine.Tests +{ + [TestClass] + public class SqlDateTests + { + [DataTestMethod] + [DataRow("[M]M/dd/[yy]yy", "M/d/yy;M/d/yyyy;M/dd/yy;M/dd/yyyy;MM/d/yy;MM/d/yyyy;MM/dd/yy;MM/dd/yyyy")] + public void ConvertSqlFormatStringToNet(string input, string expected) + { + var actual = SqlDateParsing.SqlToNetFormatString(input); + CollectionAssert.AreEquivalent(expected.Split(';'), actual); + } + + [DataTestMethod] + [DataRow("4/21/2007")] + [DataRow("4-21-2007")] + [DataRow("4.21.2007")] + [DataRow("Apr 21, 2007")] + [DataRow("Apr 2007 21")] + [DataRow("21 April, 2007")] + [DataRow("21 2007 Apr")] + [DataRow("2007 April 21")] + [DataRow("2007-04-21")] + [DataRow("20070421")] + public void ParseMDY(string input) + { + Assert.IsTrue(SqlDate.TryParse(input, DateFormat.mdy, out var actual)); + Assert.AreEqual(new DateTime(2007, 4, 21), actual.Value); + } + + [DataTestMethod] + [DataRow("4/2007/21")] + [DataRow("4-2007-21")] + [DataRow("4.2007.21")] + [DataRow("Apr 21, 2007")] + [DataRow("Apr 2007 21")] + [DataRow("21 April, 2007")] + [DataRow("21 2007 Apr")] + [DataRow("2007 April 21")] + [DataRow("2007-04-21")] + [DataRow("20070421")] + public void ParseMYD(string input) + { + Assert.IsTrue(SqlDate.TryParse(input, DateFormat.myd, out var actual)); + Assert.AreEqual(new DateTime(2007, 4, 21), actual.Value); + } + + [DataTestMethod] + [DataRow("21/4/2007")] + [DataRow("21-4-2007")] + [DataRow("21.4.2007")] + [DataRow("Apr 21, 2007")] + [DataRow("Apr 2007 21")] + [DataRow("21 April, 2007")] + [DataRow("21 2007 Apr")] + [DataRow("2007 April 21")] + [DataRow("2007-04-21")] + [DataRow("20070421")] + public void ParseDMY(string input) + { + Assert.IsTrue(SqlDate.TryParse(input, DateFormat.dmy, out var actual)); + Assert.AreEqual(new DateTime(2007, 4, 21), actual.Value); + } + + [DataTestMethod] + [DataRow("21/2007/4")] + [DataRow("21-2007-4")] + [DataRow("21.2007.4")] + [DataRow("Apr 21, 2007")] + [DataRow("Apr 2007 21")] + [DataRow("21 April, 2007")] + [DataRow("21 2007 Apr")] + [DataRow("2007 April 21")] + [DataRow("2007-04-21")] + [DataRow("20070421")] + public void ParseDYM(string input) + { + Assert.IsTrue(SqlDate.TryParse(input, DateFormat.dym, out var actual)); + Assert.AreEqual(new DateTime(2007, 4, 21), actual.Value); + } + + [DataTestMethod] + [DataRow("2007/4/21")] + [DataRow("2007-4-21")] + [DataRow("2007.4.21")] + [DataRow("Apr 21, 2007")] + [DataRow("Apr 2007 21")] + [DataRow("21 April, 2007")] + [DataRow("21 2007 Apr")] + [DataRow("2007 April 21")] + [DataRow("2007-04-21")] + [DataRow("20070421")] + public void ParseYMD(string input) + { + Assert.IsTrue(SqlDate.TryParse(input, DateFormat.ymd, out var actual)); + Assert.AreEqual(new DateTime(2007, 4, 21), actual.Value); + } + + [DataTestMethod] + [DataRow("12:30:45")] // TIME only + //[DataRow("+02:00")] // TIMEZONE only - mentioned as valid in documentation but doesn't work in practise + [DataRow("12:30:45+02:00")] // TIME + TIMEZONE + public void UseDefaultValuesForTimeStrings(string input) + { + // https://learn.microsoft.com/en-us/sql/t-sql/data-types/date-transact-sql?view=sql-server-ver16#convert-string-literals-to-date + Assert.IsTrue(SqlDate.TryParse(input, DateFormat.ymd, out var actual)); + Assert.AreEqual(new DateTime(1900, 1, 1), actual.Value); + } + + [DataTestMethod] + [DataRow("2007-04-21 12:30:45")] // DATE + TIME + [DataRow("2007-04-21 12:30:45+02:00")] // DATE + TIME + TIMEZONE + public void IgnoreTime(string input) + { + // https://learn.microsoft.com/en-us/sql/t-sql/data-types/date-transact-sql?view=sql-server-ver16#convert-string-literals-to-date + Assert.IsTrue(SqlDate.TryParse(input, DateFormat.ymd, out var actual)); + Assert.AreEqual(new DateTime(2007, 4, 21), actual.Value); + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine.Tests/SqlDateTimeTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/SqlDateTimeTests.cs new file mode 100644 index 00000000..dd132651 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine.Tests/SqlDateTimeTests.cs @@ -0,0 +1,129 @@ +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace MarkMpn.Sql4Cds.Engine.Tests +{ + [TestClass] + public class SqlDateTimeTests + { + [DataTestMethod] + [DataRow("4/15/96", "mdy")] + [DataRow("04/15/96", "mdy")] + [DataRow("4/15/1996", "mdy")] + [DataRow("04/15/1996", "mdy")] + + [DataRow("4-15-96", "mdy")] + [DataRow("04-15-96", "mdy")] + [DataRow("4-15-1996", "mdy")] + [DataRow("04-15-1996", "mdy")] + + [DataRow("4.15.96", "mdy")] + [DataRow("04.15.96", "mdy")] + [DataRow("4.15.1996", "mdy")] + [DataRow("04.15.1996", "mdy")] + + [DataRow("4/96/15", "myd")] + [DataRow("04/96/15", "myd")] + [DataRow("4/1996/15", "myd")] + [DataRow("04/1996/15", "myd")] + + [DataRow("15/4/96", "dmy")] + [DataRow("15/04/96", "dmy")] + [DataRow("15/4/1996", "dmy")] + [DataRow("15/04/1996", "dmy")] + + [DataRow("15/96/4", "dym")] + [DataRow("15/96/04", "dym")] + [DataRow("15/1996/4", "dym")] + [DataRow("15/1996/04", "dym")] + + [DataRow("96/15/4", "ydm")] + [DataRow("96/15/04", "ydm")] + [DataRow("1996/15/4", "ydm")] + [DataRow("1996/15/04", "ydm")] + + [DataRow("96/4/15", "ymd")] + [DataRow("96/04/15", "ymd")] + [DataRow("1996/4/15", "ymd")] + [DataRow("1996/04/15", "ymd")] + public void NumericDateFormat(string input, string order) + { + var format = (DateFormat)Enum.Parse(typeof(DateFormat), order); + Assert.IsTrue(SqlDateParsing.TryParse(input, format, out SqlDateTime actual)); + Assert.AreEqual(new DateTime(1996, 4, 15), actual.Value); + } + + [DataTestMethod] + [DataRow("04/15/1996 14:30", "14:30:00")] + [DataRow("04/15/1996 14:30:20", "14:30:20")] + [DataRow("04/15/1996 14:30:20:997", "14:30:20.997")] + [DataRow("04/15/1996 14:30:20.9", "14:30:20.9")] + [DataRow("04/15/1996 4am", "04:00:00")] + [DataRow("04/15/1996 4 PM", "16:00:00")] + public void NumericDateFormatWithTime(string input, string expectedTime) + { + Assert.IsTrue(SqlDateParsing.TryParse(input, DateFormat.mdy, out SqlDateTime actual)); + Assert.AreEqual(new DateTime(1996, 4, 15), actual.Value.Date); + Assert.AreEqual(TimeSpan.Parse(expectedTime), actual.Value.TimeOfDay); + } + + [DataTestMethod] + [DataRow("Apr 1996", false)] + [DataRow("April 1996", false)] + [DataRow("April 15 1996", true)] + [DataRow("April 15, 1996", true)] + [DataRow("April 15 96", true)] + [DataRow("April 15, 96", true)] + [DataRow("Apr 1996 15", true)] + [DataRow("April 1996 15", true)] + [DataRow("Apr, 1996", false)] + [DataRow("April, 1996", false)] + [DataRow("15 Apr, 1996", true)] + [DataRow("15 April, 1996", true)] + [DataRow("15 Apr,1996", true)] + [DataRow("15 April,1996", true)] + [DataRow("15 Apr,96", true)] + [DataRow("15 April,96", true)] + [DataRow("15 Apr96", true)] + [DataRow("15 April96", true)] + [DataRow("15 96 apr", true)] + [DataRow("15 96 april", true)] + [DataRow("15 1996 apr", true)] + [DataRow("15 1996 april", true)] + [DataRow("1996 apr", false)] + [DataRow("1996 april", false)] + [DataRow("1996 apr 15", true)] + [DataRow("1996 april 15", true)] + [DataRow("1996 15 apr", true)] + [DataRow("1996 15 april", true)] + public void AlphaFormat(string input, bool includesDay) + { + Assert.IsTrue(SqlDateParsing.TryParse(input, DateFormat.mdy, out SqlDateTime actual)); + + if (includesDay) + Assert.AreEqual(new DateTime(1996, 4, 15), actual.Value); + else + Assert.AreEqual(new DateTime(1996, 4, 1), actual.Value); + } + + [DataTestMethod] + [DataRow("2004-05-23T14:25:10", false)] + [DataRow("2004-05-23T14:25:10.487", true)] + [DataRow("20040523 14:25:10", false)] + [DataRow("20040523 14:25:10.487", true)] + public void IsoFormat(string input, bool milli) + { + Assert.IsTrue(SqlDateParsing.TryParse(input, DateFormat.mdy, out SqlDateTime actual)); + + if (milli) + Assert.AreEqual(new DateTime(2004, 5, 23, 14, 25, 10, 487), actual.Value); + else + Assert.AreEqual(new DateTime(2004, 5, 23, 14, 25, 10), actual.Value); + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine.Tests/SqlTimeTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/SqlTimeTests.cs new file mode 100644 index 00000000..967b26a6 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine.Tests/SqlTimeTests.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace MarkMpn.Sql4Cds.Engine.Tests +{ + [TestClass] + public class SqlTimeTests + { + //public void ParseTime + } +} diff --git a/MarkMpn.Sql4Cds.Engine.Tests/SqlVariantTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/SqlVariantTests.cs index 5d610416..ff81d206 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/SqlVariantTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/SqlVariantTests.cs @@ -4,19 +4,33 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using Microsoft.SqlServer.TransactSql.ScriptDom; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace MarkMpn.Sql4Cds.Engine.Tests { [TestClass] - public class SqlVariantTests + public class SqlVariantTests : FakeXrmEasyTestsBase { + private readonly ExpressionExecutionContext _eec; + + public SqlVariantTests() + { + _eec = new ExpressionExecutionContext( + session: new SessionContext(_localDataSources, new StubOptions()), + options: new StubOptions(), + parameterTypes: new Dictionary(), + parameterValues: new Dictionary(), + log: e => { }, + entity: null); + } + [TestMethod] public void NullSortsBeforeAllOtherValues() { Assert.IsTrue(SqlVariant.Null.CompareTo(SqlVariant.Null) == 0); - Assert.IsTrue(SqlVariant.Null.CompareTo(new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1))) < 0); - Assert.IsTrue(new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1)).CompareTo(SqlVariant.Null) > 0); + Assert.IsTrue(SqlVariant.Null.CompareTo(new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1), _eec)) < 0); + Assert.IsTrue(new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1), _eec).CompareTo(SqlVariant.Null) > 0); } [TestMethod] @@ -28,26 +42,26 @@ public void NullDoesNotEqualNull() [TestMethod] public void ValuesFromDifferentFamiliesAreNotEqual() { - Assert.AreEqual(new SqlVariant(DataTypeHelpers.VarChar(1, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString("1")) == new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1)), (SqlBoolean)false); + Assert.AreEqual(new SqlVariant(DataTypeHelpers.VarChar(1, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString("1"), _eec) == new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1), _eec), (SqlBoolean)false); } [TestMethod] public void ValuesFromDifferentTypesInSameFamilyAreEqual() { - Assert.AreEqual(new SqlVariant(DataTypeHelpers.BigInt, new SqlInt64(1)) == new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1)), (SqlBoolean)true); + Assert.AreEqual(new SqlVariant(DataTypeHelpers.BigInt, new SqlInt64(1), _eec) == new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1), _eec), (SqlBoolean)true); } [TestMethod] public void SortsAccordingToDataTypeFamilies() { var variant = SqlVariant.Null; - var dt = new SqlVariant(DataTypeHelpers.DateTime, new SqlDateTime(2000, 1, 1)); - var approx = new SqlVariant(DataTypeHelpers.Float, new SqlSingle(1)); - var exact = new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1)); - var ch = new SqlVariant(DataTypeHelpers.VarChar(10, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString("1")); - var nch = new SqlVariant(DataTypeHelpers.NVarChar(10, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString("1")); - var bin = new SqlVariant(DataTypeHelpers.VarBinary(10), new SqlBinary(new byte[] { 1 })); - var guid = new SqlVariant(DataTypeHelpers.UniqueIdentifier, new SqlGuid(Guid.NewGuid())); + var dt = new SqlVariant(DataTypeHelpers.DateTime, new SqlDateTime(2000, 1, 1), _eec); + var approx = new SqlVariant(DataTypeHelpers.Float, new SqlSingle(1), _eec); + var exact = new SqlVariant(DataTypeHelpers.Int, new SqlInt32(1), _eec); + var ch = new SqlVariant(DataTypeHelpers.VarChar(10, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString("1"), _eec); + var nch = new SqlVariant(DataTypeHelpers.NVarChar(10, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString("1"), _eec); + var bin = new SqlVariant(DataTypeHelpers.VarBinary(10), new SqlBinary(new byte[] { 1 }), _eec); + var guid = new SqlVariant(DataTypeHelpers.UniqueIdentifier, new SqlGuid(Guid.NewGuid()), _eec); var list = new List { variant, dt, approx, exact, ch, nch, bin, guid }; var rnd = new Random(); diff --git a/MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs b/MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs index fbb0655b..6871535d 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs @@ -31,6 +31,8 @@ class StubOptions : IQueryExecutionOptions ColumnOrdering IQueryExecutionOptions.ColumnOrdering => ColumnOrdering.Alphabetical; + public event EventHandler PrimaryDataSourceChanged; + void IQueryExecutionOptions.ConfirmInsert(ConfirmDmlStatementEventArgs e) { } diff --git a/MarkMpn.Sql4Cds.Engine/Ado/CancellationTokenOptionsWrapper.cs b/MarkMpn.Sql4Cds.Engine/Ado/CancellationTokenOptionsWrapper.cs index 5598817b..02133e8e 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/CancellationTokenOptionsWrapper.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/CancellationTokenOptionsWrapper.cs @@ -45,6 +45,19 @@ public CancellationTokenOptionsWrapper(IQueryExecutionOptions options, Cancellat public ColumnOrdering ColumnOrdering => _options.ColumnOrdering; + public event EventHandler PrimaryDataSourceChanged + { + add + { + _options.PrimaryDataSourceChanged += value; + } + + remove + { + _options.PrimaryDataSourceChanged -= value; + } + } + public void ConfirmDelete(ConfirmDmlStatementEventArgs e) { _options.ConfirmDelete(e); diff --git a/MarkMpn.Sql4Cds.Engine/Ado/ChangeDatabaseOptionsWrapper.cs b/MarkMpn.Sql4Cds.Engine/Ado/ChangeDatabaseOptionsWrapper.cs deleted file mode 100644 index 0c89a9ec..00000000 --- a/MarkMpn.Sql4Cds.Engine/Ado/ChangeDatabaseOptionsWrapper.cs +++ /dev/null @@ -1,107 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using System.Threading; -using Microsoft.Xrm.Sdk.Metadata; -using Microsoft.Xrm.Sdk.Query; - -namespace MarkMpn.Sql4Cds.Engine -{ - class ChangeDatabaseOptionsWrapper : IQueryExecutionOptions - { - private readonly Sql4CdsConnection _connection; - private readonly IQueryExecutionOptions _options; - - public ChangeDatabaseOptionsWrapper(Sql4CdsConnection connection, IQueryExecutionOptions options) - { - _connection = connection; - _options = options; - - PrimaryDataSource = options.PrimaryDataSource; - UseTDSEndpoint = options.UseTDSEndpoint; - BlockUpdateWithoutWhere = options.BlockUpdateWithoutWhere; - BlockDeleteWithoutWhere = options.BlockDeleteWithoutWhere; - UseBulkDelete = options.UseBulkDelete; - BatchSize = options.BatchSize; - MaxDegreeOfParallelism = options.MaxDegreeOfParallelism; - UseLocalTimeZone = options.UseLocalTimeZone; - BypassCustomPlugins = options.BypassCustomPlugins; - QuotedIdentifiers = options.QuotedIdentifiers; - ColumnOrdering = options.ColumnOrdering; - } - - public CancellationToken CancellationToken => _options.CancellationToken; - - public bool BlockUpdateWithoutWhere { get; set; } - - public bool BlockDeleteWithoutWhere { get; set; } - - public bool UseBulkDelete { get; set; } - - public int BatchSize { get; set; } - - public bool UseTDSEndpoint { get; set; } - - public int MaxDegreeOfParallelism { get; set; } - - public bool UseLocalTimeZone { get; set; } - - public bool BypassCustomPlugins { get; set; } - - public string PrimaryDataSource { get; set; } // TODO: Update UserId when changing data source - - public Guid UserId => _options.UserId; - - public bool QuotedIdentifiers { get; set; } - - public ColumnOrdering ColumnOrdering { get; set; } - - public void ConfirmDelete(ConfirmDmlStatementEventArgs e) - { - if (!e.Cancel) - _options.ConfirmDelete(e); - - if (!e.Cancel) - _connection.OnPreDelete(e); - } - - public void ConfirmInsert(ConfirmDmlStatementEventArgs e) - { - if (!e.Cancel) - _options.ConfirmInsert(e); - - if (!e.Cancel) - _connection.OnPreInsert(e); - } - - public void ConfirmUpdate(ConfirmDmlStatementEventArgs e) - { - if (!e.Cancel) - _options.ConfirmUpdate(e); - - if (!e.Cancel) - _connection.OnPreUpdate(e); - } - - public bool ContinueRetrieve(int count) - { - var cancelled = !_options.ContinueRetrieve(count); - - if (!cancelled) - { - var args = new ConfirmRetrieveEventArgs(count); - _connection.OnPreRetrieve(args); - - cancelled = args.Cancel; - } - - return !cancelled; - } - - public void Progress(double? progress, string message) - { - _options.Progress(progress, message); - _connection.OnProgress(new ProgressEventArgs(progress, message)); - } - } -} diff --git a/MarkMpn.Sql4Cds.Engine/Ado/DefaultQueryExecutionOptions.cs b/MarkMpn.Sql4Cds.Engine/Ado/DefaultQueryExecutionOptions.cs index 295d595f..9eb67536 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/DefaultQueryExecutionOptions.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/DefaultQueryExecutionOptions.cs @@ -16,85 +16,126 @@ namespace MarkMpn.Sql4Cds.Engine { class DefaultQueryExecutionOptions : IQueryExecutionOptions { - public DefaultQueryExecutionOptions(DataSource dataSource, CancellationToken cancellationToken) + private readonly Sql4CdsConnection _connection; + private string _primaryDataSource; + private Guid? _userId; + + public DefaultQueryExecutionOptions(Sql4CdsConnection connection, DataSource dataSource, CancellationToken cancellationToken) { - PrimaryDataSource = dataSource.Name; + _connection = connection; + _primaryDataSource = dataSource.Name; CancellationToken = cancellationToken; -#if NETCOREAPP - if (dataSource.Connection is ServiceClient svc) - { - UserId = svc.GetMyUserId(); - } -#else - if (dataSource.Connection is CrmServiceClient svc) - { - UserId = svc.GetMyCrmUserId(); - } -#endif - else - { - var whoami = (WhoAmIResponse)dataSource.Connection.Execute(new WhoAmIRequest()); - UserId = whoami.UserId; - } - + PrimaryDataSourceChanged += (_, __) => _userId = null; } public CancellationToken CancellationToken { get; } - public bool BlockUpdateWithoutWhere => false; + public bool BlockUpdateWithoutWhere { get; set; } - public bool BlockDeleteWithoutWhere => false; + public bool BlockDeleteWithoutWhere { get; set; } - public bool UseBulkDelete => false; + public bool UseBulkDelete { get; set; } - public int BatchSize => 100; + public int BatchSize { get; set; } = 100; - public bool UseTDSEndpoint => true; + public bool UseTDSEndpoint { get; set; } = true; - public int MaxDegreeOfParallelism => 10; + public int MaxDegreeOfParallelism { get; set; } = 10; public bool ColumnComparisonAvailable { get; } public bool OrderByEntityNameAvailable { get; } - public bool UseLocalTimeZone => false; + public bool UseLocalTimeZone { get; set; } public List JoinOperatorsAvailable { get; } - public bool BypassCustomPlugins => false; + public bool BypassCustomPlugins { get; set; } - public string PrimaryDataSource { get; } + public string PrimaryDataSource + { + get => _primaryDataSource; + set + { + if (_primaryDataSource != value) + { + _primaryDataSource = value; + OnPrimaryDataSourceChanged(); + } + } + } + + public Guid UserId + { + get + { + if (_userId == null) + { +#if NETCOREAPP + if (_connection.Session.DataSources[PrimaryDataSource].Connection is ServiceClient svc) + { + _userId = svc.GetMyUserId(); + } +#else + if (_connection.Session.DataSources[PrimaryDataSource].Connection is CrmServiceClient svc) + { + _userId = svc.GetMyCrmUserId(); + } +#endif + else + { + var whoami = (WhoAmIResponse)_connection.Session.DataSources[PrimaryDataSource].Connection.Execute(new WhoAmIRequest()); + _userId = whoami.UserId; + } + } + + return _userId.Value; + } + } - public Guid UserId { get; } + public bool QuotedIdentifiers { get; set; } = true; - public bool QuotedIdentifiers => true; + public ColumnOrdering ColumnOrdering { get; set; } - public ColumnOrdering ColumnOrdering => ColumnOrdering.Strict; + public event EventHandler PrimaryDataSourceChanged; public void ConfirmDelete(ConfirmDmlStatementEventArgs e) { + if (!e.Cancel) + _connection.OnPreDelete(e); } public void ConfirmInsert(ConfirmDmlStatementEventArgs e) { + if (!e.Cancel) + _connection.OnPreInsert(e); } public void ConfirmUpdate(ConfirmDmlStatementEventArgs e) { + if (!e.Cancel) + _connection.OnPreUpdate(e); } public bool ContinueRetrieve(int count) { - return true; + var args = new ConfirmRetrieveEventArgs(count); + _connection.OnPreRetrieve(args); + + var cancelled = args.Cancel; + + return !cancelled; } public void Progress(double? progress, string message) { + _connection.OnProgress(new ProgressEventArgs(progress, message)); } - public void RetrievingNextPage() + protected virtual void OnPrimaryDataSourceChanged() { + PrimaryDataSourceChanged?.Invoke(this, EventArgs.Empty); } } } diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs index 6987088f..0610529f 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs @@ -48,7 +48,7 @@ public Sql4CdsCommand(Sql4CdsConnection connection, string commandText) CommandTimeout = 30; DbParameterCollection = new Sql4CdsParameterCollection(); - _planBuilder = new ExecutionPlanBuilder(_connection.DataSources.Values, _connection.Options); + _planBuilder = new ExecutionPlanBuilder(_connection.Session, _connection.Options); _planBuilder.Log = msg => _connection.OnInfoMessage(null, msg); } @@ -123,7 +123,7 @@ protected override DbConnection DbConnection throw new ArgumentOutOfRangeException(nameof(value), "Connection must be a Sql4CdsConnection"); _connection = con; - _planBuilder = new ExecutionPlanBuilder(_connection.DataSources.Values, _connection.Options); + _planBuilder = new ExecutionPlanBuilder(_connection.Session, _connection.Options); _planBuilder.Log = msg => _connection.OnInfoMessage(null, msg); Plan = null; UseTDSEndpointDirectly = false; @@ -261,10 +261,10 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) if (UseTDSEndpointDirectly) { #if NETCOREAPP - var svc = (ServiceClient)_connection.DataSources[_connection.Database].Connection; + var svc = (ServiceClient)_connection.Session.DataSources[_connection.Database].Connection; var con = new SqlConnection("server=" + svc.ConnectedOrgUriActual.Host); #else - var svc = (CrmServiceClient)_connection.DataSources[_connection.Database].Connection; + var svc = (CrmServiceClient)_connection.Session.DataSources[_connection.Database].Connection; var con = new SqlConnection("server=" + svc.CrmConnectOrgUriActual.Host); #endif con.AccessToken = svc.CurrentAccessToken; @@ -272,11 +272,11 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) var cmd = con.CreateCommand(); cmd.CommandTimeout = (int)TimeSpan.FromMinutes(2).TotalSeconds; - cmd.CommandText = SqlNode.ApplyCommandBehavior(CommandText, behavior, _connection.Options); + cmd.CommandText = SqlNode.ApplyCommandBehavior(CommandText, behavior, new NodeExecutionContext(null, _connection.Options, null, null, null)); var node = new SqlNode { Sql = cmd.CommandText, DataSource = _connection.Database }; cmd.StatementCompleted += (_, e) => { - _connection.GlobalVariableValues["@@ROWCOUNT"] = (SqlInt32)e.RecordCount; + _connection.Session.GlobalVariableValues["@@ROWCOUNT"] = (SqlInt32)e.RecordCount; OnStatementCompleted(node, e.RecordCount, $"({e.RecordCount} {(e.RecordCount == 1 ? "row" : "rows")} affected)"); }; diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs index 167f66ac..4f8f7f78 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs @@ -26,11 +26,9 @@ namespace MarkMpn.Sql4Cds.Engine /// public class Sql4CdsConnection : DbConnection { - private readonly IDictionary _dataSources; - private readonly ChangeDatabaseOptionsWrapper _options; - private readonly Dictionary _globalVariableTypes; - private readonly Dictionary _globalVariableValues; + private readonly DefaultQueryExecutionOptions _options; private readonly TelemetryClient _ai; + private readonly SessionContext _session; /// /// Creates a new using the specified XRM connection string @@ -61,27 +59,8 @@ public Sql4CdsConnection(IDictionary dataSources) if (dataSources.Count == 0) throw new ArgumentOutOfRangeException("At least one data source must be supplied"); - var options = new DefaultQueryExecutionOptions(dataSources.First().Value, CancellationToken.None); - - _dataSources = dataSources; - _options = new ChangeDatabaseOptionsWrapper(this, options); - - _globalVariableTypes = new Dictionary(StringComparer.OrdinalIgnoreCase) - { - ["@@IDENTITY"] = DataTypeHelpers.EntityReference, - ["@@ROWCOUNT"] = DataTypeHelpers.Int, - ["@@SERVERNAME"] = DataTypeHelpers.NVarChar(100, _dataSources[_options.PrimaryDataSource].DefaultCollation, CollationLabel.CoercibleDefault), - ["@@VERSION"] = DataTypeHelpers.NVarChar(Int32.MaxValue, _dataSources[_options.PrimaryDataSource].DefaultCollation, CollationLabel.CoercibleDefault), - ["@@ERROR"] = DataTypeHelpers.Int, - }; - _globalVariableValues = new Dictionary(StringComparer.OrdinalIgnoreCase) - { - ["@@IDENTITY"] = SqlEntityReference.Null, - ["@@ROWCOUNT"] = (SqlInt32)0, - ["@@SERVERNAME"] = GetServerName(_dataSources[_options.PrimaryDataSource]), - ["@@VERSION"] = GetVersion(_dataSources[_options.PrimaryDataSource]), - ["@@ERROR"] = (SqlInt32)0, - }; + _options = new DefaultQueryExecutionOptions(this, dataSources.First().Value, CancellationToken.None); + _session = new SessionContext(dataSources, _options); _ai = new TelemetryClient(new Microsoft.ApplicationInsights.Extensibility.TelemetryConfiguration("79761278-a908-4575-afbf-2f4d82560da6")); @@ -93,51 +72,6 @@ public Sql4CdsConnection(IDictionary dataSources) ApplicationName = "SQL 4 CDS ADO.NET Provider"; } - private SqlString GetVersion(DataSource dataSource) - { - string orgVersion = null; - -#if NETCOREAPP - if (dataSource.Connection is ServiceClient svc) - orgVersion = svc.ConnectedOrgVersion.ToString(); -#else - if (dataSource.Connection is CrmServiceClient svc) - orgVersion = svc.ConnectedOrgVersion.ToString(); -#endif - - if (orgVersion == null) - orgVersion = ((RetrieveVersionResponse)dataSource.Execute(new RetrieveVersionRequest())).Version; - - var assembly = typeof(Sql4CdsConnection).Assembly; - var assemblyVersion = assembly.GetName().Version; - var assemblyCopyright = assembly - .GetCustomAttributes(typeof(AssemblyCopyrightAttribute), false) - .OfType() - .FirstOrDefault()? - .Copyright; - var assemblyFilename = assembly.Location; - var assemblyDate = System.IO.File.GetLastWriteTime(assemblyFilename); - - return $"Microsoft Dataverse - {orgVersion}\r\n\tSQL 4 CDS - {assemblyVersion}\r\n\t{assemblyDate:MMM dd yyyy HH:mm:ss}\r\n\t{assemblyCopyright}"; - } - - private SqlString GetServerName(DataSource dataSource) - { -#if NETCOREAPP - var svc = dataSource.Connection as ServiceClient; - - if (svc != null) - return svc.ConnectedOrgUriActual.Host; -#else - var svc = dataSource.Connection as CrmServiceClient; - - if (svc != null) - return svc.CrmConnectOrgUriActual.Host; -#endif - - return dataSource.Name; - } - private static IOrganizationService Connect(string connectionString) { #if NETCOREAPP @@ -185,7 +119,7 @@ internal void OnInfoMessage(IRootExecutionPlanNode node, Sql4CdsError message) } } - internal IDictionary DataSources => _dataSources; + internal SessionContext Session => _session; internal IQueryExecutionOptions Options => _options; @@ -289,10 +223,6 @@ public ColumnOrdering ColumnOrdering set => _options.ColumnOrdering = value; } - internal Dictionary GlobalVariableTypes => _globalVariableTypes; - - internal Dictionary GlobalVariableValues => _globalVariableValues; - internal TelemetryClient TelemetryClient => _ai; /// @@ -371,7 +301,7 @@ public override string DataSource { get { - var dataSource = _dataSources[Database]; + var dataSource = Session.DataSources[Database]; #if NETCOREAPP if (dataSource.Connection is ServiceClient svc) @@ -389,7 +319,7 @@ public override string ServerVersion { get { - var dataSource = _dataSources[Database]; + var dataSource = Session.DataSources[Database]; #if NETCOREAPP if (dataSource.Connection is ServiceClient svc) @@ -408,7 +338,7 @@ public override string ServerVersion public override void ChangeDatabase(string databaseName) { - if (!_dataSources.ContainsKey(databaseName)) + if (!Session.DataSources.ContainsKey(databaseName)) throw new Sql4CdsException(new Sql4CdsError(11, 0, 0, null, databaseName, 0, "Database is not in the list of connected databases", null)); _options.PrimaryDataSource = databaseName; diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs index a88c27d0..b9cc364f 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs @@ -1,17 +1,12 @@ using System; using System.Collections; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlTypes; using System.Linq; -using System.Linq.Expressions; -using System.Text; -using System.Xml; using MarkMpn.Sql4Cds.Engine.ExecutionPlan; using Microsoft.SqlServer.TransactSql.ScriptDom; -using Microsoft.Xrm.Sdk; namespace MarkMpn.Sql4Cds.Engine { @@ -21,8 +16,9 @@ class Sql4CdsDataReader : DbDataReader private readonly Sql4CdsCommand _command; private readonly IQueryExecutionOptions _options; private readonly CommandBehavior _behavior; - private readonly Dictionary _parameterTypes; - private readonly Dictionary _parameterValues; + private readonly LayeredDictionary _parameterTypes; + private readonly LayeredDictionary _parameterValues; + private readonly Stack _errorDetails; private Dictionary _labelIndexes; private int _recordsAffected; private int _instructionPointer; @@ -32,7 +28,6 @@ class Sql4CdsDataReader : DbDataReader private int _rows; private int _resultSetsReturned; private bool _closed; - private Stack _errorDetails; public Sql4CdsDataReader(Sql4CdsCommand command, IQueryExecutionOptions options, CommandBehavior behavior) { @@ -42,14 +37,15 @@ public Sql4CdsDataReader(Sql4CdsCommand command, IQueryExecutionOptions options, _behavior = behavior; _recordsAffected = -1; - _parameterTypes = ((Sql4CdsParameterCollection)command.Parameters).GetParameterTypes(); - _parameterValues = ((Sql4CdsParameterCollection)command.Parameters).GetParameterValues(); + _parameterTypes = new LayeredDictionary( + _connection.Session.GlobalVariableTypes, + ((Sql4CdsParameterCollection)command.Parameters).GetParameterTypes(), + new Dictionary(StringComparer.OrdinalIgnoreCase)); - foreach (var paramType in _connection.GlobalVariableTypes) - _parameterTypes[paramType.Key] = paramType.Value; - - foreach (var paramValue in _connection.GlobalVariableValues) - _parameterValues[paramValue.Key] = paramValue.Value; + _parameterValues = new LayeredDictionary( + _connection.Session.GlobalVariableValues, + ((Sql4CdsParameterCollection)command.Parameters).GetParameterValues(), + new Dictionary(StringComparer.OrdinalIgnoreCase)); _errorDetails = new Stack(); @@ -57,7 +53,7 @@ public Sql4CdsDataReader(Sql4CdsCommand command, IQueryExecutionOptions options, Close(); } - internal Dictionary ParameterValues => _parameterValues; + internal IDictionary ParameterValues => _parameterValues; private Dictionary LabelIndexes { @@ -75,7 +71,7 @@ private Dictionary LabelIndexes } } - private bool ExecuteWithExceptionHandling(Dictionary parameterTypes, Dictionary parameterValues) + private bool ExecuteWithExceptionHandling(LayeredDictionary parameterTypes, LayeredDictionary parameterValues) { while (true) { @@ -109,10 +105,10 @@ private bool ExecuteWithExceptionHandling(Dictionary } } - private bool Execute(Dictionary parameterTypes, Dictionary parameterValues) + private bool Execute(LayeredDictionary parameterTypes, LayeredDictionary parameterValues) { IRootExecutionPlanNode logNode = null; - var context = new NodeExecutionContext(_connection.DataSources, _options, parameterTypes, parameterValues, msg => _connection.OnInfoMessage(logNode, msg)); + var context = new NodeExecutionContext(_connection.Session, _options, parameterTypes, parameterValues, msg => _connection.OnInfoMessage(logNode, msg)); context.Error = _errorDetails.FirstOrDefault(); try @@ -297,11 +293,6 @@ private bool Execute(Dictionary parameterTypes, Dicti throw sqlErr; } } - finally - { - foreach (var paramName in _connection.GlobalVariableValues.Keys.ToArray()) - _connection.GlobalVariableValues[paramName] = parameterValues[paramName]; - } if (_options.CancellationToken.IsCancellationRequested) throw new Sql4CdsException(new Sql4CdsError(11, 0, 0, null, null, 0, _command.CancelledManually ? "Query was cancelled by user" : "Query timed out", null)); diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs index d6885f7c..9a6ed10c 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs @@ -138,6 +138,12 @@ internal static Sql4CdsError InvalidLengthOrPrecision(SqlDataTypeReference type) return Create(1001, type, (SqlInt32)type.StartLine, (SqlInt32)precision); } + internal static Sql4CdsError InvalidScale(SqlDataTypeReference type, int scaleParamIndex) + { + var precision = Int32.Parse(type.Parameters[scaleParamIndex].Value); + return Create(1002, type, (SqlInt32)type.StartLine, (SqlInt32)precision); + } + internal static Sql4CdsError ArithmeticOverflow(DataTypeReference sourceType, DataTypeReference targetType, TSqlFragment fragment) { return Create(8115, fragment, GetTypeName(sourceType), GetTypeName(targetType)); @@ -228,6 +234,11 @@ internal static Sql4CdsError ConversionFailed(DataTypeReference sourceType, Lite return Create(245, sourceValue, Collation.USEnglish.ToSqlString(GetTypeName(sourceType)), (SqlInt32)sourceValue.Value.Length, Collation.USEnglish.ToSqlString(sourceValue.Value), Collation.USEnglish.ToSqlString(GetTypeName(targetType))); } + internal static Sql4CdsError ConversionOutOfRange(DataTypeReference sourceType, DataTypeReference targetType) + { + return Create(242, sourceType, Collation.USEnglish.ToSqlString(GetTypeName(sourceType)), Collation.USEnglish.ToSqlString(GetTypeName(targetType))); + } + internal static Sql4CdsError CollationConflict(TSqlFragment fragment, Collation source, Collation target, string operationName) { return Create(468, fragment, (SqlInt32)(source?.Name.Length ?? 0), Collation.USEnglish.ToSqlString(source?.Name), (SqlInt32)(target?.Name.Length ?? 0), Collation.USEnglish.ToSqlString(target?.Name), Collation.USEnglish.ToSqlString(operationName)); @@ -780,6 +791,26 @@ internal static Sql4CdsError InvalidProcedureParameterType(TSqlFragment fragment return Create(214, fragment, parameter, type); } + internal static Sql4CdsError InvalidDatePart(TSqlFragment fragment, string part, string function, DataTypeReference dataType) + { + return Create(9810, fragment, (SqlInt32)part.Length, Collation.USEnglish.ToSqlString(part), (SqlInt32)function.Length, Collation.USEnglish.ToSqlString(function), Collation.USEnglish.ToSqlString(GetTypeName(dataType))); + } + + internal static Sql4CdsError AdditionOverflow(TSqlFragment fragment, DataTypeReference type) + { + return Create(517, fragment, Collation.USEnglish.ToSqlString(GetTypeName(type))); + } + + internal static Sql4CdsError UnsupportedDatePart(TSqlFragment fragment, string part, string function) + { + return Create(9806, fragment, (SqlInt32)part.Length, Collation.USEnglish.ToSqlString(part), (SqlInt32)function.Length, Collation.USEnglish.ToSqlString(function)); + } + + internal static Sql4CdsError InvalidDateFormat(TSqlFragment fragment, string format) + { + return Create(2741, fragment, (SqlInt32)format.Length, Collation.USEnglish.ToSqlString(format)); + } + private static string GetTypeName(DataTypeReference type) { if (type is SqlDataTypeReference sqlType) diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs index 5dd71eef..de985b64 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs @@ -187,7 +187,7 @@ internal DataTypeReference GetDataType() break; case DbType.DateTimeOffset: - _dataType = DataTypeHelpers.DateTimeOffset; + _dataType = DataTypeHelpers.DateTimeOffset(7); break; case DbType.Decimal: diff --git a/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs b/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs index e01ddc69..1779d63c 100644 --- a/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs +++ b/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs @@ -51,7 +51,10 @@ public static SqlDataTypeReference DateTime2(short scale) return new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.DateTime2, Parameters = { new IntegerLiteral { Value = scale.ToString(CultureInfo.InvariantCulture) } } }; } - public static SqlDataTypeReference DateTimeOffset { get; } = new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.DateTimeOffset }; + public static SqlDataTypeReference DateTimeOffset(short scale) + { + return new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.DateTimeOffset, Parameters = { new IntegerLiteral { Value = scale.ToString(CultureInfo.InvariantCulture) } } }; + } public static SqlDataTypeReference Decimal(short precision, short scale) { @@ -418,12 +421,10 @@ public static short GetScale(this DataTypeReference type, short invalidValue = 0 case SqlDataTypeOption.DateTime: return 3; - case SqlDataTypeOption.DateTimeOffset: - return 7; - case SqlDataTypeOption.SmallDateTime: return 0; + case SqlDataTypeOption.DateTimeOffset: case SqlDataTypeOption.DateTime2: case SqlDataTypeOption.Time: if (dataType.Parameters.Count == 0 || diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs index af423080..a2fe0790 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs @@ -92,10 +92,10 @@ public AggregateFunction(Func selector) /// /// The current state of the aggregation /// The new state of the aggregation - public virtual void NextRecord(object state) + public virtual void NextRecord(object state, ExpressionExecutionContext context) { var value = _selector(); - Update(value, state); + Update(value, context, state); } /// @@ -103,10 +103,10 @@ public virtual void NextRecord(object state) /// /// The current state of the aggregation /// The new state of the aggregation - public virtual void NextPartition(object state) + public virtual void NextPartition(object state, ExpressionExecutionContext context) { var value = _selector(); - UpdatePartition(value, state); + UpdatePartition(value, context, state); } /// @@ -114,14 +114,14 @@ public virtual void NextPartition(object state) /// /// /// The current state of the aggregation - protected abstract void Update(object value, object state); + protected abstract void Update(object value, ExpressionExecutionContext context, object state); /// /// Updates the aggregation state based on a value extracted from the partition /// /// /// The current state of the aggregation - protected abstract void UpdatePartition(object value, object state); + protected abstract void UpdatePartition(object value, ExpressionExecutionContext context, object state); /// /// Returns the current value of this aggregation @@ -184,20 +184,20 @@ public Average(Func selector, DataTypeReference sourceType, DataTypeRefe _count = new CountColumn(selector); } - public override void NextRecord(object state) + public override void NextRecord(object state, ExpressionExecutionContext context) { var s = (State)state; - _sum.NextRecord(s.SumState); - _count.NextRecord(s.CountState); + _sum.NextRecord(s.SumState, context); + _count.NextRecord(s.CountState, context); } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { throw new NotImplementedException(); } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { throw new InvalidOperationException(); } @@ -253,13 +253,13 @@ public Count(Func selector) : base(selector) { } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { var s = (State)state; s.Value = s.Value + 1; } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { var s = (State)state; s.Value = s.Value + (SqlInt32)value; @@ -297,7 +297,7 @@ public CountColumn(Func selector) : base(selector) { } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { if (value == null || (value is INullable nullable && nullable.IsNull)) return; @@ -306,7 +306,7 @@ protected override void Update(object value, object state) s.Value = s.Value + 1; } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { var s = (State)state; s.Value = s.Value + (SqlInt32)value; @@ -350,7 +350,7 @@ public Max(Func selector, DataTypeReference type) : base(selector) Type = type; } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { if (value == null || (value is INullable nullable && nullable.IsNull)) return; @@ -363,9 +363,9 @@ protected override void Update(object value, object state) s.Value = cmp; } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { - Update(value, state); + Update(value, context, state); } public override object GetValue(object state) @@ -406,7 +406,7 @@ public Min(Func selector, DataTypeReference type) : base(selector) Type = type; } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { if (value == null || (value is INullable nullable && nullable.IsNull)) return; @@ -420,9 +420,9 @@ protected override void Update(object value, object state) s.Value = cmp; } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { - Update(value, state); + Update(value, context, state); } public override object GetValue(object state) @@ -454,7 +454,7 @@ class State } private readonly SqlDataTypeOption _type; - private readonly Func _valueSelector; + private readonly Func _valueSelector; /// /// Creates a new @@ -467,15 +467,16 @@ public Sum(Func selector, DataTypeReference sourceType, DataTypeReferenc _type = ((SqlDataTypeReference)returnType).SqlDataTypeOption; var valueParam = Expression.Parameter(typeof(object)); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); var unboxed = Expression.Unbox(valueParam, sourceType.ToNetType(out _)); - var conversion = SqlTypeConverter.Convert(unboxed, sourceType, returnType); + var conversion = SqlTypeConverter.Convert(unboxed, contextParam, sourceType, returnType); conversion = Expr.Box(conversion); - _valueSelector = (Func) Expression.Lambda(conversion, valueParam).Compile(); + _valueSelector = (Func) Expression.Lambda(conversion, valueParam, contextParam).Compile(); } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { - var d = (INullable)_valueSelector(value); + var d = (INullable)_valueSelector(value, context); if (d.IsNull) return; @@ -509,9 +510,9 @@ protected override void Update(object value, object state) } } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { - Update(value, state); + Update(value, context, state); } public override object GetValue(object state) @@ -569,7 +570,7 @@ public First(Func selector, DataTypeReference type) : base(selector) Type = type; } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { var s = (State)state; if (s.Done) @@ -579,7 +580,7 @@ protected override void Update(object value, object state) s.Done = true; } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { throw new InvalidOperationException(); } @@ -611,7 +612,7 @@ public State() public SqlString Value { get; set; } } - private Func _valueSelector; + private Func _valueSelector; /// /// Creates a new @@ -622,17 +623,18 @@ public StringAgg(Func selector, DataTypeReference sourceType, DataTypeRe Type = returnType; var valueParam = Expression.Parameter(typeof(object)); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); var unboxed = Expression.Unbox(valueParam, sourceType.ToNetType(out _)); - var conversion = SqlTypeConverter.Convert(unboxed, sourceType, returnType); + var conversion = SqlTypeConverter.Convert(unboxed, contextParam, sourceType, returnType); conversion = Expr.Convert(conversion, typeof(SqlString)); - _valueSelector = (Func)Expression.Lambda(conversion, valueParam).Compile(); + _valueSelector = (Func)Expression.Lambda(conversion, valueParam, contextParam).Compile(); } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { var s = (State)state; - var str = _valueSelector(value); + var str = _valueSelector(value, context); if (str.IsNull) return; @@ -643,7 +645,7 @@ protected override void Update(object value, object state) s.Value += Separator + str; } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { throw new InvalidOperationException(); } @@ -687,21 +689,21 @@ public DistinctAggregate(AggregateFunction func, Func selector) : base(s public override DataTypeReference Type => _func.Type; - public override void NextRecord(object state) + public override void NextRecord(object state, ExpressionExecutionContext context) { var value = _selector(); var s = (State)state; if (s.Distinct.Add(value)) - _func.NextRecord(s.InnerState); + _func.NextRecord(s.InnerState, context); } - protected override void UpdatePartition(object value, object state) + protected override void UpdatePartition(object value, ExpressionExecutionContext context, object state) { throw new InvalidOperationException(); } - protected override void Update(object value, object state) + protected override void Update(object value, ExpressionExecutionContext context, object state) { throw new NotImplementedException(); } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs index af59ef61..feebab8f 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs @@ -50,11 +50,14 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect var entities = GetDmlSourceEntities(context, out var schema); var valueAccessors = CompileValueAccessors(schema, entities, context.ParameterTypes); + var eec = new ExpressionExecutionContext(context); foreach (var entity in entities) { + eec.Entity = entity; + foreach (var variable in Variables) - context.ParameterValues[variable.VariableName] = (INullable)valueAccessors[variable.VariableName](entity); + context.ParameterValues[variable.VariableName] = (INullable)valueAccessors[variable.VariableName](eec); count++; } @@ -74,10 +77,10 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect /// The attributes in the target metadata /// The time zone that datetime values are supplied in /// - protected Dictionary> CompileValueAccessors(INodeSchema schema, List entities, IDictionary variableTypes) + protected Dictionary> CompileValueAccessors(INodeSchema schema, List entities, IDictionary variableTypes) { - var valueAccessors = new Dictionary>(); - var entityParam = Expression.Parameter(typeof(Entity)); + var valueAccessors = new Dictionary>(); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); foreach (var mapping in Variables) { @@ -94,7 +97,8 @@ protected Dictionary> CompileValueAccessors(INodeSc var destNetType = destSqlType.ToNetType(out _); - var expr = (Expression)Expression.Property(entityParam, typeof(Entity).GetCustomAttribute().MemberName, Expression.Constant(sourceColumnName)); + var entity = Expression.Property(contextParam, nameof(ExpressionExecutionContext.Entity)); + var expr = (Expression)Expression.Property(entity, typeof(Entity).GetCustomAttribute().MemberName, Expression.Constant(sourceColumnName)); var originalExpr = expr; if (sourceSqlType.IsSameAs(DataTypeHelpers.Int) && !SqlTypeConverter.CanChangeTypeExplicit(sourceSqlType, destSqlType) && entities.All(e => ((SqlInt32)e[sourceColumnName]).IsNull)) @@ -108,12 +112,12 @@ protected Dictionary> CompileValueAccessors(INodeSc expr = Expression.Convert(expr, sourceSqlType.ToNetType(out _)); // Convert to destination SQL type - expr = SqlTypeConverter.Convert(expr, sourceSqlType, destSqlType); + expr = SqlTypeConverter.Convert(expr, contextParam, sourceSqlType, destSqlType); } expr = Expr.Box(expr); - valueAccessors[destVariableName] = Expression.Lambda>(expr, entityParam).Compile(); + valueAccessors[destVariableName] = Expression.Lambda>(expr, contextParam).Compile(); } return valueAccessors; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs index 05e76a3a..bedb2dad 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs @@ -387,40 +387,8 @@ private bool TranslateFetchXMLCriteria(NodeCompilationContext context, DataSourc } // Select the correct FetchXML operator - @operator op; - - switch (type) - { - case BooleanComparisonType.Equals: - case BooleanComparisonType.IsNotDistinctFrom: - op = @operator.eq; - break; - - case BooleanComparisonType.GreaterThan: - op = @operator.gt; - break; - - case BooleanComparisonType.GreaterThanOrEqualTo: - op = @operator.ge; - break; - - case BooleanComparisonType.LessThan: - op = @operator.lt; - break; - - case BooleanComparisonType.LessThanOrEqualTo: - op = @operator.le; - break; - - case BooleanComparisonType.NotEqualToBrackets: - case BooleanComparisonType.NotEqualToExclamation: - case BooleanComparisonType.IsDistinctFrom: - op = @operator.ne; - break; - - default: - throw new NotSupportedQueryFragmentException(Sql4CdsError.SyntaxError(comparison)) { Suggestion = "Unsupported comparison type" }; - } + if (!type.TryConvertToFetchXml(out var op)) + throw new NotSupportedQueryFragmentException(Sql4CdsError.SyntaxError(comparison)) { Suggestion = "Unsupported comparison type" }; // Find the entity that the condition applies to, which may be different to the entity that the condition FetchXML element will be // added within @@ -741,7 +709,7 @@ private bool TranslateFetchXMLCriteria(NodeCompilationContext context, DataSourc if (inPred.Subquery != null) return false; - if (!inPred.Values.All(v => v is Literal)) + if (!inPred.Values.All(v => v is ValueExpression)) return false; var columnName = inCol.GetColumnName(); @@ -764,7 +732,7 @@ private bool TranslateFetchXMLCriteria(NodeCompilationContext context, DataSourc var meta = dataSource.Metadata[entityName]; - return TranslateFetchXMLCriteriaWithVirtualAttributes(context, meta, entityAlias, attrName, null, op, inPred.Values.Cast().ToArray(), dataSource, targetEntityAlias, items, out condition, out filter); + return TranslateFetchXMLCriteriaWithVirtualAttributes(context, meta, entityAlias, attrName, null, op, inPred.Values.Cast().ToArray(), dataSource, targetEntityAlias, items, out condition, out filter); } if (criteria is BooleanIsNullExpression isNull) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs index e404fcc4..30aca68c 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs @@ -181,7 +181,7 @@ private int GetMaxDOP(NodeCompilationContext context, IList query if (DataSource == null) return 1; - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Unknown datasource"); return ParallelismHelper.GetMaxDOP(dataSource, context, queryHints); @@ -235,6 +235,66 @@ private bool GetContinueOnError(NodeCompilationContext context, IList hints, string logicalName, string[] requiredColumns) + { + if (hints != null && hints.OfType().Any(hint => hint.Hints.Any(s => s.Value.Equals("NO_DIRECT_DML", StringComparison.OrdinalIgnoreCase)))) + return; + + // Work out the fields that we should use as the primary key for these records. + var dataSource = context.Session.DataSources[DataSource]; + var targetMetadata = dataSource.Metadata[logicalName]; + var keyAttributes = new[] { targetMetadata.PrimaryIdAttribute }; + + if (targetMetadata.LogicalName == "listmember") + { + keyAttributes = new[] { "listid", "entityid" }; + } + else if (targetMetadata.IsIntersect == true) + { + var relationship = targetMetadata.ManyToManyRelationships.Single(); + keyAttributes = new[] { relationship.Entity1IntersectAttribute, relationship.Entity2IntersectAttribute }; + } + else if (targetMetadata.DataProviderId == DataProviders.ElasticDataProvider) + { + // Elastic tables need the partitionid as part of the primary key + keyAttributes = new[] { targetMetadata.PrimaryIdAttribute, "partitionid" }; + } + else if (targetMetadata.LogicalName == "activitypointer") + { + // Can't do DML operations on base activitypointer table, need to read the record to + // find the concrete activity type. + return; + } + + // Skip any ComputeScalar node that is being used to generate additional values, + // unless they reference additional values in the data source + var compute = Source as ComputeScalarNode; + + if (compute != null) + { + if (compute.Columns.Any(c => c.Value.GetColumns().Except(keyAttributes).Any())) + return; + + // Ignore any columns being created by the ComputeScalar node + foreach (var col in compute.Columns) + requiredColumns = requiredColumns.Except(new[] { col.Key }).ToArray(); + } + + if ((compute?.Source ?? Source) is FetchXmlScan fetch) + { + var folded = fetch.FoldDmlSource(context, hints, logicalName, requiredColumns, keyAttributes); + + if (compute != null) + compute.Source = folded; + else + Source = folded; + } + else if (Source is SqlNode sql) + { + Source = sql.FoldDmlSource(context, hints, logicalName, requiredColumns, keyAttributes); + } + } + /// /// Changes the name of source columns /// @@ -269,7 +329,7 @@ protected List GetDmlSourceEntities(NodeExecutionContext context, out IN var dataTable = new DataTable(); var schemaTable = dataReader.GetSchemaTable(); var columnTypes = new ColumnList(); - var targetDataSource = DataSource == null ? context.PrimaryDataSource : context.DataSources[DataSource]; + var targetDataSource = DataSource == null ? context.PrimaryDataSource : context.Session.DataSources[DataSource]; for (var i = 0; i < schemaTable.Rows.Count; i++) { @@ -297,7 +357,7 @@ protected List GetDmlSourceEntities(NodeExecutionContext context, out IN case "smalldatetime": colSqlType = DataTypeHelpers.SmallDateTime; break; case "date": colSqlType = DataTypeHelpers.Date; break; case "time": colSqlType = DataTypeHelpers.Time(scale); break; - case "datetimeoffset": colSqlType = DataTypeHelpers.DateTimeOffset; break; + case "datetimeoffset": colSqlType = DataTypeHelpers.DateTimeOffset(scale); break; case "datetime2": colSqlType = DataTypeHelpers.DateTime2(scale); break; case "decimal": colSqlType = DataTypeHelpers.Decimal(precision, scale); break; case "numeric": colSqlType = DataTypeHelpers.Decimal(precision, scale); break; @@ -384,13 +444,14 @@ protected List GetDmlSourceEntities(NodeExecutionContext context, out IN /// The time zone that datetime values are supplied in /// The records that are being mapped /// - protected Dictionary> CompileColumnMappings(DataSource dataSource, string logicalName, IDictionary mappings, INodeSchema schema, DateTimeKind dateTimeKind, List entities) + protected Dictionary> CompileColumnMappings(DataSource dataSource, string logicalName, IDictionary mappings, INodeSchema schema, DateTimeKind dateTimeKind, List entities) { var metadata = dataSource.Metadata[logicalName]; var attributes = metadata.Attributes.ToDictionary(a => a.LogicalName, StringComparer.OrdinalIgnoreCase); - var attributeAccessors = new Dictionary>(); - var entityParam = Expression.Parameter(typeof(Entity)); + var attributeAccessors = new Dictionary>(); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); + var entityParam = Expression.Property(contextParam, nameof(ExpressionExecutionContext.Entity)); foreach (var mapping in mappings) { @@ -438,7 +499,7 @@ protected Dictionary> CompileColumnMappings(DataSou { convertedExpr = expr; expr = originalExpr; - convertedExpr = SqlTypeConverter.Convert(convertedExpr, typeof(EntityReference)); + convertedExpr = SqlTypeConverter.Convert(convertedExpr, contextParam, typeof(EntityReference)); } else if (sourceSqlType == DataTypeHelpers.ImplicitIntForNullLiteral) { @@ -472,13 +533,13 @@ protected Dictionary> CompileColumnMappings(DataSou { // Normally we want to specify the target type as a logical name var stringType = DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.Implicit); - targetExpr = SqlTypeConverter.Convert(targetExpr, sourceTargetType, stringType); - targetExpr = SqlTypeConverter.Convert(targetExpr, typeof(string)); + targetExpr = SqlTypeConverter.Convert(targetExpr, contextParam, sourceTargetType, stringType); + targetExpr = SqlTypeConverter.Convert(targetExpr, contextParam, typeof(string)); } } - convertedExpr = SqlTypeConverter.Convert(expr, sourceSqlType, DataTypeHelpers.UniqueIdentifier); - convertedExpr = SqlTypeConverter.Convert(convertedExpr, typeof(Guid)); + convertedExpr = SqlTypeConverter.Convert(expr, contextParam, sourceSqlType, DataTypeHelpers.UniqueIdentifier); + convertedExpr = SqlTypeConverter.Convert(convertedExpr, contextParam, typeof(Guid)); convertedExpr = Expression.New( typeof(EntityReference).GetConstructor(new[] { typeof(string), typeof(Guid) }), targetExpr, @@ -490,8 +551,8 @@ protected Dictionary> CompileColumnMappings(DataSou { var partitionIdExpr = (Expression)Expression.Property(entityParam, typeof(Entity).GetCustomAttribute().MemberName, Expression.Constant(partitionIdColumn)); partitionIdExpr = Expression.Convert(partitionIdExpr, schema.Schema[partitionIdColumn].Type.ToNetType(out _)); - partitionIdExpr = SqlTypeConverter.Convert(partitionIdExpr, schema.Schema[partitionIdColumn].Type, DataTypeHelpers.NVarChar(100, dataSource.DefaultCollation, CollationLabel.Implicit)); - partitionIdExpr = SqlTypeConverter.Convert(partitionIdExpr, typeof(string)); + partitionIdExpr = SqlTypeConverter.Convert(partitionIdExpr, contextParam, schema.Schema[partitionIdColumn].Type, DataTypeHelpers.NVarChar(100, dataSource.DefaultCollation, CollationLabel.Implicit)); + partitionIdExpr = SqlTypeConverter.Convert(partitionIdExpr, contextParam, typeof(string)); convertedExpr = Expr.Call(() => CreateElasticEntityReference(Expr.Arg(), Expr.Arg(), Expr.Arg()), convertedExpr, partitionIdExpr, Expression.Constant(dataSource.Metadata)); } @@ -505,11 +566,11 @@ protected Dictionary> CompileColumnMappings(DataSou { // Convert to destination SQL type - don't do this if we're converting from an EntityReference to a PartyList so // we don't lose the entity name during the conversion via a string - expr = SqlTypeConverter.Convert(expr, sourceSqlType, destSqlType, throwOnTruncate: true, table: logicalName, column: destAttributeName); + expr = SqlTypeConverter.Convert(expr, contextParam, sourceSqlType, destSqlType, throwOnTruncate: true, table: logicalName, column: destAttributeName); } // Convert to final .NET SDK type - convertedExpr = SqlTypeConverter.Convert(expr, destType); + convertedExpr = SqlTypeConverter.Convert(expr, contextParam, destType); if (attr is EnumAttributeMetadata && !(attr is MultiSelectPicklistAttributeMetadata)) { @@ -546,10 +607,10 @@ protected Dictionary> CompileColumnMappings(DataSou convertedExpr); if (expr.Type.IsValueType) - expr = SqlTypeConverter.Convert(expr, typeof(object)); + expr = SqlTypeConverter.Convert(expr, contextParam, typeof(object)); } - attributeAccessors[destAttributeName] = Expression.Lambda>(expr, entityParam).Compile(); + attributeAccessors[destAttributeName] = Expression.Lambda>(expr, contextParam).Compile(); } return attributeAccessors; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BulkDeleteJobNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BulkDeleteJobNode.cs index af94d3d2..932a8a4c 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BulkDeleteJobNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BulkDeleteJobNode.cs @@ -53,7 +53,7 @@ public void Execute(NodeExecutionContext context, out int recordsAffected, out s try { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); Microsoft.Xrm.Sdk.Query.QueryExpression query; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs index e7063661..8f24a378 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs @@ -74,7 +74,7 @@ public override IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContex if (result.Length != 1 || result[0] != this) return result; - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); // Use bulk delete if requested & possible @@ -92,6 +92,9 @@ Source is FetchXmlScan fetch && } } + // Replace a source query with a list of known IDs if possible + FoldIdsToConstantScan(context, hints, LogicalName, ColumnMappings.Values.ToArray()); + return new[] { this }; } @@ -114,12 +117,13 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect try { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); List entities; EntityMetadata meta; - Dictionary> attributeAccessors; + Dictionary> attributeAccessors; + var eec = new ExpressionExecutionContext(context); using (_timer.Run()) { @@ -163,7 +167,11 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect context.Options, entities, meta, - entity => CreateDeleteRequest(meta, entity, attributeAccessors), + entity => + { + eec.Entity = entity; + return CreateDeleteRequest(meta, eec, attributeAccessors); + }, new OperationNames { InProgressUppercase = "Deleting", @@ -188,14 +196,14 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect } } - private OrganizationRequest CreateDeleteRequest(EntityMetadata meta, Entity entity, Dictionary> attributeAccessors) + private OrganizationRequest CreateDeleteRequest(EntityMetadata meta, ExpressionExecutionContext context, Dictionary> attributeAccessors) { if (meta.LogicalName == "principalobjectaccess") { - var objectId = (Guid)attributeAccessors["objectid"](entity); - var objectTypeCode = entity.GetAttributeValue(ColumnMappings["objecttypecode"]).Value; - var principalId = (Guid)attributeAccessors["principalid"](entity); - var principalTypeCode = entity.GetAttributeValue(ColumnMappings["principaltypecode"]).Value; + var objectId = (Guid)attributeAccessors["objectid"](context); + var objectTypeCode = context.Entity.GetAttributeValue(ColumnMappings["objecttypecode"]).Value; + var principalId = (Guid)attributeAccessors["principalid"](context); + var principalTypeCode = context.Entity.GetAttributeValue(ColumnMappings["principaltypecode"]).Value; return new RevokeAccessRequest { @@ -209,16 +217,16 @@ private OrganizationRequest CreateDeleteRequest(EntityMetadata meta, Entity enti { return new RemoveMemberListRequest { - ListId = (Guid)attributeAccessors["listid"](entity), - EntityId = (Guid)attributeAccessors["entityid"](entity) + ListId = (Guid)attributeAccessors["listid"](context), + EntityId = (Guid)attributeAccessors["entityid"](context) }; } else if (meta.IsIntersect == true) { var relationship = meta.ManyToManyRelationships.Single(); - var targetId = (Guid)attributeAccessors[relationship.Entity1IntersectAttribute](entity); - var relatedId = (Guid)attributeAccessors[relationship.Entity2IntersectAttribute](entity); + var targetId = (Guid)attributeAccessors[relationship.Entity1IntersectAttribute](context); + var relatedId = (Guid)attributeAccessors[relationship.Entity2IntersectAttribute](context); return new DisassociateRequest { @@ -228,7 +236,7 @@ private OrganizationRequest CreateDeleteRequest(EntityMetadata meta, Entity enti }; } - var id = (Guid)attributeAccessors[meta.PrimaryIdAttribute](entity); + var id = (Guid)attributeAccessors[meta.PrimaryIdAttribute](context); var req = new DeleteRequest { Target = new EntityReference(LogicalName, id) @@ -242,14 +250,14 @@ private OrganizationRequest CreateDeleteRequest(EntityMetadata meta, Entity enti KeyAttributes = { [meta.PrimaryIdAttribute] = id, - ["partitionid"] = attributeAccessors["partitionid"](entity) + ["partitionid"] = attributeAccessors["partitionid"](context) } }; } // Special case for activitypointer - need to set the specific activity type code if (LogicalName == "activitypointer") - req.Target.LogicalName = entity.GetAttributeValue(ColumnMappings["activitytypecode"]).Value; + req.Target.LogicalName = context.Entity.GetAttributeValue(ColumnMappings["activitytypecode"]).Value; return req; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs index 15dc6bff..274b438c 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs @@ -96,7 +96,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext if (fetch.Entity.name == "audit" && Columns.Any(col => col.StartsWith(fetch.Alias.EscapeIdentifier() + ".objectid"))) return this; - var metadata = context.DataSources[fetch.DataSource].Metadata; + var metadata = context.Session.DataSources[fetch.DataSource].Metadata; // Can't apply DISTINCT to partylist attributes // https://github.com/MarkMpn/Sql4Cds/issues/528 diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs index e4ae0a11..99405697 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs @@ -75,7 +75,7 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect { using (_timer.Run()) { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); var entities = GetDmlSourceEntities(context, out var schema); @@ -93,10 +93,11 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect throw new QueryExecutionException(Sql4CdsError.ImpersonationError(username), new ApplicationException("Ambiguous username")); // Precompile mappings with type conversions + var eec = new ExpressionExecutionContext(context) { Entity = entities[0] }; var attributeAccessors = CompileColumnMappings(dataSource, "systemuser", new Dictionary(StringComparer.OrdinalIgnoreCase) { ["systemuserid"] = UserIdSource }, schema, DateTimeKind.Unspecified, entities); var userIdAccessor = attributeAccessors["systemuserid"]; - var userId = (Guid)userIdAccessor(entities[0]); + var userId = (Guid)userIdAccessor(eec); PropertyInfo callerIdProp; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs index 05fcb19c..ba75b75b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs @@ -136,7 +136,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext var expr = value.Value.Compile(expressionContext); var sqlConversion = SqlTypeConverter.GetConversion(sourceSqlType, destSqlType); var netConversion = SqlTypeConverter.GetConversion(destSqlType, destNetType); - var conversion = (Func) ((ExpressionExecutionContext ctx) => netConversion(sqlConversion(expr(ctx)))); + var conversion = (Func) ((ExpressionExecutionContext ctx) => netConversion(sqlConversion(expr(ctx), ctx), ctx)); if (ValueTypes[value.Key] == typeof(Entity)) { var conversionToString = conversion; @@ -283,7 +283,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont { PagesRetrieved = 0; - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); context.Options.Progress(0, $"Executing {MessageName}..."); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index 9bd47bde..edff418e 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -111,8 +111,8 @@ public static Func Compile(this BooleanExpress private static Expression ToExpression(this TSqlFragment expr, ExpressionCompilationContext context, bool createExpression, out ParameterExpression[] parameters, out DataTypeReference sqlType, out string cacheKey) { - var contextParam = createExpression ? Expression.Parameter(typeof(ExpressionExecutionContext)) : null; - var exprParam = createExpression ? Expression.Parameter(typeof(TSqlFragment)) : null; + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); + var exprParam = Expression.Parameter(typeof(TSqlFragment)); Expression expression; @@ -334,17 +334,21 @@ private static Expression ToExpression(ColumnReferenceExpression col, Expression return Expression.Convert(expr, returnType); } - private static Expression ToExpression(IdentifierLiteral guid, ExpressionCompilationContext context, ParameterExpression contextParam, ParameterExpression exprParam, bool createExpression, out DataTypeReference sqlType, out string cacheKey) + private static Expression ToExpression(IdentifierLiteral id, ExpressionCompilationContext context, ParameterExpression contextParam, ParameterExpression exprParam, bool createExpression, out DataTypeReference sqlType, out string cacheKey) { - sqlType = DataTypeHelpers.UniqueIdentifier; - cacheKey = ""; + sqlType = DataTypeHelpers.VarChar(id.Value.Length, context.PrimaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); + + cacheKey = $""; if (!createExpression) return null; - return Expr.Call( - () => SqlGuid.Parse(Expr.Arg()), - Expression.Property(Expression.Convert(exprParam, typeof(IdentifierLiteral)), nameof(IdentifierLiteral.Value))); + var value = Expression.Property(Expression.Convert(exprParam, typeof(IdentifierLiteral)), nameof(IdentifierLiteral.Value)); + + var expr = (Expression)Expression.Property(contextParam, nameof(ExpressionExecutionContext.PrimaryDataSource)); + expr = Expression.Property(expr, nameof(DataSource.DefaultCollation)); + expr = Expression.Call(expr, nameof(Collation.ToSqlString), Array.Empty(), value); + return expr; } private static Expression ToExpression(IntegerLiteral i, ExpressionCompilationContext context, ParameterExpression contextParam, ParameterExpression exprParam, bool createExpression, out DataTypeReference sqlType, out string cacheKey) @@ -545,7 +549,7 @@ cmp.SecondExpression is FunctionCall func && } if (!lhsType.IsSameAs(type)) - lhs = createExpression ? SqlTypeConverter.Convert(lhs, lhsType, type) : null; + lhs = createExpression ? SqlTypeConverter.Convert(lhs, contextParam, lhsType, type) : null; if (!rhsType.IsSameAs(type)) { @@ -565,7 +569,7 @@ cmp.SecondExpression is StringLiteral str && }; } - rhs = createExpression ? SqlTypeConverter.Convert(rhs, rhsType, type) : null; + rhs = createExpression ? SqlTypeConverter.Convert(rhs, contextParam, rhsType, type) : null; } AssertCollationSensitive(type); @@ -629,7 +633,7 @@ private static Expression ToExpression(DistinctPredicate distinct, ExpressionCom } if (!lhsType.IsSameAs(type)) - lhs = createExpression ? SqlTypeConverter.Convert(lhs, lhsType, type) : null; + lhs = createExpression ? SqlTypeConverter.Convert(lhs, contextParam, lhsType, type) : null; if (!rhsType.IsSameAs(type)) { @@ -649,7 +653,7 @@ distinct.SecondExpression is StringLiteral str && }; } - rhs = createExpression ? SqlTypeConverter.Convert(rhs, rhsType, type) : null; + rhs = createExpression ? SqlTypeConverter.Convert(rhs, contextParam, rhsType, type) : null; } AssertCollationSensitive(type); @@ -804,17 +808,17 @@ rhsSqlType is SqlDataTypeReference sqlRhsType && type = DataTypeHelpers.Decimal(p, s); if (lhs.Type != typeof(SqlDecimal)) - lhs = SqlTypeConverter.Convert(lhs, lhsSqlType, DataTypeHelpers.Decimal(lhsSqlType.GetPrecision(), lhsSqlType.GetScale())); + lhs = SqlTypeConverter.Convert(lhs, contextParam, lhsSqlType, DataTypeHelpers.Decimal(lhsSqlType.GetPrecision(), lhsSqlType.GetScale())); if (rhs.Type != typeof(SqlDecimal)) - rhs = SqlTypeConverter.Convert(rhs, rhsSqlType, DataTypeHelpers.Decimal(rhsSqlType.GetPrecision(), rhsSqlType.GetScale())); + rhs = SqlTypeConverter.Convert(rhs, contextParam, rhsSqlType, DataTypeHelpers.Decimal(rhsSqlType.GetPrecision(), rhsSqlType.GetScale())); } else { if (!lhsSqlType.IsSameAs(type)) - lhs = SqlTypeConverter.Convert(lhs, lhsSqlType, type); + lhs = SqlTypeConverter.Convert(lhs, contextParam, lhsSqlType, type); if (!rhsSqlType.IsSameAs(type)) - rhs = SqlTypeConverter.Convert(rhs, rhsSqlType, type); + rhs = SqlTypeConverter.Convert(rhs, contextParam, rhsSqlType, type); } sqlType = null; @@ -954,12 +958,13 @@ private static MethodInfo GetMethod(FunctionCall func, ExpressionCompilationCont var paramExpressionsWithType = func.Parameters .Select((param, index) => { - // Special case for DATEPART / DATEDIFF / DATEADD - first parameter looks like a field but is actually an identifier + // Special case for DATEPART / DATEDIFF / DATEADD / DATETRUNC - first parameter looks like a field but is actually an identifier if (index == 0 && ( func.FunctionName.Value.Equals("DATEPART", StringComparison.OrdinalIgnoreCase) || func.FunctionName.Value.Equals("DATEDIFF", StringComparison.OrdinalIgnoreCase) || - func.FunctionName.Value.Equals("DATEADD", StringComparison.OrdinalIgnoreCase) + func.FunctionName.Value.Equals("DATEADD", StringComparison.OrdinalIgnoreCase) || + func.FunctionName.Value.Equals("DATETRUNC", StringComparison.OrdinalIgnoreCase) )) { // Check parameter is an expected datepart value @@ -969,16 +974,9 @@ private static MethodInfo GetMethod(FunctionCall func, ExpressionCompilationCont ex = new NotSupportedQueryFragmentException(Sql4CdsError.InvalidParameter(param, 1, "datepart")); col = null; } - else + else if (!ExpressionFunctions.TryParseDatePart(col.MultiPartIdentifier.Identifiers.Single().Value, out _)) { - try - { - ExpressionFunctions.DatePartToInterval(col.MultiPartIdentifier.Identifiers.Single().Value); - } - catch - { - ex = new NotSupportedQueryFragmentException(Sql4CdsError.InvalidOptionValue(param, "datepart")); - } + ex = new NotSupportedQueryFragmentException(Sql4CdsError.InvalidOptionValue(param, "datepart")); } if (ex != null) @@ -1010,6 +1008,46 @@ private static MethodInfo GetMethod(FunctionCall func, ExpressionCompilationCont } var paramExpr = func.InvokeSubExpression(x => x.Parameters[index], (x, i) => x.Parameters[i], index, context, contextParam, exprParam, createExpression, out var paramType, out var paramCacheKey, out var paramException); + + // Special case for DATEPART - second parameter can accept any datetime family type. Function is implemented + // to accept datetimeoffset for highest precision, but also needs to accept numeric types which can't be converted + // to datetimeoffset. Convert them to datetime first + if (index == 1 && + func.FunctionName.Value.Equals("DATEPART", StringComparison.OrdinalIgnoreCase) && + paramType is SqlDataTypeReference datePartParamSqlType && + datePartParamSqlType.SqlDataTypeOption.IsNumeric()) + { + var dateTimeType = (DataTypeReference) DataTypeHelpers.DateTime; + paramExpr = Convert(context, contextParam, paramExpr, paramType, paramCacheKey, ref dateTimeType, null, null, null, func.Parameters[index], "IMPLICIT", out paramCacheKey); + paramType = dateTimeType; + } + + // Special case for DATETRUNC - second parameter can accept any datetime family type. String values should be + // implicitly converted to datetime2. Numeric values are not supported for DATETRUNC + if (index == 1 && + func.FunctionName.Value.Equals("DATETRUNC", StringComparison.OrdinalIgnoreCase) && + paramType is SqlDataTypeReference dateTruncParamSqlType && + dateTruncParamSqlType.SqlDataTypeOption.IsStringType()) + { + var dateTimeType = (DataTypeReference)DataTypeHelpers.DateTime2(7); + paramExpr = Convert(context, contextParam, paramExpr, paramType, paramCacheKey, ref dateTimeType, null, null, null, func.Parameters[index], "IMPLICIT", out paramCacheKey); + paramType = dateTimeType; + } + + // Special case for DATEADD - third parameter can accept any datetime family type. Function is implemented + // to accept datetimeoffset for highest precision, but also needs to accept numeric types which can't be converted + // to datetimeoffset. Convert them to datetime first. Can also accept string values which should be converted + // to datetime + if (index == 2 && + func.FunctionName.Value.Equals("DATEADD", StringComparison.OrdinalIgnoreCase) && + paramType is SqlDataTypeReference dateAddParamSqlType && + (dateAddParamSqlType.SqlDataTypeOption.IsNumeric() || dateAddParamSqlType.SqlDataTypeOption.IsStringType())) + { + var dateTimeType = (DataTypeReference) DataTypeHelpers.DateTime; + paramExpr = Convert(context, contextParam, paramExpr, paramType, paramCacheKey, ref dateTimeType, null, null, null, func.Parameters[index], "IMPLICIT", out paramCacheKey); + paramType = dateTimeType; + } + return new { Expression = paramExpr, Type = paramType, CacheKey = paramCacheKey, Exception = paramException }; }) .ToList(); @@ -1081,7 +1119,7 @@ private static MethodInfo GetMethod(ExpressionCompilationContext context, Type t var method = correctParameterCount[0].Method; var parameters = correctParameterCount[0].Parameters; - DataTypeReference sourceType = null; + var parameterTypes = new Dictionary(); cacheKey = method.Name; if (correctParameterCount[0].Method.IsGenericMethodDefinition) @@ -1095,10 +1133,7 @@ private static MethodInfo GetMethod(ExpressionCompilationContext context, Type t for (var i = 0; i < genericArguments.Length; i++) { if (param.ParameterType == genericArguments[i] && genericArgumentValues[i] == null) - { genericArgumentValues[i] = paramTypes[i].ToNetType(out _); - sourceType = paramTypes[i]; - } } } @@ -1193,17 +1228,21 @@ private static MethodInfo GetMethod(ExpressionCompilationContext context, Type t if (paramType == typeof(DataTypeReference)) { - if (parameters[i].GetCustomAttribute() != null) + var sourceType = parameters[i].GetCustomAttribute(); + if (sourceType != null) { - cacheKey += $"(TYPE:{sourceType.ToSql()})"; + cacheKey += $"(TYPE:{parameterTypes[sourceType.SourceParameter].ToSql()})"; if (createExpression) { var paramsWithType = new Expression[paramExpressions.Length + 1]; paramExpressions.CopyTo(paramsWithType, 0); - paramsWithType[i] = Expression.Constant(sourceType); + paramsWithType[i] = Expression.Constant(parameterTypes[sourceType.SourceParameter]); paramExpressions = paramsWithType; } hiddenParams++; + + if (parameters[i].GetCustomAttribute() != null) + sqlType = parameterTypes[sourceType.SourceParameter]; } else { @@ -1242,6 +1281,8 @@ private static MethodInfo GetMethod(ExpressionCompilationContext context, Type t if (paramType != typeof(INullable) && !SqlTypeConverter.CanChangeTypeImplicit(paramTypes[i - hiddenParams], paramType.ToSqlType(primaryDataSource))) throw new NotSupportedQueryFragmentException(Sql4CdsError.TypeClash(i < paramOffset ? func : func.Parameters[i - paramOffset], paramTypes[i], paramType.ToSqlType(primaryDataSource))); + + parameterTypes[parameters[i].Name] = paramTypes[i - hiddenParams]; } for (var i = parameters.Length; i < paramTypes.Length; i++) @@ -1262,7 +1303,7 @@ private static MethodInfo GetMethod(ExpressionCompilationContext context, Type t if (arrayType == typeof(INullable)) arrayMembers.Add(Expression.Convert(paramExpressions[i], typeof(INullable))); else - arrayMembers.Add(SqlTypeConverter.Convert(paramExpressions[i], arrayType)); + arrayMembers.Add(SqlTypeConverter.Convert(paramExpressions[i], contextParam, arrayType)); } var arrayParam = Expression.NewArrayInit(arrayType, arrayMembers); @@ -1392,12 +1433,12 @@ private static Expression ToExpression(this FunctionCall func, ExpressionCompila for (var i = 0; i < parameters.Length; i++) { if (paramValues[i].Type != parameters[i].ParameterType) - paramValues[i] = SqlTypeConverter.Convert(paramValues[i], parameters[i].ParameterType); + paramValues[i] = SqlTypeConverter.Convert(paramValues[i], contextParam, parameters[i].ParameterType); } var expr = (Expression) Expression.Call(method, paramValues); - if (expr.Type == typeof(object) && parameters.Any(p => p.GetCustomAttribute() != null)) + if (parameters.Any(p => p.GetCustomAttribute() != null)) expr = Expression.Convert(expr, sqlType.ToNetType(out _)); return expr; @@ -1481,10 +1522,10 @@ private static Expression ToExpression(this InPredicate inPred, ExpressionCompil var convertedExprValue = exprValue; if (!exprType.IsSameAs(type)) - convertedExprValue = SqlTypeConverter.Convert(convertedExprValue, exprType, type); + convertedExprValue = SqlTypeConverter.Convert(convertedExprValue, contextParam, exprType, type); if (!comparisonType.IsSameAs(type)) - comparisonValue = SqlTypeConverter.Convert(comparisonValue, comparisonType, type); + comparisonValue = SqlTypeConverter.Convert(comparisonValue, contextParam, comparisonType, type); var comparison = inPred.NotDefined ? Expression.NotEqual(convertedExprValue, comparisonValue) : Expression.Equal(convertedExprValue, comparisonValue); @@ -1561,7 +1602,7 @@ private static Expression ToExpression(this BooleanIsNullExpression isNull, Expr cacheKey += " IS NULL"; } - value = createExpression ? SqlTypeConverter.Convert(value, typeof(SqlBoolean)) : null; + value = createExpression ? SqlTypeConverter.Convert(value, contextParam, typeof(SqlBoolean)) : null; sqlType = DataTypeHelpers.Bit; return value; } @@ -1587,7 +1628,7 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil if (!SqlTypeConverter.CanChangeTypeImplicit(valueType, stringType)) throw new NotSupportedQueryFragmentException(Sql4CdsError.TypeClash(like.FirstExpression, valueType, stringType)); - value = createExpression ? SqlTypeConverter.Convert(value, valueType, stringType) : null; + value = createExpression ? SqlTypeConverter.Convert(value, contextParam, valueType, stringType) : null; valueType = stringType; } @@ -1596,7 +1637,7 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil if (!SqlTypeConverter.CanChangeTypeImplicit(patternType, stringType)) throw new NotSupportedQueryFragmentException(Sql4CdsError.TypeClash(like.FirstExpression, patternType, stringType)); - pattern = createExpression ? SqlTypeConverter.Convert(pattern, patternType, stringType) : null; + pattern = createExpression ? SqlTypeConverter.Convert(pattern, contextParam, patternType, stringType) : null; patternType = stringType; } @@ -1605,7 +1646,7 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil if (!SqlTypeConverter.CanChangeTypeImplicit(escapeType, stringType)) throw new NotSupportedQueryFragmentException(Sql4CdsError.TypeClash(like.FirstExpression, escapeType, stringType)); - escape = createExpression ? SqlTypeConverter.Convert(escape, escapeType, stringType) : null; + escape = createExpression ? SqlTypeConverter.Convert(escape, contextParam, escapeType, stringType) : null; escapeType = stringType; } @@ -1894,7 +1935,7 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, Exp if (elseValue != null) { if (!elseType.IsSameAs(type)) - elseValue = SqlTypeConverter.Convert(elseValue, elseType, type); + elseValue = SqlTypeConverter.Convert(elseValue, contextParam, elseType, type); result = elseValue; } @@ -1911,17 +1952,17 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, Exp var caseType = caseTypes[i]; if (!valueType.IsSameAs(caseType)) - valueCopy = SqlTypeConverter.Convert(valueCopy, valueType, caseType); + valueCopy = SqlTypeConverter.Convert(valueCopy, contextParam, valueType, caseType); if (!whenType.IsSameAs(caseType)) - whenValue = SqlTypeConverter.Convert(whenValue, whenType, caseType); + whenValue = SqlTypeConverter.Convert(whenValue, contextParam, whenType, caseType); var comparison = Expression.Equal(valueCopy, whenValue); var returnValue = thenClauses[i].Expression; var returnType = thenClauses[i].Type; if (!returnType.IsSameAs(type)) - returnValue = SqlTypeConverter.Convert(returnValue, returnType, type); + returnValue = SqlTypeConverter.Convert(returnValue, contextParam, returnType, type); result = Expression.Condition(Expression.IsTrue(comparison), returnValue, result); } @@ -1996,7 +2037,7 @@ private static Expression ToExpression(this SearchedCaseExpression searchedCase, if (elseValue != null) { if (!elseType.IsSameAs(type)) - elseValue = SqlTypeConverter.Convert(elseValue, elseType, type); + elseValue = SqlTypeConverter.Convert(elseValue, contextParam, elseType, type); result = elseValue; } @@ -2014,10 +2055,10 @@ private static Expression ToExpression(this SearchedCaseExpression searchedCase, var returnValue = thenClauses[i].Expression; var returnType = thenClauses[i].Type; - whenValue = SqlTypeConverter.Convert(whenValue, whenType, bitType); + whenValue = SqlTypeConverter.Convert(whenValue, contextParam, whenType, bitType); whenValue = Expression.IsTrue(whenValue); - returnValue = SqlTypeConverter.Convert(returnValue, returnType, type); + returnValue = SqlTypeConverter.Convert(returnValue, contextParam, returnType, type); result = Expression.Condition(whenValue, returnValue, result); } @@ -2065,7 +2106,7 @@ private static Expression ToExpression(this BooleanNotExpression not, Expression [SqlDataTypeOption.Numeric] = typeof(SqlDecimal), [SqlDataTypeOption.NVarChar] = typeof(SqlString), [SqlDataTypeOption.Real] = typeof(SqlSingle), - [SqlDataTypeOption.SmallDateTime] = typeof(SqlDateTime), + [SqlDataTypeOption.SmallDateTime] = typeof(SqlSmallDateTime), [SqlDataTypeOption.SmallInt] = typeof(SqlInt16), [SqlDataTypeOption.SmallMoney] = typeof(SqlMoney), [SqlDataTypeOption.Text] = typeof(SqlString), @@ -2140,7 +2181,7 @@ public static bool IsType(this DataTypeReference type, SqlDataTypeOption sqlType [typeof(SqlGuid)] = DataTypeHelpers.UniqueIdentifier, [typeof(SqlEntityReference)] = DataTypeHelpers.EntityReference, [typeof(SqlDateTime2)] = DataTypeHelpers.DateTime2(7), - [typeof(SqlDateTimeOffset)] = DataTypeHelpers.DateTimeOffset, + [typeof(SqlDateTimeOffset)] = DataTypeHelpers.DateTimeOffset(7), [typeof(SqlDate)] = DataTypeHelpers.Date, [typeof(SqlTime)] = DataTypeHelpers.Time(7), [typeof(SqlXml)] = DataTypeHelpers.Xml, @@ -2173,10 +2214,10 @@ private static Expression ToExpression(this ConvertCall convert, ExpressionCompi sqlType = convert.DataType; - return Convert(context, value, valueType, valueCacheKey, ref sqlType, style, styleType, styleCacheKey, convert, "CONVERT", out cacheKey); + return Convert(context, contextParam, value, valueType, valueCacheKey, ref sqlType, style, styleType, styleCacheKey, convert, "CONVERT", out cacheKey); } - private static Expression Convert(ExpressionCompilationContext context, Expression value, DataTypeReference valueType, string valueCacheKey, ref DataTypeReference sqlType, Expression style, DataTypeReference styleType, string styleCacheKey, TSqlFragment expr, string cacheKeyRoot, out string cacheKey) + private static Expression Convert(ExpressionCompilationContext context, ParameterExpression contextParam, Expression value, DataTypeReference valueType, string valueCacheKey, ref DataTypeReference sqlType, Expression style, DataTypeReference styleType, string styleCacheKey, TSqlFragment expr, string cacheKeyRoot, out string cacheKey) { if (sqlType is SqlDataTypeReference sqlTargetType && sqlTargetType.SqlDataTypeOption.IsStringType()) @@ -2205,7 +2246,7 @@ private static Expression Convert(ExpressionCompilationContext context, Expressi cacheKey += ", " + styleCacheKey; cacheKey += ")"; - return value == null ? null : SqlTypeConverter.Convert(value, valueType, sqlType, style, styleType, expr); + return value == null ? null : SqlTypeConverter.Convert(value, contextParam, valueType, sqlType, style, styleType, expr); } private static Expression ToExpression(this CastCall cast, ExpressionCompilationContext context, ParameterExpression contextParam, ParameterExpression exprParam, bool createExpression, out DataTypeReference sqlType, out string cacheKey) @@ -2217,7 +2258,7 @@ private static Expression ToExpression(this CastCall cast, ExpressionCompilation sqlType = cast.DataType; - return Convert(context, value, valueType, valueCacheKey, ref sqlType, null, null, null, cast, "CAST", out cacheKey); + return Convert(context, contextParam, value, valueType, valueCacheKey, ref sqlType, null, null, null, cast, "CAST", out cacheKey); } private static readonly Regex _containsParser = new Regex("^\\S+( OR \\S+)*$", RegexOptions.IgnoreCase | RegexOptions.Compiled); @@ -2251,7 +2292,7 @@ private static Expression ToExpression(this FullTextPredicate fullText, Expressi if (!SqlTypeConverter.CanChangeTypeImplicit(colType, stringType)) throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidColumnForFullTextSearch(fullText.Columns[0])); - col = createExpression ? SqlTypeConverter.Convert(col, colType, stringType) : null; + col = createExpression ? SqlTypeConverter.Convert(col, contextParam, colType, stringType) : null; sqlType = DataTypeHelpers.Bit; if (fullText.Value is StringLiteral lit) @@ -2268,7 +2309,7 @@ private static Expression ToExpression(this FullTextPredicate fullText, Expressi if (!SqlTypeConverter.CanChangeTypeImplicit(valueType, stringType)) throw new NotSupportedQueryFragmentException(Sql4CdsError.TypeClash(fullText.Value, valueType, stringType)); - value = createExpression ? SqlTypeConverter.Convert(value, valueType, stringType) : null; + value = createExpression ? SqlTypeConverter.Convert(value, contextParam, valueType, stringType) : null; cacheKey = $"{colCacheKey} CONTAINS {valueCacheKey}"; return createExpression ? Expr.Call(() => Contains(Expr.Arg(), Expr.Arg()), col, value) : null; @@ -2710,6 +2751,50 @@ public static BooleanComparisonType TransitiveComparison(this BooleanComparisonT } } + /// + /// Returns the equivalent Fetch XML condition operator for this comparison + /// + /// The comparison type to convert to Fetch XML + /// The equivalent Fetch XML condition operator for this comparison + public static bool TryConvertToFetchXml(this BooleanComparisonType cmp, out @operator op) + { + switch (cmp) + { + case BooleanComparisonType.Equals: + case BooleanComparisonType.IsNotDistinctFrom: + op = @operator.eq; + break; + + case BooleanComparisonType.GreaterThan: + op = @operator.gt; + break; + + case BooleanComparisonType.GreaterThanOrEqualTo: + op = @operator.ge; + break; + + case BooleanComparisonType.LessThan: + op = @operator.lt; + break; + + case BooleanComparisonType.LessThanOrEqualTo: + op = @operator.le; + break; + + case BooleanComparisonType.NotEqualToBrackets: + case BooleanComparisonType.NotEqualToExclamation: + case BooleanComparisonType.IsDistinctFrom: + op = @operator.ne; + break; + + default: + op = @operator.eq; + return false; + } + + return true; + } + private static string GetTypeKey(DataTypeReference type, bool includeStringLength) { if (type is XmlDataTypeReference) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs index ae754431..81360f4e 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs @@ -280,16 +280,17 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont { PagesRetrieved = 0; - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); var schema = GetSchema(context); + var eec = new ExpressionExecutionContext(context); ApplyParameterValues(context); FindEntityNameGroupings(dataSource.Metadata); - VerifyFilterValueTypes(Entity.name, Entity.Items, dataSource); + VerifyFilterValueTypes(Entity.name, Entity.Items, dataSource, eec); var mainEntity = FetchXml.Items.OfType().Single(); var name = mainEntity.name; @@ -454,14 +455,14 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont } } - private void VerifyFilterValueTypes(string entityName, object[] items, DataSource dataSource) + private void VerifyFilterValueTypes(string entityName, object[] items, DataSource dataSource, ExpressionExecutionContext context) { if (items == null) return; // Check the value(s) supplied for filter values can be converted to the expected types foreach (var filter in items.OfType()) - VerifyFilterValueTypes(entityName, filter.Items, dataSource); + VerifyFilterValueTypes(entityName, filter.Items, dataSource, context); foreach (var condition in items.OfType()) { @@ -517,17 +518,17 @@ private void VerifyFilterValueTypes(string entityName, object[] items, DataSourc var conversion = SqlTypeConverter.GetConversion(DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), attrType); if (condition.value != null) - conversion(dataSource.DefaultCollation.ToSqlString(condition.value)); + conversion(dataSource.DefaultCollation.ToSqlString(condition.value), context); if (condition.Items != null) { foreach (var value in condition.Items) - conversion(dataSource.DefaultCollation.ToSqlString(value.Value)); + conversion(dataSource.DefaultCollation.ToSqlString(value.Value), context); } } foreach (var linkEntity in items.OfType()) - VerifyFilterValueTypes(linkEntity.name, linkEntity.Items, dataSource); + VerifyFilterValueTypes(linkEntity.name, linkEntity.Items, dataSource, context); } private void AddPagingFilters(filter filter, IQueryExecutionOptions options) @@ -1019,7 +1020,7 @@ public override IEnumerable GetSources() public override INodeSchema GetSchema(NodeCompilationContext context) { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); var fetchXmlString = FetchXmlString; @@ -1474,7 +1475,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext // Move partitionid filter to partitionId parameter // https://learn.microsoft.com/en-us/power-apps/developer/data-platform/use-elastic-tables?tabs=sdk#query-rows-of-an-elastic-table - var meta = context.DataSources[DataSource].Metadata[Entity.name]; + var meta = context.Session.DataSources[DataSource].Metadata[Entity.name]; if (meta.DataProviderId == DataProviders.ElasticDataProvider && Entity.Items != null) { @@ -1700,7 +1701,7 @@ private void RemoveIdentitySemiJoinLinkEntities(NodeCompilationContext context) { // If we've got a semi join link entity that matches to the parent entity by primary key, // remove the link entity and move the conditions to the parent entity - var dataSource = context.DataSources[DataSource]; + var dataSource = context.Session.DataSources[DataSource]; Entity.Items = RemoveIdentitySemiJoinLinkEntities(Entity.name, dataSource.Metadata, Entity.Items); } @@ -1981,7 +1982,7 @@ private void MergeNestedFilters(filter filter) public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); var schema = GetSchema(context); @@ -2060,7 +2061,7 @@ public override void AddRequiredColumns(NodeCompilationContext context, IList>(); _lastPageValues = new List(); @@ -2085,7 +2086,7 @@ public override void AddRequiredColumns(NodeCompilationContext context, IList GetVariablesInternal() return FindParameterizedConditions().Keys; } + internal IDataExecutionPlanNodeInternal FoldDmlSource(NodeCompilationContext context, IList hints, string logicalName, string[] requiredColumns, string[] keyAttributes) + { + if (Entity.name != logicalName || Entity.Items == null) + return this; + + // Can't produce any values except the primary key + var requiredAttributes = requiredColumns + .Select(col => col.SplitMultiPartIdentifier().Last()) + .ToArray(); + + if (requiredAttributes.Except(keyAttributes).Any()) + return this; + + if (Entity.GetLinkEntities().Any()) + return this; + + var filters = Entity.Items.OfType().ToList(); + + if (filters.Count != 1) + return this; + + if (!filters[0].Items.All(x => x is condition)) + return this; + + if (filters[0].Items.Cast().Any(c => c.ValueOf != null)) + return this; + + var dataSource = context.Session.DataSources[DataSource]; + var metadata = dataSource.Metadata[logicalName]; + var conditions = filters[0].Items.Cast().ToList(); + var ecc = new ExpressionCompilationContext(context, null, null); + var schema = GetSchema(context); + var constantScan = new ConstantScanNode + { + Alias = Alias + }; + + for (var i = 0; i < requiredColumns.Length; i++) + constantScan.Schema[requiredAttributes[i]] = schema.Schema[requiredColumns[i]]; + + // We can handle compound keys, but only if they are all ANDed together + if (keyAttributes.Length > 1 && filters[0].type == filterType.and) + { + var values = new Dictionary(); + + foreach (var keyAttribute in keyAttributes) + { + var condition = conditions.FirstOrDefault(c => c.attribute == keyAttribute); + if (condition == null) + return this; + + if (condition.@operator != @operator.eq) + return this; + + var attribute = metadata.Attributes.Single(a => a.LogicalName == condition.attribute); + values[condition.attribute] = attribute.GetDmlValue(condition.value, condition.IsVariable, ecc, dataSource); + } + + constantScan.Values.Add(values); + return constantScan; + } + + // We can also handle multiple values for a single key being ORed together + else if (keyAttributes.Length == 1 && + conditions.All(c => c.attribute == metadata.PrimaryIdAttribute) && + conditions.All(c => c.@operator == @operator.eq || c.@operator == @operator.@in) && + (conditions.Count == 1 || filters[0].type == filterType.or)) + { + foreach (var condition in conditions) + { + var attribute = metadata.Attributes.Single(a => a.LogicalName == condition.attribute); + + if (condition.@operator == @operator.eq) + { + constantScan.Values.Add(new Dictionary { [condition.attribute] = attribute.GetDmlValue(condition.value, condition.IsVariable, ecc, dataSource) }); + } + else if (condition.@operator == @operator.@in) + { + foreach (var value in condition.Items) + constantScan.Values.Add(new Dictionary { [condition.attribute] = attribute.GetDmlValue(value.Value, value.IsVariable, ecc, dataSource) }); + } + } + + return constantScan; + } + + return this; + } + public override string ToString() { return "FetchXML Query"; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs index 5ba5a272..813b840b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs @@ -1081,7 +1081,7 @@ private bool FoldInExistsToFetchXml(NodeCompilationContext context, IList l.alias).Intersect(rightFetch.Entity.GetLinkEntities().Select(l => l.alias), StringComparer.OrdinalIgnoreCase).Any() && (leftFetch.FetchXml.top == null || rightFetch.FetchXml.top == null)) { @@ -1210,7 +1210,7 @@ private bool FoldInExistsToFetchXml(NodeCompilationContext context, IList> Accessors { get; set; } + public IDictionary> Accessors { get; set; } public DataTypeReference SqlType { get; set; } public Type NetType { get; set; } public IComparable[] DataMemberOrder { get; set; } @@ -157,10 +157,11 @@ public override IEnumerable GetSources() protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); var resp = (RetrieveAllOptionSetsResponse)dataSource.Connection.Execute(new RetrieveAllOptionSetsRequest()); + var eec = new ExpressionExecutionContext(context); foreach (var optionset in resp.OptionSetMetadata) { @@ -174,7 +175,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont continue; } - converted[col.Key] = optionsetProp(optionset); + converted[col.Key] = optionsetProp(optionset, eec); } yield return converted; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs index 066b52e5..16349329 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs @@ -58,7 +58,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont executionContext.Entity = entity; foreach (var func in values.Values) - func.AggregateFunction.NextRecord(func.State); + func.AggregateFunction.NextRecord(func.State, executionContext); } foreach (var group in groups) @@ -98,7 +98,7 @@ Source is FetchXmlScan fetch && GroupBy.Count == 0 && Aggregates.Count == 1 && Aggregates.Single().Value.AggregateType == AggregateType.CountStar && - context.DataSources[fetch.DataSource].Metadata[fetch.Entity.name].DataProviderId == null && // RetrieveTotalRecordCountRequest is not valid for virtual entities + context.Session.DataSources[fetch.DataSource].Metadata[fetch.Entity.name].DataProviderId == null && // RetrieveTotalRecordCountRequest is not valid for virtual entities fetch.FetchXml.DataSource == null) // RetrieveTotalRecordCountRequest is not valid for archive data { var count = new RetrieveTotalRecordCountNode { DataSource = fetch.DataSource, EntityName = fetch.Entity.name }; @@ -269,7 +269,7 @@ Source is FetchXmlScan fetch && } } - var metadata = context.DataSources[fetchXml.DataSource].Metadata; + var metadata = context.Session.DataSources[fetchXml.DataSource].Metadata; // Aggregates are not supported on archive data if (fetchXml.FetchXml.DataSource == "retained") diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs index 2f451740..dc423336 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs @@ -18,8 +18,8 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan class IndexSpoolNode : BaseDataNode, ISingleSourceExecutionPlanNode, ISpoolProducerNode { private IDictionary> _hashTable; - private Func _keySelector; - private Func _seekSelector; + private Func _keySelector; + private Func _seekSelector; private Stack _stack; [Browsable(false)] @@ -219,7 +219,7 @@ private IDataExecutionPlanNodeInternal FoldCTEToFetchXml(NodeCompilationContext return this; } - var metadata = context.DataSources[anchorFetchXml.DataSource].Metadata[anchorFetchXml.Entity.name]; + var metadata = context.Session.DataSources[anchorFetchXml.DataSource].Metadata[anchorFetchXml.Entity.name]; var hierarchicalRelationship = metadata.OneToManyRelationships.SingleOrDefault(r => r.IsHierarchical == true); if (hierarchicalRelationship == null || @@ -411,15 +411,22 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont if (WithStack) return ExecuteInternalWithStack(context); + var eec = new ExpressionExecutionContext(context); + // Build an internal hash table of the source indexed by the key column if (_hashTable == null) { _hashTable = Source.Execute(context) - .GroupBy(e => _keySelector((INullable)e[KeyColumn])) + .GroupBy(e => + { + eec.Entity = e; + return _keySelector((INullable)e[KeyColumn], eec); + }) .ToDictionary(g => g.Key, g => g.ToList()); } - var keyValue = _seekSelector((INullable)context.ParameterValues[SeekValue]); + eec.Entity = null; + var keyValue = _seekSelector((INullable)context.ParameterValues[SeekValue], eec); if (!_hashTable.TryGetValue(keyValue, out var matches)) return Array.Empty(); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs index 491d19e2..89e70332 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs @@ -85,7 +85,7 @@ private bool GetIgnoreDuplicateKey(NodeCompilationContext context, IList entities; EntityMetadata meta; Dictionary attributes; - Dictionary> attributeAccessors; - Func primaryIdAccessor; + Dictionary> attributeAccessors; + Func primaryIdAccessor; + var eec = new ExpressionExecutionContext(context); using (_timer.Run()) { @@ -146,7 +147,11 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect context.Options, entities, meta, - entity => CreateInsertRequest(meta, entity, attributeAccessors, primaryIdAccessor, attributes), + entity => + { + eec.Entity = entity; + return CreateInsertRequest(meta, eec, attributeAccessors, primaryIdAccessor, attributes); + }, new OperationNames { InProgressUppercase = "Inserting", @@ -173,13 +178,13 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect } } - private OrganizationRequest CreateInsertRequest(EntityMetadata meta, Entity entity, Dictionary> attributeAccessors, Func primaryIdAccessor, Dictionary attributes) + private OrganizationRequest CreateInsertRequest(EntityMetadata meta, ExpressionExecutionContext context, Dictionary> attributeAccessors, Func primaryIdAccessor, Dictionary attributes) { // Special cases for intersect entities if (LogicalName == "listmember") { - var listId = GetNotNull("listid", entity, attributeAccessors); - var entityId = GetNotNull("entityid", entity, attributeAccessors); + var listId = GetNotNull("listid", context, attributeAccessors); + var entityId = GetNotNull("entityid", context, attributeAccessors); return new AddMemberListRequest { @@ -194,8 +199,8 @@ private OrganizationRequest CreateInsertRequest(EntityMetadata meta, Entity enti // the relationship that this is the intersect entity for var relationship = meta.ManyToManyRelationships.Single(); - var e1 = GetNotNull(relationship.Entity1IntersectAttribute, entity, attributeAccessors); - var e2 = GetNotNull(relationship.Entity2IntersectAttribute, entity, attributeAccessors); + var e1 = GetNotNull(relationship.Entity1IntersectAttribute, context, attributeAccessors); + var e2 = GetNotNull(relationship.Entity2IntersectAttribute, context, attributeAccessors); return new AssociateRequest { @@ -208,9 +213,9 @@ private OrganizationRequest CreateInsertRequest(EntityMetadata meta, Entity enti if (LogicalName == "principalobjectaccess") { // Insert into principalobjectaccess is equivalent to a share - var objectId = GetNotNull("objectid", entity, attributeAccessors); - var principalId = GetNotNull("principalid", entity, attributeAccessors); - var accessRightsMask = GetNotNull("accessrightsmask", entity, attributeAccessors); + var objectId = GetNotNull("objectid", context, attributeAccessors); + var principalId = GetNotNull("principalid", context, attributeAccessors); + var accessRightsMask = GetNotNull("accessrightsmask", context, attributeAccessors); return new GrantAccessRequest { @@ -226,7 +231,7 @@ private OrganizationRequest CreateInsertRequest(EntityMetadata meta, Entity enti var insert = new Entity(LogicalName); if (primaryIdAccessor != null) - insert.Id = (Guid) primaryIdAccessor(entity); + insert.Id = (Guid) primaryIdAccessor(context); foreach (var attributeAccessor in attributeAccessors) { @@ -238,7 +243,7 @@ private OrganizationRequest CreateInsertRequest(EntityMetadata meta, Entity enti if (!String.IsNullOrEmpty(attr.AttributeOf)) continue; - var value = attributeAccessor.Value(entity); + var value = attributeAccessor.Value(context); insert[attr.LogicalName] = value; } @@ -246,9 +251,9 @@ private OrganizationRequest CreateInsertRequest(EntityMetadata meta, Entity enti return new CreateRequest { Target = insert }; } - private T GetNotNull(string attribute, Entity entity, Dictionary> attributeAccessors) + private T GetNotNull(string attribute, ExpressionExecutionContext context, Dictionary> attributeAccessors) { - var value = attributeAccessors[attribute](entity); + var value = attributeAccessors[attribute](context); if (value == null) throw new QueryExecutionException(Sql4CdsError.NotNullInsert(new Identifier { Value = attribute }, new Identifier { Value = LogicalName }, "Insert")); @@ -285,7 +290,7 @@ protected override bool FilterErrors(NodeExecutionContext context, OrganizationR logMessage += $". The duplicate values were ({create.Target.Id})"; } - context.Log(new Sql4CdsError(10, LineNumber, 0, null, context.DataSources[DataSource].Name, 0, logMessage)); + context.Log(new Sql4CdsError(10, LineNumber, 0, null, context.Session.DataSources[DataSource].Name, 0, logMessage)); return false; } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs index c399ce66..c9600e75 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs @@ -26,7 +26,7 @@ class MetadataProperty { public string SqlName { get; set; } public string PropertyName { get; set; } - public Func Accessor { get; set; } + public Func Accessor { get; set; } public DataTypeReference SqlType { get; set; } public Type Type { get; set; } public IComparable[] DataMemberOrder { get; set; } @@ -36,7 +36,7 @@ class AttributeProperty { public string SqlName { get; set; } public string PropertyName { get; set; } - public IDictionary> Accessors { get; set; } + public IDictionary> Accessors { get; set; } public DataTypeReference SqlType { get; set; } public Type Type { get; set; } public bool IsNullable { get; set; } @@ -552,7 +552,7 @@ private CompiledExpression CompileExpression(ScalarExpression expression, Expres var targetNetType = targetSqlType.ToNetType(out _); var netConverter = SqlTypeConverter.GetConversion(targetNetType, targetValueType); - return new CompiledExpression(expression, context => netConverter(sqlConverter((INullable)expr(context)))); + return new CompiledExpression(expression, context => netConverter(sqlConverter((INullable)expr(context), context), context)); } public override INodeSchema GetSchema(NodeCompilationContext context) @@ -843,10 +843,11 @@ internal static DataTypeReference GetPropertyType(Type propType) return SqlTypeConverter.NetToSqlType(propType).ToSqlType(null); } - internal static Func GetPropertyAccessor(PropertyInfo prop, Type targetType) + internal static Func GetPropertyAccessor(PropertyInfo prop, Type targetType) { var rawParam = Expression.Parameter(typeof(object)); - var param = SqlTypeConverter.Convert(rawParam, prop.DeclaringType); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); + var param = SqlTypeConverter.Convert(rawParam, contextParam, prop.DeclaringType); var value = (Expression)Expression.Property(param, prop); // Extract base value from complex types @@ -869,13 +870,13 @@ internal static Func GetPropertyAccessor(PropertyInfo prop, Type value = Expression.Property(value, nameof(Label.UserLocalizedLabel)); if (value.Type == typeof(LocalizedLabel)) - value = Expression.Condition(Expression.Equal(value, Expression.Constant(null)), SqlTypeConverter.Convert(Expression.Constant(null), typeof(string)), Expression.Property(value, nameof(LocalizedLabel.Label))); + value = Expression.Condition(Expression.Equal(value, Expression.Constant(null)), Expression.Constant(null, typeof(string)), Expression.Property(value, nameof(LocalizedLabel.Label))); if (value.Type.IsEnum) value = Expression.Call(value, nameof(Enum.ToString), Array.Empty()); if (typeof(MetadataBase).IsAssignableFrom(value.Type)) - value = Expression.Condition(Expression.Equal(value, Expression.Constant(null)), SqlTypeConverter.Convert(Expression.Constant(null), typeof(Guid?)), Expression.Property(value, nameof(MetadataBase.MetadataId))); + value = Expression.Condition(Expression.Equal(value, Expression.Constant(null)), Expression.Constant(null, typeof(Guid?)), Expression.Property(value, nameof(MetadataBase.MetadataId))); var directConversionType = SqlTypeConverter.NetToSqlType(value.Type); @@ -887,10 +888,10 @@ internal static Func GetPropertyAccessor(PropertyInfo prop, Type if (value.Type == typeof(string) && directConversionType == typeof(SqlString) && targetType == typeof(SqlString)) converted = Expr.Call(() => ApplyCollation(Expr.Arg()), value); else - converted = SqlTypeConverter.Convert(value, directConversionType); + converted = SqlTypeConverter.Convert(value, contextParam, directConversionType); if (targetType != directConversionType) - converted = SqlTypeConverter.Convert(converted, targetType); + converted = SqlTypeConverter.Convert(converted, contextParam, targetType); // Return null literal if final value is null if (!value.Type.IsValueType) @@ -904,7 +905,7 @@ internal static Func GetPropertyAccessor(PropertyInfo prop, Type // Compile the function value = Expr.Box(value); - var func = (Func) Expression.Lambda(value, rawParam).Compile(); + var func = (Func) Expression.Lambda(value, rawParam, contextParam).Compile(); return func; } @@ -1000,7 +1001,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont query.Properties.PropertyNames.Add(nameof(EntityMetadata.Keys)); } - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); var resp = (RetrieveMetadataChangesResponse)dataSource.Connection.Execute(new RetrieveMetadataChangesRequest { Query = query }); @@ -1009,6 +1010,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont var oneToManyRelationshipProps = typeof(OneToManyRelationshipMetadata).GetProperties().ToDictionary(p => p.Name); var manyToManyRelationshipProps = typeof(ManyToManyRelationshipMetadata).GetProperties().ToDictionary(p => p.Name); var keyProps = typeof(EntityKeyMetadata).GetProperties().ToDictionary(p => p.Name); + var eec = new ExpressionExecutionContext(context); var results = resp.EntityMetadata.Select(e => new { Entity = e, Attribute = (AttributeMetadata)null, Relationship = (RelationshipMetadataBase)null, Key = (EntityKeyMetadata)null }); @@ -1037,7 +1039,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont converted.Id = result.Entity.MetadataId ?? Guid.Empty; foreach (var prop in _entityCols) - converted[prop.Key] = prop.Value.Accessor(result.Entity); + converted[prop.Key] = prop.Value.Accessor(result.Entity, eec); } if (MetadataSource.HasFlag(MetadataSource.Attribute)) @@ -1053,7 +1055,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont continue; } - converted[prop.Key] = accessor(result.Attribute); + converted[prop.Key] = accessor(result.Attribute, eec); } } @@ -1063,7 +1065,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont converted.Id = result.Relationship.MetadataId ?? Guid.Empty; foreach (var prop in _oneToManyRelationshipCols) - converted[prop.Key] = prop.Value.Accessor(result.Relationship); + converted[prop.Key] = prop.Value.Accessor(result.Relationship, eec); } if (MetadataSource.HasFlag(MetadataSource.ManyToOneRelationship)) @@ -1072,7 +1074,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont converted.Id = result.Relationship.MetadataId ?? Guid.Empty; foreach (var prop in _manyToOneRelationshipCols) - converted[prop.Key] = prop.Value.Accessor(result.Relationship); + converted[prop.Key] = prop.Value.Accessor(result.Relationship, eec); } if (MetadataSource.HasFlag(MetadataSource.ManyToManyRelationship)) @@ -1081,7 +1083,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont converted.Id = result.Relationship.MetadataId ?? Guid.Empty; foreach (var prop in _manyToManyRelationshipCols) - converted[prop.Key] = prop.Value.Accessor(result.Relationship); + converted[prop.Key] = prop.Value.Accessor(result.Relationship, eec); } if (MetadataSource.HasFlag(MetadataSource.Key)) @@ -1090,7 +1092,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont converted.Id = result.Key.MetadataId ?? Guid.Empty; foreach (var prop in _keyCols) - converted[prop.Key] = prop.Value.Accessor(result.Key); + converted[prop.Key] = prop.Value.Accessor(result.Key, eec); } foreach (var attr in converted.Attributes.ToList()) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NestedLoopNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NestedLoopNode.cs index 91436897..3273e85b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NestedLoopNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NestedLoopNode.cs @@ -64,7 +64,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont innerParameterTypes[kvp.Value] = leftSchema.Schema[kvp.Key].Type; } - rightCompilationContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes, context.Log); + rightCompilationContext = new NodeCompilationContext(context.Session, context.Options, innerParameterTypes, context.Log); } var innerParameters = context.ParameterValues; @@ -85,7 +85,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont var hasRight = false; - foreach (var right in RightSource.Execute(new NodeExecutionContext(context.DataSources, context.Options, innerParameterTypes, innerParameters, context.Log))) + foreach (var right in RightSource.Execute(new NodeExecutionContext(context, innerParameterTypes, innerParameters))) { if (rightSchema == null) { @@ -183,7 +183,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext LeftSource.Parent = this; var innerParameterTypes = GetInnerParameterTypes(leftSchema, context.ParameterTypes); - var innerContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes, context.Log); + var innerContext = new NodeCompilationContext(context.Session, context.Options, innerParameterTypes, context.Log); var rightSchema = RightSource.GetSchema(innerContext); RightSource = RightSource.FoldQuery(innerContext, hints); RightSource.Parent = this; @@ -329,7 +329,7 @@ public override void AddRequiredColumns(NodeCompilationContext context, IList OutputRightSchema) .Concat(criteriaCols) @@ -346,7 +346,7 @@ protected override INodeSchema GetRightSchema(NodeCompilationContext context) { var leftSchema = LeftSource.GetSchema(context); var innerParameterTypes = GetInnerParameterTypes(leftSchema, context.ParameterTypes); - var innerContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes, context.Log); + var innerContext = new NodeCompilationContext(context.Session, context.Options, innerParameterTypes, context.Log); return RightSource.GetSchema(innerContext); } @@ -356,7 +356,7 @@ protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationConte ParseEstimate(leftEstimate, out var leftMin, out var leftMax, out var leftIsRange); var leftSchema = LeftSource.GetSchema(context); var innerParameterTypes = GetInnerParameterTypes(leftSchema, context.ParameterTypes); - var innerContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes, context.Log); + var innerContext = new NodeCompilationContext(context.Session, context.Options, innerParameterTypes, context.Log); var rightEstimate = RightSource.EstimateRowsOut(innerContext); ParseEstimate(rightEstimate, out var rightMin, out var rightMax, out var rightIsRange); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OffsetFetchNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OffsetFetchNode.cs index afc53e5e..2cb5924d 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OffsetFetchNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OffsetFetchNode.cs @@ -36,8 +36,8 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont { var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); var expressionExecutionContext = new ExpressionExecutionContext(context); - var offset = SqlTypeConverter.ChangeType(Offset.Compile(expressionCompilationContext)(expressionExecutionContext)); - var fetch = SqlTypeConverter.ChangeType(Fetch.Compile(expressionCompilationContext)(expressionExecutionContext)); + var offset = SqlTypeConverter.ChangeType(Offset.Compile(expressionCompilationContext)(expressionExecutionContext), expressionExecutionContext); + var fetch = SqlTypeConverter.ChangeType(Fetch.Compile(expressionCompilationContext)(expressionExecutionContext), expressionExecutionContext); if (offset < 0) throw new QueryExecutionException(Sql4CdsError.Create(10742, null)); @@ -65,7 +65,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext Source = Source.FoldQuery(context, hints); Source.Parent = this; - var expressionCompilationContext = new ExpressionCompilationContext(context.DataSources, context.Options, null, null, null); + var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); if (!Offset.IsConstantValueExpression(expressionCompilationContext, out var offsetLiteral) || !Fetch.IsConstantValueExpression(expressionCompilationContext, out var fetchLiteral)) @@ -78,8 +78,8 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext return this; var expressionExecutionContext = new ExpressionExecutionContext(expressionCompilationContext); - var offset = SqlTypeConverter.ChangeType(offsetLiteral.Compile(expressionCompilationContext)(expressionExecutionContext)); - var count = SqlTypeConverter.ChangeType(fetchLiteral.Compile(expressionCompilationContext)(expressionExecutionContext)); + var offset = SqlTypeConverter.ChangeType(offsetLiteral.Compile(expressionCompilationContext)(expressionExecutionContext), expressionExecutionContext); + var count = SqlTypeConverter.ChangeType(fetchLiteral.Compile(expressionCompilationContext)(expressionExecutionContext), expressionExecutionContext); var page = offset / count; if (page * count == offset && count <= 5000) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OpenJsonNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OpenJsonNode.cs index 86c3d0fd..0cfaf933 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OpenJsonNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OpenJsonNode.cs @@ -15,7 +15,7 @@ class OpenJsonNode : BaseDataNode private Func _jsonExpression; private Func _pathExpression; private Collation _jsonCollation; - private List> _conversions; + private List> _conversions; private static readonly Collation _keyCollation; @@ -279,7 +279,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont } else { - yield return TokenToEntity(prop.Value, schema, mappings); + yield return TokenToEntity(prop.Value, schema, mappings, eec); } } } @@ -304,7 +304,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont } else { - yield return TokenToEntity(item, schema, mappings); + yield return TokenToEntity(item, schema, mappings, eec); } i++; @@ -319,7 +319,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont } } - private Entity TokenToEntity(JsonElement token, INodeSchema schema, JsonPath[] mappings) + private Entity TokenToEntity(JsonElement token, INodeSchema schema, JsonPath[] mappings, ExpressionExecutionContext context) { var result = new Entity(); @@ -373,7 +373,7 @@ private Entity TokenToEntity(JsonElement token, INodeSchema schema, JsonPath[] m } var sqlStringValue = Collation.USEnglish.ToSqlString(stringValue); - var sqlValue = _conversions[i](sqlStringValue); + var sqlValue = _conversions[i](sqlStringValue, context); result[PrefixWithAlias(Schema[i].ColumnDefinition.ColumnIdentifier.Value, null)] = sqlValue; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedAggregateNode.cs index bbc31602..393c0303 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedAggregateNode.cs @@ -68,7 +68,7 @@ private int GetMaxDOP(NodeCompilationContext context, IList query if (fetchXmlNode.DataSource == null) return 1; - if (!context.DataSources.TryGetValue(fetchXmlNode.DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(fetchXmlNode.DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Unknown datasource"); return ParallelismHelper.GetMaxDOP(dataSource, context, queryHints); @@ -96,7 +96,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont var fetchXmlNode = (FetchXmlScan)Source.Clone(); var name = fetchXmlNode.Entity.name; - var meta = context.DataSources[fetchXmlNode.DataSource].Metadata[name]; + var meta = context.Session.DataSources[fetchXmlNode.DataSource].Metadata[name]; context.Options.Progress(0, $"Partitioning {GetDisplayName(0, meta)}..."); // Get the minimum and maximum primary keys from the source @@ -143,7 +143,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont SplitPartition(fullRange); // Multi-thread where possible - var org = context.DataSources[fetchXmlNode.DataSource].Connection; + var org = context.Session.DataSources[fetchXmlNode.DataSource].Connection; _lock = new object(); #if NETCOREAPP @@ -172,10 +172,10 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont [fetchXmlNode.DataSource] = new DataSource { Connection = svc?.Clone() ?? org, - Metadata = context.DataSources[fetchXmlNode.DataSource].Metadata, + Metadata = context.Session.DataSources[fetchXmlNode.DataSource].Metadata, Name = fetchXmlNode.DataSource, - TableSizeCache = context.DataSources[fetchXmlNode.DataSource].TableSizeCache, - MessageCache = context.DataSources[fetchXmlNode.DataSource].MessageCache + TableSizeCache = context.Session.DataSources[fetchXmlNode.DataSource].TableSizeCache, + MessageCache = context.Session.DataSources[fetchXmlNode.DataSource].MessageCache } }; @@ -199,7 +199,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont partitionParameterValues[kvp.Key] = kvp.Value; } - var partitionContext = new NodeExecutionContext(context.DataSources, context.Options, context.ParameterTypes, partitionParameterValues, context.Log); + var partitionContext = new NodeExecutionContext(context, context.ParameterTypes, partitionParameterValues); return new { Context = partitionContext, Fetch = fetch }; }, @@ -335,7 +335,7 @@ private void ExecuteAggregate(NodeExecutionContext context, ExpressionExecutionC expressionContext.Entity = entity; foreach (var func in values.Values) - func.AggregateFunction.NextPartition(func.State); + func.AggregateFunction.NextPartition(func.State, expressionContext); } } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs index 731e3522..87348327 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs @@ -29,7 +29,7 @@ class RetrieveTotalRecordCountNode : BaseDataNode protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); var count = ((RetrieveTotalRecordCountResponse)dataSource.Connection.Execute(new RetrieveTotalRecordCountRequest { EntityNames = new[] { EntityName } })).EntityRecordCountCollection[EntityName]; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RevertNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RevertNode.cs index ee2b69e2..0fbfa131 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RevertNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RevertNode.cs @@ -63,7 +63,7 @@ public void Execute(NodeExecutionContext context, out int recordsAffected, out s { using (_timer.Run()) { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); #if NETCOREAPP diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs index 1f93ac0d..ef2f2528 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs @@ -111,7 +111,7 @@ internal static void FoldFetchXmlColumns(IDataExecutionPlanNode source, List _dateFormat; + private int _executionCount; + private readonly Timer _timer = new Timer(); + + public SetDateFormatNode(ScalarExpression expression) + { + DateFormat = expression; + } + + [Category("Settings")] + [Description("The date format to use when parsing dates")] + public ScalarExpression DateFormat { get; set; } + + [Browsable(false)] + public string Sql { get; set; } + + [Browsable(false)] + public int Index { get; set; } + + [Browsable(false)] + public int Length { get; set; } + + [Browsable(false)] + public int LineNumber { get; set; } + + public override int ExecutionCount => _executionCount; + + public override TimeSpan Duration => _timer.Duration; + + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) + { + } + + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) + { + _dateFormat = DateFormat.Compile(new ExpressionCompilationContext(context, null, null)); + return new[] { this }; + } + + public void Execute(NodeExecutionContext context, out int recordsAffected, out string message) + { + _executionCount++; + + using (_timer.Run()) + { + var formatString = ((SqlString)_dateFormat(new ExpressionExecutionContext(context))).Value; + + if (!Enum.TryParse(formatString, out var dateFormat)) + throw new QueryExecutionException(Sql4CdsError.InvalidDateFormat(DateFormat, formatString)); + + context.Session.DateFormat = dateFormat; + } + + recordsAffected = -1; + message = null; + } + + public override IEnumerable GetSources() + { + return Enumerable.Empty(); + } + + public object Clone() + { + return new SetDateFormatNode(DateFormat) + { + _dateFormat = _dateFormat, + }; + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs index 57807164..2b5588af 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs @@ -266,10 +266,10 @@ private IDataExecutionPlanNodeInternal FoldSorts(NodeCompilationContext context) var canFold = fetchXml != null; DataSource dataSource = null; - if (canFold && fetchXml.RequiresCustomPaging(context.DataSources)) + if (canFold && fetchXml.RequiresCustomPaging(context.Session.DataSources)) canFold = false; - if (canFold && !context.DataSources.TryGetValue(fetchXml.DataSource, out dataSource)) + if (canFold && !context.Session.DataSources.TryGetValue(fetchXml.DataSource, out dataSource)) throw new QueryExecutionException("Missing datasource " + fetchXml.DataSource); var isAuditOrElastic = fetchXml != null && (fetchXml.Entity.name == "audit" || dataSource != null && dataSource.Metadata[fetchXml.Entity.name].DataProviderId == DataProviders.ElasticDataProvider); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs index 187ef022..1ac26485 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs @@ -9,8 +9,12 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using MarkMpn.Sql4Cds.Engine.FetchXml; +using MarkMpn.Sql4Cds.Engine.Visitors; using Microsoft.SqlServer.TransactSql.ScriptDom; using Microsoft.Xrm.Sdk; +using Microsoft.Xrm.Sdk.Metadata; + #if NETCOREAPP using Microsoft.PowerPlatform.Dataverse.Client; #else @@ -53,6 +57,8 @@ public SqlNode() { } [Browsable(false)] public HashSet Parameters { get; private set; } = new HashSet(StringComparer.OrdinalIgnoreCase); + internal SelectStatement SelectStatement { get; set; } + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } @@ -65,7 +71,7 @@ public DbDataReader Execute(NodeExecutionContext context, CommandBehavior behavi { try { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); if (context.Options.UseLocalTimeZone) @@ -89,7 +95,7 @@ public DbDataReader Execute(NodeExecutionContext context, CommandBehavior behavi var cmd = con.CreateCommand(); cmd.CommandTimeout = (int)TimeSpan.FromMinutes(2).TotalSeconds; - cmd.CommandText = ApplyCommandBehavior(Sql, behavior, context.Options); + cmd.CommandText = ApplyCommandBehavior(Sql, behavior, context); foreach (var paramValue in context.ParameterValues) { @@ -145,13 +151,19 @@ public DbDataReader Execute(NodeExecutionContext context, CommandBehavior behavi } } - internal static string ApplyCommandBehavior(string sql, CommandBehavior behavior, IQueryExecutionOptions options) + internal static string ApplyCommandBehavior(string sql, CommandBehavior behavior, NodeExecutionContext context) { + if (context.Session.DateFormat != DateFormat.mdy) + { + // mdy is the default format for the TDS Endpoint, so we need to switch it as necessary + sql = "SET DATEFORMAT " + context.Session.DateFormat.ToString() + ";\r\n" + sql; + } + if (behavior == CommandBehavior.Default) return sql; // TDS Endpoint doesn't support command behavior flags, so fake them by modifying the SQL query - var dom = new TSql160Parser(options.QuotedIdentifiers); + var dom = new TSql160Parser(context.Options.QuotedIdentifiers); var script = (TSqlScript) dom.Parse(new StringReader(sql), out _); if (behavior.HasFlag(CommandBehavior.SchemaOnly)) @@ -225,6 +237,144 @@ public override IEnumerable GetSources() return Array.Empty(); } + internal IExecutionPlanNodeInternal FoldDmlSource(NodeCompilationContext context, IList hints, string logicalName, string[] requiredColumns, string[] keyAttributes) + { + if (!(SelectStatement?.QueryExpression is QuerySpecification querySpec)) + return this; + + if (querySpec.FromClause == null || querySpec.FromClause.TableReferences.Count != 1 || !(querySpec.FromClause.TableReferences[0] is NamedTableReference table)) + return this; + + if (table.SchemaObject.BaseIdentifier.Value != logicalName) + return this; + + if (querySpec.WhereClause == null || querySpec.WhereClause.SearchCondition == null) + return this; + + var filterVisitor = new SimpleFilterVisitor(); + querySpec.WhereClause.SearchCondition.Accept(filterVisitor); + + if (filterVisitor.BinaryType == null) + return this; + + var dataSource = context.Session.DataSources[DataSource]; + var metadata = dataSource.Metadata[logicalName]; + var conditions = filterVisitor.Conditions.ToList(); + var ecc = new ExpressionCompilationContext(context, null, null); + + if (!TryGetDmlSchema(dataSource, metadata, querySpec, new ExpressionCompilationContext(context, null, null), out var schema, out var literalValues)) + return this; + + // Every column must either be a literal or a key attribute + if (requiredColumns.Except(literalValues.Keys).Except(keyAttributes).Any()) + return this; + + var constantScan = new ConstantScanNode(); + + foreach (var col in requiredColumns) + constantScan.Schema[col.SplitMultiPartIdentifier().Last()] = new ColumnDefinition(schema[col], true, false); + + // We can handle compound keys, but only if they are all ANDed together + if (keyAttributes.Length > 1 && filterVisitor.BinaryType == BooleanBinaryExpressionType.And) + { + var values = new Dictionary(); + + foreach (var col in requiredColumns) + { + if (literalValues.ContainsKey(col)) + continue; + + var condition = conditions.FirstOrDefault(c => c.attribute == col.SplitMultiPartIdentifier().Last()); + if (condition == null) + return this; + + if (condition.@operator != @operator.eq) + return this; + + var attribute = metadata.Attributes.Single(a => a.LogicalName == condition.attribute); + values[condition.attribute] = attribute.GetDmlValue(condition.value, condition.IsVariable, ecc, dataSource); + } + + constantScan.Values.Add(values); + + AddLiteralValues(constantScan, literalValues); + return constantScan; + } + + // We can also handle multiple values for a single key being ORed together + else if (keyAttributes.Length == 1 && + conditions.All(c => c.attribute == metadata.PrimaryIdAttribute) && + conditions.All(c => c.@operator == @operator.eq || c.@operator == @operator.@in) && + (conditions.Count == 1 || filterVisitor.BinaryType == BooleanBinaryExpressionType.Or)) + { + foreach (var condition in conditions) + { + var attribute = metadata.Attributes.Single(a => a.LogicalName == condition.attribute); + + if (condition.@operator == @operator.eq) + { + constantScan.Values.Add(new Dictionary { [condition.attribute] = attribute.GetDmlValue(condition.value, condition.IsVariable, ecc, dataSource) }); + } + else if (condition.@operator == @operator.@in) + { + foreach (var value in condition.Items) + constantScan.Values.Add(new Dictionary { [condition.attribute] = attribute.GetDmlValue(value.Value, value.IsVariable, ecc, dataSource) }); + } + } + + AddLiteralValues(constantScan, literalValues); + return constantScan; + } + + return this; + } + + private void AddLiteralValues(ConstantScanNode constantScan, Dictionary literalValues) + { + foreach (var row in constantScan.Values) + { + foreach (var value in literalValues) + { + row[value.Key] = value.Value; + } + } + } + + private bool TryGetDmlSchema(DataSource dataSource, EntityMetadata metadata, QuerySpecification querySpec, ExpressionCompilationContext context, out Dictionary schema, out Dictionary literalValues) + { + schema = new Dictionary(StringComparer.OrdinalIgnoreCase); + literalValues = new Dictionary(StringComparer.OrdinalIgnoreCase); + + foreach (var select in querySpec.SelectElements) + { + if (!(select is SelectScalarExpression scalar)) + return false; + + if (!scalar.Expression.GetColumns().Any() && scalar.ColumnName?.Value != null) + { + scalar.Expression.GetType(context, out var literalType); + schema[scalar.ColumnName.Value] = literalType; + literalValues[scalar.ColumnName.Value] = scalar.Expression; + continue; + } + + if (!(scalar.Expression is ColumnReferenceExpression col)) + return false; + + var attribute = metadata.Attributes.SingleOrDefault(a => a.LogicalName.Equals(col.MultiPartIdentifier.Identifiers.Last().Value, StringComparison.OrdinalIgnoreCase)); + + if (attribute == null) + return false; + + var type = attribute.GetAttributeSqlType(dataSource, false); + var name = scalar.ColumnName?.Value ?? attribute.LogicalName; + + schema[name] = type; + } + + return true; + } + public override string ToString() { return "TDS Endpoint"; @@ -239,7 +389,8 @@ public object Clone() Index = Index, Length = Length, LineNumber = LineNumber, - Parameters = Parameters + Parameters = Parameters, + SelectStatement = SelectStatement }; } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs index 42e9b8fa..aff5bc84 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs @@ -52,8 +52,8 @@ class SqlTypeConverter private static readonly IDictionary _nullValues; private static readonly CultureInfo _hijriCulture; - private static ConcurrentDictionary> _conversions; - private static ConcurrentDictionary> _sqlConversions; + private static ConcurrentDictionary> _conversions; + private static ConcurrentDictionary> _sqlConversions; private static Dictionary _netToSqlTypeConversions; private static Dictionary> _netToSqlTypeConversionFuncs; private static Dictionary _sqlToNetTypeConversions; @@ -80,6 +80,7 @@ static SqlTypeConverter() [typeof(SqlDate)] = SqlDate.Null, [typeof(SqlDateTime2)] = SqlDateTime2.Null, [typeof(SqlDateTimeOffset)] = SqlDateTimeOffset.Null, + [typeof(SqlSmallDateTime)] = SqlSmallDateTime.Null, [typeof(SqlTime)] = SqlTime.Null, [typeof(SqlXml)] = SqlXml.Null, [typeof(SqlVariant)] = SqlVariant.Null @@ -88,8 +89,8 @@ static SqlTypeConverter() _hijriCulture = (CultureInfo)CultureInfo.GetCultureInfo("ar-JO").Clone(); _hijriCulture.DateTimeFormat.Calendar = new HijriCalendar(); - _conversions = new ConcurrentDictionary>(); - _sqlConversions = new ConcurrentDictionary>(); + _conversions = new ConcurrentDictionary>(); + _sqlConversions = new ConcurrentDictionary>(); _netToSqlTypeConversions = new Dictionary(); _netToSqlTypeConversionFuncs = new Dictionary>(); @@ -110,9 +111,10 @@ static SqlTypeConverter() AddTypeConversion((ds, v, dt) => v, v => v.Value); AddNullableTypeConversion((ds, v, dt) => ((SqlDataTypeReferenceWithCollation)dt).Collation.ToSqlString(v), v => v.Value); AddTypeConversion((ds, v, dt) => v, v => v.Value); - AddTypeConversion((ds, v, dt) => (SqlDateTime)v, v => v.Value); - AddTypeConversion((ds, v, dt) => (SqlDateTime)v, v => v.Value); + AddTypeConversion((ds, v, dt) => new SqlDate(v), v => v.Value); + AddTypeConversion((ds, v, dt) => new SqlDateTime2(v), v => v.Value); AddTypeConversion((ds, v, dt) => new SqlDateTimeOffset(v), v => v.Value); + AddTypeConversion((ds, v, dt) => new SqlSmallDateTime(v), v => v.Value); AddTypeConversion((ds, v, dt) => new SqlTime(v), v => v.Value); AddNullableTypeConversion((ds, v, dt) => new SqlXml(new MemoryStream(Encoding.GetEncoding("utf-16").GetBytes(v))), v => v.Value); @@ -461,30 +463,46 @@ public static bool CanChangeTypeExplicit(DataTypeReference from, DataTypeReferen /// Produces the required expression to convert values to a specific type /// /// The expression that generates the values to convert + /// The expression which contains the the expression will be evaluated in /// The type to convert to - /// The expression which contains the the expression will be evaluated in /// An expression to generate values of the required type - public static Expression Convert(Expression expr, Type to) + public static Expression Convert(Expression expr, Expression context, Type to) { - if (expr.Type == typeof(SqlDateTime) && (to == typeof(SqlBoolean) || to == typeof(SqlByte) || to == typeof(SqlInt16) || to == typeof(SqlInt32) || to == typeof(SqlInt64) || to == typeof(SqlDecimal) || to == typeof(SqlSingle) || to == typeof(SqlDouble))) + if ((expr.Type == typeof(SqlDateTime) || expr.Type == typeof(SqlSmallDateTime)) && (to == typeof(SqlBoolean) || to == typeof(SqlByte) || to == typeof(SqlInt16) || to == typeof(SqlInt32) || to == typeof(SqlInt64) || to == typeof(SqlDecimal) || to == typeof(SqlSingle) || to == typeof(SqlDouble))) { + // Conversion from datetime types to numeric types uses the number of days since 1900-01-01 + Expression converted; + + if (expr.Type == typeof(SqlDateTime)) + converted = Expression.Convert(expr, typeof(DateTime)); + else + converted = Expression.PropertyOrField(expr, nameof(SqlSmallDateTime.Value)); + + converted = Expression.PropertyOrField( + Expression.Subtract( + converted, + Expression.Constant(new DateTime(1900, 1, 1)) + ), + nameof(TimeSpan.TotalDays) + ); + + // Round the value when converting from datetime to an integer type, rather than + // using the standard double -> int truncation + if (to == typeof(SqlBoolean) || to == typeof(SqlByte) || to == typeof(SqlInt16) || to == typeof(SqlInt32) || to == typeof(SqlInt64)) + { + converted = Expr.Call(() => Math.Round(Expr.Arg()), converted); + } + + // Convert the result to a SqlDouble and handle null values expr = Expression.Condition( Expression.PropertyOrField(expr, nameof(SqlDateTime.IsNull)), Expression.Constant(SqlDouble.Null), - Expression.Convert( - Expression.PropertyOrField( - Expression.Subtract( - Expression.Convert(expr, typeof(DateTime)), - Expression.Constant(new DateTime(1900, 1, 1)) - ), - nameof(TimeSpan.TotalDays) - ), - typeof(SqlDouble) + Expression.Convert(converted, typeof(SqlDouble) ) ); } - if ((expr.Type == typeof(SqlBoolean) || expr.Type == typeof(SqlByte) || expr.Type == typeof(SqlInt16) || expr.Type == typeof(SqlInt32) || expr.Type == typeof(SqlInt64) || expr.Type == typeof(SqlDecimal) || expr.Type == typeof(SqlSingle) || expr.Type == typeof(SqlDouble)) && to == typeof(SqlDateTime)) + if ((expr.Type == typeof(SqlBoolean) || expr.Type == typeof(SqlByte) || expr.Type == typeof(SqlInt16) || expr.Type == typeof(SqlInt32) || expr.Type == typeof(SqlInt64) || expr.Type == typeof(SqlDecimal) || expr.Type == typeof(SqlSingle) || expr.Type == typeof(SqlDouble)) && (to == typeof(SqlDateTime) || to == typeof(SqlSmallDateTime))) { expr = Expression.Condition( NullCheck(expr), @@ -520,6 +538,21 @@ public static Expression Convert(Expression expr, Type to) if (expr.Type == typeof(SqlString) && to == typeof(SqlXml)) expr = Expr.Call(() => ParseXml(Expr.Arg()), expr); + if (expr.Type == typeof(SqlString) && to == typeof(SqlTime)) + expr = Expr.Call(() => ParseTime(Expr.Arg(), Expr.Arg()), expr, context); + + if (expr.Type == typeof(SqlString) && to == typeof(SqlDate)) + expr = Expr.Call(() => ParseDate(Expr.Arg(), Expr.Arg()), expr, context); + + if (expr.Type == typeof(SqlString) && (to == typeof(SqlDateTime) || to == typeof(SqlSmallDateTime))) + expr = Expr.Call(() => ParseDateTime(Expr.Arg(), Expr.Arg()), expr, context); + + if (expr.Type == typeof(SqlString) && to == typeof(SqlDateTime2)) + expr = Expr.Call(() => ParseDateTime2(Expr.Arg(), Expr.Arg()), expr, context); + + if (expr.Type == typeof(SqlString) && to == typeof(SqlDateTimeOffset)) + expr = Expr.Call(() => ParseDateTimeOffset(Expr.Arg(), Expr.Arg()), expr, context); + if (expr.Type == typeof(SqlDecimal) && (to == typeof(SqlInt64) || to == typeof(SqlInt32) || to == typeof(SqlInt16) || to == typeof(SqlByte))) { // Built-in conversion uses rounding, should use truncation @@ -548,9 +581,44 @@ public static Expression Convert(Expression expr, Type to) return expr; } - private static SqlString ApplyCollation(ExpressionExecutionContext context, SqlString sqlString) + private static SqlTime ParseTime(SqlString value, ExpressionExecutionContext context) + { + if (!SqlDateParsing.TryParse(value.Value, context.Session.DateFormat, out SqlTime time)) + throw new QueryExecutionException(Sql4CdsError.DateTimeParseError(null)); + + return time; + } + + private static SqlDate ParseDate(SqlString value, ExpressionExecutionContext context) { - return ApplyCollation(context.PrimaryDataSource.DefaultCollation, sqlString); + if (!SqlDateParsing.TryParse(value.Value, context.Session.DateFormat, out SqlDate date)) + throw new QueryExecutionException(Sql4CdsError.DateTimeParseError(null)); + + return date; + } + + private static SqlDateTime ParseDateTime(SqlString value, ExpressionExecutionContext context) + { + if (!SqlDateParsing.TryParse(value.Value, context.Session.DateFormat, out SqlDateTime dateTime)) + throw new QueryExecutionException(Sql4CdsError.DateTimeParseError(null)); + + return dateTime; + } + + private static SqlDateTime2 ParseDateTime2(SqlString value, ExpressionExecutionContext context) + { + if (!SqlDateParsing.TryParse(value.Value, context.Session.DateFormat, out SqlDateTime2 dateTime2)) + throw new QueryExecutionException(Sql4CdsError.DateTimeParseError(null)); + + return dateTime2; + } + + private static SqlDateTimeOffset ParseDateTimeOffset(SqlString value, ExpressionExecutionContext context) + { + if (!SqlDateParsing.TryParse(value.Value, context.Session.DateFormat, out SqlDateTimeOffset dateTimeOffset)) + throw new QueryExecutionException(Sql4CdsError.DateTimeParseError(null)); + + return dateTimeOffset; } private static SqlString ApplyCollation(Collation collation, SqlString sqlString) @@ -565,13 +633,14 @@ private static SqlString ApplyCollation(Collation collation, SqlString sqlString /// Produces the required expression to convert values to a specific type /// /// The expression that generates the values to convert + /// The expression that provides the current when the expression is evaluated /// The type to convert from /// The type to convert to /// An optional parameter defining the style of the conversion /// An optional parameter defining the type of the expression /// An optional parameter containing the SQL CONVERT() function call to report any errors against /// An expression to generate values of the required type - public static Expression Convert(Expression expr, DataTypeReference from, DataTypeReference to, Expression style = null, DataTypeReference styleType = null, TSqlFragment convert = null, bool throwOnTruncate = false, string table = null, string column = null) + public static Expression Convert(Expression expr, Expression context, DataTypeReference from, DataTypeReference to, Expression style = null, DataTypeReference styleType = null, TSqlFragment convert = null, bool throwOnTruncate = false, string table = null, string column = null) { if (from.IsSameAs(to)) return expr; @@ -605,20 +674,20 @@ public static Expression Convert(Expression expr, DataTypeReference from, DataTy // Special case for conversion to sql_variant if (to.IsSameAs(DataTypeHelpers.Variant)) { - var ctor = typeof(SqlVariant).GetConstructor(new[] { typeof(DataTypeReference), typeof(INullable) }); - return Expression.New(ctor, Expression.Constant(from), Expression.Convert(expr, typeof(INullable))); + var ctor = typeof(SqlVariant).GetConstructor(new[] { typeof(DataTypeReference), typeof(INullable), typeof(ExpressionExecutionContext) }); + return Expression.New(ctor, Expression.Constant(from), Expression.Convert(expr, typeof(INullable)), context); } // Special case for conversion from sql_variant if (from.IsSameAs(DataTypeHelpers.Variant)) { - return Expression.Convert(Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg()), expr, Expression.Constant(to), style), targetType); + return Expression.Convert(Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg()), expr, Expression.Constant(to), style, context), targetType); } var targetCollation = (to as SqlDataTypeReferenceWithCollation)?.Collation; if (fromSqlType != null && (fromSqlType.SqlDataTypeOption.IsDateTimeType() || fromSqlType.SqlDataTypeOption == SqlDataTypeOption.Date || fromSqlType.SqlDataTypeOption == SqlDataTypeOption.Time) && targetType == typeof(SqlString)) - expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg()), Convert(expr, typeof(SqlDateTime)), Expression.Constant(fromSqlType.SqlDataTypeOption != SqlDataTypeOption.Time), Expression.Constant(fromSqlType.SqlDataTypeOption != SqlDataTypeOption.Date), Expression.Constant(from.GetScale()), Expression.Constant(from), Expression.Constant(to), style, Expression.Constant(targetCollation)); + expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg()), Convert(expr, context, typeof(SqlDateTimeOffset)), Expression.Constant(fromSqlType.SqlDataTypeOption != SqlDataTypeOption.Time), Expression.Constant(fromSqlType.SqlDataTypeOption != SqlDataTypeOption.Date), Expression.Constant(from.GetScale()), Expression.Constant(from), Expression.Constant(to), style, Expression.Constant(targetCollation)); else if ((expr.Type == typeof(SqlDouble) || expr.Type == typeof(SqlSingle)) && targetType == typeof(SqlString)) expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg()), expr, style, Expression.Constant(targetCollation)); else if (expr.Type == typeof(SqlMoney) && targetType == typeof(SqlString)) @@ -628,7 +697,7 @@ public static Expression Convert(Expression expr, DataTypeReference from, DataTy if (expr.Type != targetType) { - expr = Convert(expr, targetType); + expr = Convert(expr, context, targetType); // Handle errors during type conversion var conversionError = Expr.Call(() => Sql4CdsError.ConversionError(Expr.Arg(), Expr.Arg(), Expr.Arg()), Expression.Constant(convert, typeof(TSqlFragment)), Expression.Constant(from), Expression.Constant(to)); @@ -731,7 +800,7 @@ public static Expression Convert(Expression expr, DataTypeReference from, DataTy if (toSqlType.Parameters.Count > 0) { - if (!Int32.TryParse(toSqlType.Parameters[0].Value, out precision)) + if (!Int32.TryParse(toSqlType.Parameters[0].Value, out precision) || precision < 0) throw new NotSupportedQueryFragmentException(Sql4CdsError.SyntaxError(toSqlType)) { Suggestion = "Invalid attributes specified for type " + toSqlType.SqlDataTypeOption }; if (precision < 1) @@ -739,8 +808,11 @@ public static Expression Convert(Expression expr, DataTypeReference from, DataTy if (toSqlType.Parameters.Count > 1) { - if (!Int32.TryParse(toSqlType.Parameters[1].Value, out scale)) + if (!Int32.TryParse(toSqlType.Parameters[1].Value, out scale) || scale < 0) throw new NotSupportedQueryFragmentException(Sql4CdsError.SyntaxError(toSqlType)) { Suggestion = "Invalid attributes specified for type " + toSqlType.SqlDataTypeOption }; + + if (scale > precision) + throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidScale(toSqlType, 0)) { Suggestion = "Scale cannot be greater than precision" }; } if (toSqlType.Parameters.Count > 2) @@ -755,6 +827,25 @@ public static Expression Convert(Expression expr, DataTypeReference from, DataTy Expression.Constant(to), Expression.Constant(convert, typeof(TSqlFragment))); } + else if (expr.Type == typeof(SqlDateTime2) || expr.Type == typeof(SqlDateTimeOffset)) + { + // Default scale is 7 + var scale = 7; + + if (toSqlType.Parameters.Count > 0) + { + if (!Int32.TryParse(toSqlType.Parameters[0].Value, out scale) || scale < 0) + throw new NotSupportedQueryFragmentException(Sql4CdsError.SyntaxError(toSqlType)) { Suggestion = "Invalid attributes specified for type " + toSqlType.SqlDataTypeOption }; + + if (scale > 7) + throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidScale(toSqlType, 1)) { Suggestion = "Scale cannot be greater than 7" }; + } + + if (expr.Type == typeof(SqlDateTime2)) + expr = Expr.Call(() => ApplyScale(Expr.Arg(), Expr.Arg()), expr, Expression.Constant(scale)); + else + expr = Expr.Call(() => ApplyScale(Expr.Arg(), Expr.Arg()), expr, Expression.Constant(scale)); + } return expr; } @@ -771,6 +862,38 @@ private static SqlDecimal ApplyPrecisionScale(SqlDecimal value, int precision, i } } + private static SqlDateTime2 ApplyScale(SqlDateTime2 value, int scale) + { + if (value.IsNull) + return value; + + var ticks = value.Value.Ticks % TimeSpan.TicksPerSecond; + var integerSeconds = value.Value.AddTicks(-ticks); + + if (scale == 0) + return integerSeconds; + + var fractionalSeconds = (decimal)ticks / TimeSpan.TicksPerSecond; + var rounded = Math.Round(fractionalSeconds, scale, MidpointRounding.AwayFromZero); + return integerSeconds.AddTicks((int)(rounded * TimeSpan.TicksPerSecond)); + } + + private static SqlDateTimeOffset ApplyScale(SqlDateTimeOffset value, int scale) + { + if (value.IsNull) + return value; + + var ticks = value.Value.Ticks % TimeSpan.TicksPerSecond; + var integerSeconds = value.Value.AddTicks(-ticks); + + if (scale == 0) + return integerSeconds; + + var fractionalSeconds = (decimal)ticks / TimeSpan.TicksPerSecond; + var rounded = Math.Round(fractionalSeconds, scale, MidpointRounding.AwayFromZero); + return integerSeconds.AddTicks((int)(rounded * TimeSpan.TicksPerSecond)); + } + /// /// Converts a value from one collation to another /// @@ -890,7 +1013,7 @@ private static SqlXml ParseXml(SqlString value) /// The style to apply /// The collation to use for the returned result /// The converted string - private static SqlString Convert(SqlDateTime value, bool date, bool time, int timeScale, DataTypeReference fromType, DataTypeReference toType, SqlInt32 style, Collation collation) + private static SqlString Convert(SqlDateTimeOffset value, bool date, bool time, int timeScale, DataTypeReference fromType, DataTypeReference toType, SqlInt32 style, Collation collation) { if (value.IsNull || style.IsNull) return SqlString.Null; @@ -973,7 +1096,12 @@ private static SqlString Convert(SqlDateTime value, bool date, bool time, int ti case 9: case 109: dateFormatString = "MMM dd yyyy"; - timeFormatString = "hh:mm:ss:" + new string('f', timeScale) + "tt"; + timeFormatString = "hh:mm:ss"; + + if (timeScale > 0) + timeFormatString += ":" + new string('f', timeScale); + + timeFormatString += "tt"; break; case 10: @@ -1003,12 +1131,18 @@ private static SqlString Convert(SqlDateTime value, bool date, bool time, int ti case 13: case 113: dateFormatString = "dd MMM yyyy"; - timeFormatString = "HH:mm:ss:" + new string('f', timeScale); + timeFormatString = "HH:mm:ss"; + + if (timeScale > 0) + timeFormatString += ":" + new string('f', timeScale); break; case 14: case 114: - timeFormatString = "HH:mm:ss:" + new string('f', timeScale); + timeFormatString = "HH:mm:ss"; + + if (timeScale > 0) + timeFormatString += ":" + new string('f', timeScale); break; case 20: @@ -1021,7 +1155,10 @@ private static SqlString Convert(SqlDateTime value, bool date, bool time, int ti case 25: case 121: dateFormatString = "yyyy-MM-dd"; - timeFormatString = "HH:mm:ss." + new string('f', timeScale); + timeFormatString = "HH:mm:ss"; + + if (timeScale > 0) + timeFormatString += "." + new string('f', timeScale); break; case 22: @@ -1036,24 +1173,44 @@ private static SqlString Convert(SqlDateTime value, bool date, bool time, int ti case 126: dateFormatString = "yyyy-MM-dd"; dateTimeSeparator = "T"; - timeFormatString = "HH:mm:ss." + new string('F', timeScale); + timeFormatString = "HH:mm:ss"; + + if (timeScale > 0) + timeFormatString += "." + new string('F', timeScale); break; case 127: dateFormatString = "yyyy-MM-dd"; dateTimeSeparator = "T"; - timeFormatString = "HH:mm:ss." + new string('F', timeScale) + "\\Z"; + timeFormatString = "HH:mm:ss"; + + if (timeScale > 0) + timeFormatString += "." + new string('F', timeScale); + + timeFormatString += "\\Z"; break; case 130: dateFormatString = "dd MMMM yyyy"; - timeFormatString = "hh:mm:ss:" + new string('f', timeScale) + "tt"; + timeFormatString = "hh:mm:ss"; + + if (timeScale > 0) + timeFormatString += ":" + new string('f', timeScale); + + timeFormatString += "tt"; + cultureInfo = _hijriCulture; break; case 131: dateFormatString = "dd/MM/yyyy"; - timeFormatString = "HH:mm:ss:" + new string('f', timeScale) + "tt"; + timeFormatString = "hh:mm:ss"; + + if (timeScale > 0) + timeFormatString += ":" + new string('f', timeScale); + + timeFormatString += "tt"; + cultureInfo = _hijriCulture; break; @@ -1075,8 +1232,13 @@ private static SqlString Convert(SqlDateTime value, bool date, bool time, int ti } if (time && !String.IsNullOrEmpty(timeFormatString)) + { formatString += timeFormatString; + if (fromType is SqlDataTypeReference sqlFromType && sqlFromType.SqlDataTypeOption == SqlDataTypeOption.DateTimeOffset) + formatString += " zzz"; + } + var formatted = value.Value.ToString(formatString, cultureInfo); return collation.ToSqlString(formatted); } @@ -1174,7 +1336,7 @@ private static SqlString Convert(SqlBinary value, Collation collation, bool unic /// The type to convert to /// The style of the conversion /// The underlying value of the variant converted to the target type - private static INullable Convert(SqlVariant variant, DataTypeReference targetType, SqlInt32 style) + private static INullable Convert(SqlVariant variant, DataTypeReference targetType, SqlInt32 style, ExpressionExecutionContext context) { if (variant.BaseType == null) return GetNullValue(targetType.ToNetType(out _)); @@ -1186,7 +1348,7 @@ private static INullable Convert(SqlVariant variant, DataTypeReference targetTyp throw new QueryExecutionException(Sql4CdsError.ExplicitConversionNotAllowed(null, variant.BaseType, targetType)); var conversion = GetConversion(variant.BaseType, targetType, style.IsNull ? (int?)null : style.Value); - return conversion(variant.Value); + return conversion(variant.Value, context); } /// @@ -1311,9 +1473,9 @@ public static INullable GetNullValue(Type sqlType) /// The type to convert the value to /// The value to convert /// The value converted to the requested type - public static T ChangeType(object value) + public static T ChangeType(object value, ExpressionExecutionContext context) { - return (T)ChangeType(value, typeof(T)); + return (T)ChangeType(value, typeof(T), context); } /// @@ -1323,13 +1485,13 @@ public static T ChangeType(object value) /// The value to convert /// The type to convert the value to /// The value converted to the requested type - public static object ChangeType(object value, Type type) + public static object ChangeType(object value, Type type, ExpressionExecutionContext context) { if (value != null && value.GetType() == type) return value; var conversion = GetConversion(value.GetType(), type); - return conversion(value); + return conversion(value, context); } /// @@ -1338,18 +1500,19 @@ public static object ChangeType(object value, Type type) /// The type to convert from /// The type to convert to /// A function that converts between the requested types - public static Func GetConversion(Type sourceType, Type destType) + public static Func GetConversion(Type sourceType, Type destType) { var key = sourceType.FullName + " -> " + destType.FullName; return _conversions.GetOrAdd(key, _ => CompileConversion(sourceType, destType)); } - private static Func CompileConversion(Type sourceType, Type destType) + private static Func CompileConversion(Type sourceType, Type destType) { if (sourceType == destType) - return (object value) => value; + return (object value, ExpressionExecutionContext _) => value; var param = Expression.Parameter(typeof(object)); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); var expression = (Expression) Expression.Convert(param, sourceType); // Special case for converting from string to enum for metadata filters @@ -1408,14 +1571,14 @@ private static Func CompileConversion(Type sourceType, Type dest } else { - expression = Convert(expression, destType); + expression = Convert(expression, contextParam, destType); } //if (destType == typeof(SqlString)) // expression = Expr.Call(() => ApplyCollation(Expr.Arg(), Expr.Arg()), contextParam, expression); expression = Expression.Convert(expression, typeof(object)); - return Expression.Lambda>(expression, param).Compile(); + return Expression.Lambda>(expression, param, contextParam).Compile(); } /// @@ -1425,7 +1588,7 @@ private static Func CompileConversion(Type sourceType, Type dest /// The type to convert to /// The style of the converesion /// A function that converts between the requested types - public static Func GetConversion(DataTypeReference sourceType, DataTypeReference destType, int? style = null) + public static Func GetConversion(DataTypeReference sourceType, DataTypeReference destType, int? style = null) { var key = sourceType.ToSql() + " -> " + destType.ToSql(); @@ -1443,15 +1606,16 @@ public static Func GetConversion(DataTypeReference sourceT return _sqlConversions.GetOrAdd(key, _ => CompileConversion(sourceType, destType, style)); } - private static Func CompileConversion(DataTypeReference sourceType, DataTypeReference destType, int? style = null) + private static Func CompileConversion(DataTypeReference sourceType, DataTypeReference destType, int? style = null) { if (sourceType.IsSameAs(destType)) - return (INullable value) => value; + return (INullable value, ExpressionExecutionContext _) => value; var sourceNetType = sourceType.ToNetType(out _); var destNetType = destType.ToNetType(out _); var param = Expression.Parameter(typeof(INullable)); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); var expression = (Expression)Expression.Convert(param, sourceNetType); var styleExpr = (Expression)null; var styleType = (DataTypeReference)null; @@ -1462,9 +1626,9 @@ private static Func CompileConversion(DataTypeReference so styleType = DataTypeHelpers.Int; } - expression = Convert(expression, sourceType, destType, styleExpr, styleType); + expression = Convert(expression, contextParam, sourceType, destType, styleExpr, styleType); expression = Expression.Convert(expression, typeof(INullable)); - return Expression.Lambda>(expression, param).Compile(); + return Expression.Lambda>(expression, param, contextParam).Compile(); } /// diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/StreamAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/StreamAggregateNode.cs index 3183d33e..1303b39b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/StreamAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/StreamAggregateNode.cs @@ -77,7 +77,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont expressionExecutionContext.Entity = entity; foreach (var func in states.Values) - func.AggregateFunction.NextRecord(func.State); + func.AggregateFunction.NextRecord(func.State, expressionExecutionContext); } if (states != null) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs index dac954ed..dbf48ed3 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs @@ -52,7 +52,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext public override INodeSchema GetSchema(NodeCompilationContext context) { - var dataSource = context.DataSources[DataSource]; + var dataSource = context.Session.DataSources[DataSource]; switch (SystemFunction) { @@ -97,7 +97,7 @@ protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationConte protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var dataSource = context.DataSources[DataSource]; + var dataSource = context.Session.DataSources[DataSource]; switch (SystemFunction) { diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs index 3bf0dbe2..05916a35 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs @@ -98,7 +98,27 @@ public override IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContex { UseLegacyUpdateMessages = GetUseLegacyUpdateMessages(context, hints); - return base.FoldQuery(context, hints); + var result = base.FoldQuery(context, hints); + + if (result.Length != 1 || result[0] != this) + return result; + + // If we don't need to use any of the original values and we're filtering by ID we can bypass reading + // the record to update + if (ColumnMappings.Values.All(m => m.OldValueColumn == null)) + { + var dataSource = context.Session.DataSources[DataSource]; + var meta = dataSource.Metadata[LogicalName]; + + var requiredColumns = ColumnMappings + .Select(kvp => kvp.Value.NewValueColumn) + .Union(new[] { PrimaryIdSource }) + .ToArray(); + + FoldIdsToConstantScan(context, hints, LogicalName, requiredColumns); + } + + return new[] { this }; } private bool GetUseLegacyUpdateMessages(NodeCompilationContext context, IList queryHints) @@ -120,15 +140,16 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect try { - if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.Session.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); List entities; EntityMetadata meta; Dictionary attributes; - Dictionary> newAttributeAccessors; - Dictionary> oldAttributeAccessors; - Func primaryIdAccessor; + Dictionary> newAttributeAccessors; + Dictionary> oldAttributeAccessors; + Func primaryIdAccessor; + var eec = new ExpressionExecutionContext(context); using (_timer.Run()) { @@ -213,8 +234,9 @@ public override void Execute(NodeExecutionContext context, out int recordsAffect meta, entity => { - var preImage = ExtractEntity(entity, meta, attributes, oldAttributeAccessors, primaryIdAccessor); - var update = ExtractEntity(entity, meta, attributes, newAttributeAccessors, primaryIdAccessor); + eec.Entity = entity; + var preImage = ExtractEntity(eec, meta, attributes, oldAttributeAccessors, primaryIdAccessor); + var update = ExtractEntity(eec, meta, attributes, newAttributeAccessors, primaryIdAccessor); var requests = new OrganizationRequestCollection(); @@ -568,9 +590,9 @@ private OptionSetValue GetDefaultStatusCode(EntityMetadata meta, int statecode) return new OptionSetValue((int)((StateOptionMetadata)stateCode).DefaultStatus); } - private Entity ExtractEntity(Entity entity, EntityMetadata meta, Dictionary attributes, Dictionary> newAttributeAccessors, Func primaryIdAccessor) + private Entity ExtractEntity(ExpressionExecutionContext context, EntityMetadata meta, Dictionary attributes, Dictionary> newAttributeAccessors, Func primaryIdAccessor) { - var update = new Entity(LogicalName, (Guid)primaryIdAccessor(entity)); + var update = new Entity(LogicalName, (Guid)primaryIdAccessor(context)); foreach (var attributeAccessor in newAttributeAccessors) { @@ -582,14 +604,14 @@ private Entity ExtractEntity(Entity entity, EntityMetadata meta, Dictionary(ColumnMappings["activitytypecode"].OldValueColumn).Value; + update.LogicalName = context.Entity.GetAttributeValue(ColumnMappings["activitytypecode"].OldValueColumn).Value; return update; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index bcb9b974..97c72d90 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -21,21 +21,21 @@ class ExecutionPlanBuilder private NodeCompilationContext _nodeContext; private Dictionary _cteSubplans; - public ExecutionPlanBuilder(IEnumerable dataSources, IQueryExecutionOptions options) + public ExecutionPlanBuilder(SessionContext session, IQueryExecutionOptions options) { - DataSources = dataSources.ToDictionary(ds => ds.Name, StringComparer.OrdinalIgnoreCase); + Session = session; Options = options; - if (!DataSources.ContainsKey(Options.PrimaryDataSource)) + if (!Session.DataSources.ContainsKey(Options.PrimaryDataSource)) throw new ArgumentOutOfRangeException(nameof(options), "Primary data source " + options.PrimaryDataSource + " not found"); EstimatedPlanOnly = true; } /// - /// The connections that will be used by this conversion + /// The session that the query will be executed in /// - public IDictionary DataSources { get; } + public SessionContext Session { get; } /// /// Indicates how the query will be executed @@ -52,7 +52,7 @@ public ExecutionPlanBuilder(IEnumerable dataSources, IQueryExecution /// public Action Log { get; set; } - private DataSource PrimaryDataSource => DataSources[Options.PrimaryDataSource]; + private DataSource PrimaryDataSource => Session.DataSources[Options.PrimaryDataSource]; /// /// Builds the execution plans for a SQL command @@ -65,22 +65,19 @@ public IRootExecutionPlanNode[] Build(string sql, IDictionary(StringComparer.OrdinalIgnoreCase); - _staticContext = new ExpressionCompilationContext(DataSources, Options, parameterTypes, null, null); - _nodeContext = new NodeCompilationContext(DataSources, Options, parameterTypes, Log); + var localParameterTypes = new Dictionary(StringComparer.OrdinalIgnoreCase); if (parameters != null) { foreach (var param in parameters) - parameterTypes[param.Key] = param.Value; + localParameterTypes[param.Key] = param.Value; } // Add in standard global variables - parameterTypes["@@IDENTITY"] = DataTypeHelpers.EntityReference; - parameterTypes["@@ROWCOUNT"] = DataTypeHelpers.Int; - parameterTypes["@@SERVERNAME"] = DataTypeHelpers.NVarChar(100, DataSources[Options.PrimaryDataSource].DefaultCollation, CollationLabel.CoercibleDefault); - parameterTypes["@@VERSION"] = DataTypeHelpers.NVarChar(Int32.MaxValue, DataSources[Options.PrimaryDataSource].DefaultCollation, CollationLabel.CoercibleDefault); - parameterTypes["@@ERROR"] = DataTypeHelpers.Int; + var parameterTypes = new LayeredDictionary(Session.GlobalVariableTypes, localParameterTypes); + + _staticContext = new ExpressionCompilationContext(Session, Options, parameterTypes, null, null); + _nodeContext = new NodeCompilationContext(Session, Options, parameterTypes, Log); var queries = new List(); @@ -96,11 +93,11 @@ public IRootExecutionPlanNode[] Build(string sql, IDictionary ConvertSetCommand(c)) + .ToArray(); + } + + private IDmlQueryExecutionPlanNode ConvertSetCommand(SetCommand setCommand) + { + if (setCommand is GeneralSetCommand cmd) + { + switch (cmd.CommandType) + { + case GeneralSetCommandType.DateFormat: + return new SetDateFormatNode(cmd.Parameter); + } + } + + throw new NotSupportedQueryFragmentException(Sql4CdsError.NotSupported(setCommand, setCommand.ToNormalizedSql())); + } + private SelectNode ConvertRecursiveCTEQuery(QueryExpression queryExpression, INodeSchema anchorSchema, CteValidatorVisitor cteValidator, Dictionary outerReferences) { // Convert the query using the anchor query as a subquery to check for ambiguous column names @@ -1376,8 +1396,8 @@ private DataSource SelectDataSource(SchemaObjectName schemaObject) { var databaseName = schemaObject.DatabaseIdentifier?.Value ?? Options.PrimaryDataSource; - if (!DataSources.TryGetValue(databaseName, out var dataSource)) - throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidObjectName(schemaObject)) { Suggestion = $"Available database names:\r\n* {String.Join("\r\n* ", DataSources.Keys.OrderBy(k => k))}" }; + if (!Session.DataSources.TryGetValue(databaseName, out var dataSource)) + throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidObjectName(schemaObject)) { Suggestion = $"Available database names:\r\n* {String.Join("\r\n* ", Session.DataSources.Keys.OrderBy(k => k))}" }; return dataSource; } @@ -1717,8 +1737,8 @@ private DeleteNode ConvertDeleteStatement(DeleteSpecification delete, IList k))}" }; + if (!Session.DataSources.TryGetValue(deleteTarget.TargetDataSource, out var dataSource)) + throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidObjectName(target.SchemaObject)) { Suggestion = $"Available database names:\r\n* {String.Join("\r\n*", Session.DataSources.Keys.OrderBy(k => k))}" }; ValidateDMLSchema(deleteTarget.Target, true); @@ -2061,8 +2081,8 @@ private UpdateNode ConvertUpdateStatement(UpdateSpecification update, IList k))}" }; + if (!Session.DataSources.TryGetValue(updateTarget.TargetDataSource, out var dataSource)) + throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidObjectName(target.SchemaObject)) { Suggestion = $"Available database names:\r\n* {String.Join("\r\n*", Session.DataSources.Keys.OrderBy(k => k))}" }; ValidateDMLSchema(updateTarget.Target, false); @@ -2607,11 +2627,11 @@ private UpdateNode ConvertSetClause(IList setClauses, HashSet private IRootExecutionPlanNodeInternal ConvertSelectStatement(SelectStatement select) { - if (TDSEndpoint.CanUseTDSEndpoint(Options, DataSources[Options.PrimaryDataSource].Connection)) + if (TDSEndpoint.CanUseTDSEndpoint(Options, PrimaryDataSource.Connection)) { - using (var con = DataSources[Options.PrimaryDataSource].Connection == null ? null : TDSEndpoint.Connect(DataSources[Options.PrimaryDataSource].Connection)) + using (var con = PrimaryDataSource.Connection == null ? null : TDSEndpoint.Connect(PrimaryDataSource.Connection)) { - var tdsEndpointCompatibilityVisitor = new TDSEndpointCompatibilityVisitor(con, DataSources[Options.PrimaryDataSource].Metadata, false); + var tdsEndpointCompatibilityVisitor = new TDSEndpointCompatibilityVisitor(con, PrimaryDataSource.Metadata, false); select.Accept(tdsEndpointCompatibilityVisitor); // Remove any custom optimizer hints @@ -2627,7 +2647,8 @@ private IRootExecutionPlanNodeInternal ConvertSelectStatement(SelectStatement se var sql = new SqlNode { DataSource = Options.PrimaryDataSource, - Sql = select.ToSql() + Sql = select.ToSql(), + SelectStatement = select }; var variables = new VariableCollectingVisitor(); @@ -3259,7 +3280,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubquery(IDataExecutionPlanNodeI else { // We need the inner list to be distinct to avoid creating duplicates during the join - var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(DataSources, Options, parameters, Log)); + var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(Session, Options, parameters, Log)); if (innerQuery.ColumnSet[0].SourceColumn != innerSchema.PrimaryKey && !(innerQuery.Source is DistinctNode)) { innerQuery.Source = new DistinctNode @@ -3374,7 +3395,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubquery(IDataExecutionPlanN var innerContext = new NodeCompilationContext(context, parameters); var references = new Dictionary(); var innerQuery = ConvertSelectStatement(existsSubquery.Subquery.QueryExpression, hints, schema, references, innerContext); - var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(DataSources, Options, parameters, Log)); + var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(Session, Options, parameters, Log)); var innerSchemaPrimaryKey = innerSchema.PrimaryKey; // Create the join @@ -4494,8 +4515,8 @@ private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPla if (outerKey == null) return false; - var outerSchema = node.GetSchema(new NodeCompilationContext(DataSources, Options, null, Log)); - var innerSchema = subNode.GetSchema(new NodeCompilationContext(DataSources, Options, null, Log)); + var outerSchema = node.GetSchema(new NodeCompilationContext(Session, Options, null, Log)); + var innerSchema = subNode.GetSchema(new NodeCompilationContext(Session, Options, null, Log)); if (!outerSchema.ContainsColumn(outerKey, out outerKey) || !innerSchema.ContainsColumn(innerKey, out innerKey)) @@ -4568,7 +4589,7 @@ private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPla if (semiJoin) { // Regenerate the schema after changing the alias - innerSchema = subNode.GetSchema(new NodeCompilationContext(DataSources, Options, null, Log)); + innerSchema = subNode.GetSchema(new NodeCompilationContext(Session, Options, null, Log)); if (innerSchema.PrimaryKey != rightAttribute.GetColumnName() && !(merge.RightSource is DistinctNode)) { diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanOptimizer.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanOptimizer.cs index fe267e69..cc11571c 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanOptimizer.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanOptimizer.cs @@ -17,16 +17,16 @@ namespace MarkMpn.Sql4Cds.Engine /// class ExecutionPlanOptimizer { - public ExecutionPlanOptimizer(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, bool compileConditions, Action log) + public ExecutionPlanOptimizer(SessionContext session, IQueryExecutionOptions options, IDictionary parameterTypes, bool compileConditions, Action log) { - DataSources = dataSources; + Session = session; Options = options; ParameterTypes = parameterTypes; CompileConditions = compileConditions; Log = log; } - public IDictionary DataSources { get; } + public SessionContext Session { get; } public IQueryExecutionOptions Options { get; } @@ -52,7 +52,7 @@ public IRootExecutionPlanNodeInternal[] Optimize(IRootExecutionPlanNodeInternal hints.Add(new ConditionalNode.DoNotCompileConditionsHint()); } - var context = new NodeCompilationContext(DataSources, Options, ParameterTypes, Log); + var context = new NodeCompilationContext(Session, Options, ParameterTypes, Log); // Move any additional operators down to the FetchXml var bypassOptimization = hints != null && hints.OfType().Any(list => list.Hints.Any(h => h.Value.Equals("DEBUG_BYPASS_OPTIMIZATION", StringComparison.OrdinalIgnoreCase))); diff --git a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs index 44ffaf11..6f1841fa 100644 --- a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs @@ -1,7 +1,6 @@ using MarkMpn.Sql4Cds.Engine.ExecutionPlan; using Microsoft.Crm.Sdk.Messages; using Microsoft.SqlServer.TransactSql.ScriptDom; -using Microsoft.VisualBasic; using Microsoft.Xrm.Sdk; using System; using System.Collections.Generic; @@ -194,17 +193,98 @@ public static SqlBoolean Json_Path_Exists(SqlString json, SqlString jpath) /// The modified date /// [SqlFunction(IsDeterministic = true)] - public static SqlDateTime DateAdd(SqlString datepart, SqlDouble number, SqlDateTime date) + public static SqlDateTimeOffset DateAdd(SqlString datepart, SqlInt32 number, SqlDateTimeOffset date, [SourceType(nameof(date)), TargetType] DataTypeReference dateType) { if (number.IsNull || date.IsNull) return SqlDateTime.Null; - var interval = DatePartToInterval(datepart.Value); - var value = DateAndTime.DateAdd(interval, number.Value, date.Value); + if (!TryParseDatePart(datepart.Value, out var interval)) + throw new QueryExecutionException(Sql4CdsError.InvalidOptionValue(new StringLiteral { Value = datepart.Value }, "dateadd")); - // DateAdd loses the Kind property for some interval types - add it back in again - if (value.Kind == DateTimeKind.Unspecified) - value = DateTime.SpecifyKind(value, date.Value.Kind); + if (interval == Engine.DatePart.TZOffset) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "dateadd", dateType)); + + DateTimeOffset value; + + try + { + switch (interval) + { + case Engine.DatePart.Year: + value = date.Value.AddYears(number.Value); + break; + + case Engine.DatePart.Quarter: + value = date.Value.AddMonths(number.Value * 3); + break; + + case Engine.DatePart.Month: + value = date.Value.AddMonths(number.Value); + break; + + case Engine.DatePart.DayOfYear: + case Engine.DatePart.Day: + case Engine.DatePart.WeekDay: + value = date.Value.AddDays(number.Value); + break; + + case Engine.DatePart.Week: + value = date.Value.AddDays(number.Value * 7); + break; + + case Engine.DatePart.Hour: + value = date.Value.AddHours(number.Value); + break; + + case Engine.DatePart.Minute: + value = date.Value.AddMinutes(number.Value); + break; + + case Engine.DatePart.Second: + value = date.Value.AddSeconds(number.Value); + break; + + case Engine.DatePart.Millisecond: + value = date.Value.AddMilliseconds(number.Value); + + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#return-values-for-a-smalldatetime-date-and-a-second-or-fractional-seconds-datepart + if (dateType.IsSameAs(DataTypeHelpers.SmallDateTime)) + value = value.AddMilliseconds(1); + break; + + case Engine.DatePart.Microsecond: + // Check data type & precision + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#fractional-seconds-precision + if (dateType.IsSameAs(DataTypeHelpers.SmallDateTime) || dateType.IsSameAs(DataTypeHelpers.DateTime) || dateType.IsSameAs(DataTypeHelpers.Date)) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "dateadd", dateType)); + + value = date.Value.AddTicks(number.Value * 10); + break; + + case Engine.DatePart.Nanosecond: + // Check data type & precision + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#fractional-seconds-precision + if (dateType.IsSameAs(DataTypeHelpers.SmallDateTime) || dateType.IsSameAs(DataTypeHelpers.DateTime) || dateType.IsSameAs(DataTypeHelpers.Date)) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "dateadd", dateType)); + + var ticks = (int)Math.Round(number.Value / 100M, MidpointRounding.AwayFromZero); + value = date.Value.AddTicks(ticks); + break; + + default: + throw new QueryExecutionException(Sql4CdsError.InvalidOptionValue(new StringLiteral { Value = datepart.Value }, "datepart")); + } + } + catch (ArgumentOutOfRangeException) + { + throw new QueryExecutionException(Sql4CdsError.AdditionOverflow(null, dateType)); + } + + if (dateType.IsSameAs(DataTypeHelpers.SmallDateTime) && (value.DateTime < SqlSmallDateTime.MinValue.Value || value.DateTime > SqlSmallDateTime.MaxValue.Value)) + throw new QueryExecutionException(Sql4CdsError.AdditionOverflow(null, dateType)); + + if (dateType.IsSameAs(DataTypeHelpers.DateTime) && (value.DateTime < SqlDateTime.MinValue.Value || value.DateTime > SqlDateTime.MaxValue.Value)) + throw new QueryExecutionException(Sql4CdsError.AdditionOverflow(null, dateType)); return value; } @@ -218,13 +298,175 @@ public static SqlDateTime DateAdd(SqlString datepart, SqlDouble number, SqlDateT /// The number of whole units between and /// [SqlFunction(IsDeterministic = true)] - public static SqlInt32 DateDiff(SqlString datepart, SqlDateTime startdate, SqlDateTime enddate) + public static SqlInt32 DateDiff(SqlString datepart, SqlDateTimeOffset startdate, SqlDateTimeOffset enddate, [SourceType(nameof(startdate))] DataTypeReference startdateType, [SourceType(nameof(enddate))] DataTypeReference enddateType) { if (startdate.IsNull || enddate.IsNull) return SqlInt32.Null; - var interval = DatePartToInterval(datepart.Value); - return (int) DateAndTime.DateDiff(interval, startdate.Value, enddate.Value); + if (!TryParseDatePart(datepart.Value, out var interval)) + throw new QueryExecutionException(Sql4CdsError.InvalidOptionValue(new StringLiteral { Value = datepart.Value }, "datediff")); + + if (interval == Engine.DatePart.Nanosecond) + return 0; + else if (interval == Engine.DatePart.TZOffset) + throw new QueryExecutionException(Sql4CdsError.UnsupportedDatePart(null, datepart.Value, "datediff")); + + if (interval == Engine.DatePart.WeekDay) + { + startdate = DateTrunc("day", startdate, startdateType); + enddate = DateTrunc("day", enddate, enddateType); + } + else + { + startdate = DateTrunc(datepart, startdate, startdateType); + enddate = DateTrunc(datepart, enddate, enddateType); + } + + switch (interval) + { + case Engine.DatePart.Year: + return enddate.Value.UtcDateTime.Year - startdate.Value.UtcDateTime.Year; + + case Engine.DatePart.Quarter: + var endQuarter = enddate.Value.UtcDateTime.Year * 4 + (enddate.Value.UtcDateTime.Month - 1) / 3 + 1; + var startQuarter = startdate.Value.UtcDateTime.Year * 4 + (startdate.Value.UtcDateTime.Month - 1) / 3 + 1; + return endQuarter - startQuarter; + + case Engine.DatePart.Month: + return (enddate.Value.UtcDateTime.Year - startdate.Value.UtcDateTime.Year) * 12 + enddate.Value.UtcDateTime.Month - startdate.Value.UtcDateTime.Month; + + case Engine.DatePart.DayOfYear: + case Engine.DatePart.Day: + case Engine.DatePart.WeekDay: + return (enddate.Value.UtcDateTime - startdate.Value.UtcDateTime).Days; + + case Engine.DatePart.Week: + case Engine.DatePart.ISOWeek: + return (enddate.Value.UtcDateTime - startdate.Value.UtcDateTime).Days / 7; + + case Engine.DatePart.Hour: + return (int)(enddate.Value.UtcDateTime - startdate.Value.UtcDateTime).TotalHours; + + case Engine.DatePart.Minute: + return (int)(enddate.Value.UtcDateTime - startdate.Value.UtcDateTime).TotalMinutes; + + case Engine.DatePart.Second: + return (int)(enddate.Value.UtcDateTime - startdate.Value.UtcDateTime).TotalSeconds; + + case Engine.DatePart.Millisecond: + return (int)(enddate.Value.UtcDateTime - startdate.Value.UtcDateTime).TotalMilliseconds; + + case Engine.DatePart.Microsecond: + return (int)((enddate.Value.UtcDateTime - startdate.Value.UtcDateTime).Ticks / 10); + + case Engine.DatePart.Nanosecond: + return (int)((enddate.Value.UtcDateTime - startdate.Value.UtcDateTime).Ticks * 100); + + default: + throw new QueryExecutionException(Sql4CdsError.InvalidOptionValue(new StringLiteral { Value = datepart.Value }, "datepart")); + } + } + + /// + /// Implements the DATETRUNC function + /// + /// Specifies the precision for truncation + /// The date value to be truncated + /// The truncated version of the date + /// + [SqlFunction(IsDeterministic = true)] + public static SqlDateTimeOffset DateTrunc(SqlString datepart, SqlDateTimeOffset date, [SourceType(nameof(date)), TargetType] DataTypeReference dateType) + { + if (date.IsNull) + return SqlDateTimeOffset.Null; + + if (!TryParseDatePart(datepart.Value, out var interval)) + throw new QueryExecutionException(Sql4CdsError.InvalidOptionValue(new StringLiteral { Value = datepart.Value }, "datetrunc")); + + if (interval == Engine.DatePart.WeekDay || interval == Engine.DatePart.TZOffset || interval == Engine.DatePart.Nanosecond) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + if (!(dateType is SqlDataTypeReference sqlDateType)) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + var scale = dateType.GetScale(); + + switch (interval) + { + case Engine.DatePart.Year: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Year, 1, 1, 0, 0, 0, date.Value.Offset); + + case Engine.DatePart.Quarter: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Year, (((date.Value.Month - 1) / 3) * 3) + 1, 1, 0, 0, 0, date.Value.Offset); + + case Engine.DatePart.Month: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Year, date.Value.Month, 1, 0, 0, 0, date.Value.Offset); + + case Engine.DatePart.DayOfYear: + case Engine.DatePart.Day: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Date, date.Value.Offset); + + case Engine.DatePart.Week: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Date, date.Value.Offset).AddDays(-(int)date.Value.DayOfWeek); + + case Engine.DatePart.ISOWeek: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + var day = (int)date.Value.DayOfWeek; + if (day == 0) + day = 7; + + return new DateTimeOffset(date.Value.Date, date.Value.Offset).AddDays(-day + 1); + + case Engine.DatePart.Hour: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Date) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Year, date.Value.Month, date.Value.Day, date.Value.Hour, 0, 0, date.Value.Offset); + + case Engine.DatePart.Minute: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Date) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Year, date.Value.Month, date.Value.Day, date.Value.Hour, date.Value.Minute, 0, date.Value.Offset); + + case Engine.DatePart.Second: + if (sqlDateType.SqlDataTypeOption == SqlDataTypeOption.Date) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Year, date.Value.Month, date.Value.Day, date.Value.Hour, date.Value.Minute, date.Value.Second, date.Value.Offset); + + case Engine.DatePart.Millisecond: + if (scale < 3) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return new DateTimeOffset(date.Value.Year, date.Value.Month, date.Value.Day, date.Value.Hour, date.Value.Minute, date.Value.Second, date.Value.Millisecond, date.Value.Offset); + + case Engine.DatePart.Microsecond: + if (scale < 6) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + + return date.Value.AddTicks(-date.Value.Ticks % 10); + + default: + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datetrunc", dateType)); + } } /// @@ -234,75 +476,192 @@ public static SqlInt32 DateDiff(SqlString datepart, SqlDateTime startdate, SqlDa /// The date to extract the from /// The of the [SqlFunction(IsDeterministic = true)] - public static SqlInt32 DatePart(SqlString datepart, SqlDateTime date) + public static SqlInt32 DatePart(SqlString datepart, SqlDateTimeOffset date, [SourceType(nameof(date))] DataTypeReference dateType) { if (date.IsNull) return SqlInt32.Null; - var interval = DatePartToInterval(datepart.Value); - return DateAndTime.DatePart(interval, date.Value); + if (!TryParseDatePart(datepart.Value, out var interval)) + throw new QueryExecutionException(Sql4CdsError.InvalidOptionValue(new StringLiteral { Value = datepart.Value }, "datepart")); + + var sqlDateType = dateType as SqlDataTypeReference; + + switch (interval) + { + case Engine.DatePart.Year: + if (sqlDateType?.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return date.Value.Year; + + case Engine.DatePart.Quarter: + if (sqlDateType?.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return (date.Value.Month - 1) / 3 + 1; + + case Engine.DatePart.Month: + if (sqlDateType?.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return date.Value.Month; + + case Engine.DatePart.DayOfYear: + if (sqlDateType?.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return date.Value.DayOfYear; + + case Engine.DatePart.Day: + if (sqlDateType?.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return date.Value.Day; + + case Engine.DatePart.Week: + if (sqlDateType?.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return CultureInfo.CurrentCulture.Calendar.GetWeekOfYear(date.Value.DateTime, CalendarWeekRule.FirstFourDayWeek, DayOfWeek.Sunday); + + case Engine.DatePart.WeekDay: + if (sqlDateType?.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return (int)date.Value.DayOfWeek + 1; + + case Engine.DatePart.Hour: + return date.Value.Hour; + + case Engine.DatePart.Minute: + return date.Value.Minute; + + case Engine.DatePart.Second: + return date.Value.Second; + + case Engine.DatePart.Millisecond: + return date.Value.Millisecond; + + case Engine.DatePart.Microsecond: + return (int)(date.Value.Ticks % 10_000_000) / 10; + + case Engine.DatePart.Nanosecond: + return (int)(date.Value.Ticks % 10_000_000) * 100; + + case Engine.DatePart.TZOffset: + if (sqlDateType?.SqlDataTypeOption != SqlDataTypeOption.DateTimeOffset && sqlDateType?.SqlDataTypeOption != SqlDataTypeOption.DateTime2 && sqlDateType?.SqlDataTypeOption.IsStringType() != true) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return (int)date.Value.Offset.TotalMinutes; + + case Engine.DatePart.ISOWeek: + if (sqlDateType?.SqlDataTypeOption == SqlDataTypeOption.Time) + throw new QueryExecutionException(Sql4CdsError.InvalidDatePart(null, datepart.Value, "datepart", dateType)); + + return CultureInfo.CurrentCulture.Calendar.GetWeekOfYear(date.Value.DateTime, CalendarWeekRule.FirstFourDayWeek, DayOfWeek.Monday); + + default: + throw new QueryExecutionException(Sql4CdsError.InvalidOptionValue(new StringLiteral { Value = datepart.Value }, "datepart")); + } } /// - /// Converts the SQL datepart argument names to the equivalent enum values used by VisualBasic + /// Converts the SQL datepart argument names to the equivalent enum value /// /// The SQL name for the datepart argument - /// The equivalent value - internal static DateInterval DatePartToInterval(string datepart) + /// The equivalent value + internal static bool TryParseDatePart(string datepart, out DatePart parsed) { switch (datepart.ToLower()) { case "year": case "yy": case "yyyy": - return DateInterval.Year; + parsed = Engine.DatePart.Year; + return true; case "quarter": case "qq": case "q": - return DateInterval.Quarter; + parsed = Engine.DatePart.Quarter; + return true; case "month": case "mm": case "m": - return DateInterval.Month; + parsed = Engine.DatePart.Month; + return true; case "dayofyear": case "dy": case "y": - return DateInterval.DayOfYear; + parsed = Engine.DatePart.DayOfYear; + return true; case "day": case "dd": case "d": - return DateInterval.Day; - - case "weekday": - case "dw": - case "w": - return DateInterval.Weekday; + parsed = Engine.DatePart.Day; + return true; case "week": case "wk": case "ww": - return DateInterval.WeekOfYear; + parsed = Engine.DatePart.Week; + return true; + + case "weekday": + case "dw": + case "w": // Abbreviation is lised for DATEADD but not DATEPART + parsed = Engine.DatePart.WeekDay; + return true; case "hour": case "hh": - return DateInterval.Hour; + parsed = Engine.DatePart.Hour; + return true; case "minute": - case "mi": + case "mi": // Abbreviation is lised for DATEADD but not DATEPART case "n": - return DateInterval.Minute; + parsed = Engine.DatePart.Minute; + return true; case "second": case "ss": case "s": - return DateInterval.Second; + parsed = Engine.DatePart.Second; + return true; + + case "millisecond": + case "ms": + parsed = Engine.DatePart.Millisecond; + return true; + + case "microsecond": + case "mcs": + parsed = Engine.DatePart.Microsecond; + return true; + + case "nanosecond": + case "ns": + parsed = Engine.DatePart.Nanosecond; + return true; + + case "tzoffset": + case "tz": + parsed = Engine.DatePart.TZOffset; + return true; + + case "iso_week": + case "isowk": + case "isoww": + parsed = Engine.DatePart.ISOWeek; + return true; default: - throw new ArgumentOutOfRangeException(nameof(datepart), $"Unsupported DATEPART value {datepart}"); + parsed = Engine.DatePart.Year; + return false; } } @@ -474,7 +833,7 @@ public static SqlInt32 Len(SqlString s) /// Any expression /// [SqlFunction(IsDeterministic = true)] - public static SqlInt32 DataLength(T value, [SourceType] DataTypeReference type) + public static SqlInt32 DataLength(T value, [SourceType(nameof(value))] DataTypeReference type) where T:INullable { if (value.IsNull) @@ -731,7 +1090,7 @@ public static T IsNull(T check, T replacement) /// Optional argument specifying a culture /// [SqlFunction(IsDeterministic = false)] - public static SqlString Format(T value, SqlString format, [Optional] SqlString culture, [SourceType] DataTypeReference type, ExpressionExecutionContext context) + public static SqlString Format(T value, SqlString format, [Optional] SqlString culture, [SourceType(nameof(value))] DataTypeReference type, ExpressionExecutionContext context) where T : INullable { if (value.IsNull) @@ -940,7 +1299,7 @@ public static object Value(SqlXml value, XPath2Expression query, [TargetType] Da throw new QueryExecutionException(Sql4CdsError.NotSupported(null, $"XPath return type '{result.GetType().Name}'")); if (sqlValue.GetType() != targetNetType) - sqlValue = (INullable) SqlTypeConverter.ChangeType(sqlValue, targetNetType); + sqlValue = (INullable) SqlTypeConverter.ChangeType(sqlValue, targetNetType, context); return sqlValue; } @@ -988,25 +1347,25 @@ public static SqlVariant ServerProperty(SqlString propertyName, ExpressionExecut switch (propertyName.Value.ToLowerInvariant()) { case "collation": - return new SqlVariant(DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString(dataSource.DefaultCollation.Name)); + return new SqlVariant(DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString(dataSource.DefaultCollation.Name), context); case "collationid": - return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(dataSource.DefaultCollation.LCID)); + return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(dataSource.DefaultCollation.LCID), context); case "comparisonstyle": - return new SqlVariant(DataTypeHelpers.Int, new SqlInt32((int)dataSource.DefaultCollation.CompareOptions)); + return new SqlVariant(DataTypeHelpers.Int, new SqlInt32((int)dataSource.DefaultCollation.CompareOptions), context); case "edition": - return new SqlVariant(DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString("Enterprise Edition")); + return new SqlVariant(DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString("Enterprise Edition"), context); case "editionid": - return new SqlVariant(DataTypeHelpers.BigInt, new SqlInt64(1804890536)); + return new SqlVariant(DataTypeHelpers.BigInt, new SqlInt64(1804890536), context); case "enginedition": - return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(3)); + return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(3), context); case "issingleuser": - return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(0)); + return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(0), context); case "machinename": case "servername": @@ -1019,13 +1378,13 @@ public static SqlVariant ServerProperty(SqlString propertyName, ExpressionExecut if (svc != null) machineName = svc.CrmConnectOrgUriActual.Host; #endif - return new SqlVariant(DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString(machineName)); + return new SqlVariant(DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString(machineName), context); case "pathseparator": - return new SqlVariant(DataTypeHelpers.NVarChar(1, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString(Path.DirectorySeparatorChar.ToString())); + return new SqlVariant(DataTypeHelpers.NVarChar(1, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString(Path.DirectorySeparatorChar.ToString()), context); case "processid": - return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(System.Diagnostics.Process.GetCurrentProcess().Id)); + return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(System.Diagnostics.Process.GetCurrentProcess().Id), context); case "productversion": string orgVersion = null; @@ -1041,14 +1400,14 @@ public static SqlVariant ServerProperty(SqlString propertyName, ExpressionExecut if (orgVersion == null) orgVersion = ((RetrieveVersionResponse)dataSource.Execute(new RetrieveVersionRequest())).Version; - return new SqlVariant(DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString(orgVersion)); + return new SqlVariant(DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), dataSource.DefaultCollation.ToSqlString(orgVersion), context); } return SqlVariant.Null; } [SqlFunction(IsDeterministic = false)] - public static SqlVariant Sql_Variant_Property(SqlVariant expression, SqlString property) + public static SqlVariant Sql_Variant_Property(SqlVariant expression, SqlString property, ExpressionExecutionContext context) { if (property.IsNull) return SqlVariant.Null; @@ -1057,42 +1416,42 @@ public static SqlVariant Sql_Variant_Property(SqlVariant expression, SqlString p { case "basetype": if (expression.BaseType == null) - return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), SqlString.Null); + return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), SqlString.Null, context); if (expression.BaseType is SqlDataTypeReference sqlType) - return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString(sqlType.SqlDataTypeOption.ToString().ToLowerInvariant())); + return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString(sqlType.SqlDataTypeOption.ToString().ToLowerInvariant()), context); - return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString(expression.BaseType.ToSql())); + return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString(expression.BaseType.ToSql()), context); case "precision": if (expression.BaseType == null) - return new SqlVariant(DataTypeHelpers.Int, SqlInt32.Null); + return new SqlVariant(DataTypeHelpers.Int, SqlInt32.Null, context); - return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(expression.BaseType.GetPrecision())); + return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(expression.BaseType.GetPrecision()), context); case "scale": if (expression.BaseType == null) - return new SqlVariant(DataTypeHelpers.Int, SqlInt32.Null); + return new SqlVariant(DataTypeHelpers.Int, SqlInt32.Null, context); - return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(expression.BaseType.GetScale())); + return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(expression.BaseType.GetScale()), context); case "totalbytes": if (expression.BaseType == null) - return new SqlVariant(DataTypeHelpers.Int, SqlInt32.Null); + return new SqlVariant(DataTypeHelpers.Int, SqlInt32.Null, context); - return new SqlVariant(DataTypeHelpers.Int, DataLength(expression.Value, expression.BaseType)); + return new SqlVariant(DataTypeHelpers.Int, DataLength(expression.Value, expression.BaseType), context); case "collation": if (!(expression.BaseType is SqlDataTypeReferenceWithCollation coll)) - return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), SqlString.Null); + return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), SqlString.Null, context); - return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString(coll.Collation.Name)); + return new SqlVariant(DataTypeHelpers.NVarChar(128, Collation.USEnglish, CollationLabel.CoercibleDefault), Collation.USEnglish.ToSqlString(coll.Collation.Name), context); case "maxlength": if (expression.BaseType == null) - return new SqlVariant(DataTypeHelpers.Int, SqlInt32.Null); + return new SqlVariant(DataTypeHelpers.Int, SqlInt32.Null, context); - return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(expression.BaseType.GetSize())); + return new SqlVariant(DataTypeHelpers.Int, new SqlInt32(expression.BaseType.GetSize()), context); } return SqlVariant.Null; @@ -1661,11 +2020,20 @@ class TargetTypeAttribute : Attribute } /// - /// Indicates that the parameter gives the type of the preceding parameter + /// Indicates that the parameter gives the orignal SQL type of another parameter /// [AttributeUsage(AttributeTargets.Parameter)] class SourceTypeAttribute : Attribute { + public SourceTypeAttribute(string sourceParameter) + { + SourceParameter = sourceParameter; + } + + /// + /// Returns the name of the parameter this provides the original type of + /// + public string SourceParameter { get; } } /// @@ -1683,4 +2051,26 @@ class OptionalAttribute : Attribute class CollationSensitiveAttribute : Attribute { } + + /// + /// The available date parts for the DATEPART function + /// + enum DatePart + { + Year, + Quarter, + Month, + DayOfYear, + Day, + Week, + WeekDay, + Hour, + Minute, + Second, + Millisecond, + Microsecond, + Nanosecond, + TZOffset, + ISOWeek, + } } diff --git a/MarkMpn.Sql4Cds.Engine/IQueryExecutionOptions.cs b/MarkMpn.Sql4Cds.Engine/IQueryExecutionOptions.cs index 507d425c..ea1169f2 100644 --- a/MarkMpn.Sql4Cds.Engine/IQueryExecutionOptions.cs +++ b/MarkMpn.Sql4Cds.Engine/IQueryExecutionOptions.cs @@ -93,6 +93,11 @@ interface IQueryExecutionOptions /// string PrimaryDataSource { get; } + /// + /// An event that is fired when the changes + /// + event EventHandler PrimaryDataSourceChanged; + /// /// Returns the unique identifier of the current user /// diff --git a/MarkMpn.Sql4Cds.Engine/LayeredDictionary.cs b/MarkMpn.Sql4Cds.Engine/LayeredDictionary.cs new file mode 100644 index 00000000..3b952f73 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/LayeredDictionary.cs @@ -0,0 +1,135 @@ +using System.Collections; +using System.Collections.Generic; +using System.Linq; + +namespace MarkMpn.Sql4Cds.Engine +{ + /// + /// A dictionary that combines multiple dictionaries into a single view + /// + /// + /// + class LayeredDictionary : IDictionary + { + private readonly IDictionary[] _inner; + private readonly IDictionary _fallback; + + public LayeredDictionary(params IDictionary[] inner) + { + _inner = inner; + _fallback = inner.Last(); + } + + public TValue this[TKey key] + { + get + { + foreach (var dict in _inner) + { + if (dict.TryGetValue(key, out var value)) + return value; + } + + throw new KeyNotFoundException(); + } + set + { + foreach (var dict in _inner) + { + if (dict == _fallback || dict.ContainsKey(key)) + { + dict[key] = value; + return; + } + } + } + } + + public ICollection Keys => _inner.SelectMany(d => d.Keys).ToArray(); + + public ICollection Values => _inner.SelectMany(d => d.Values).ToArray(); + + public int Count => _inner.Sum(d => d.Count); + + public bool IsReadOnly => _inner.Any(d => d.IsReadOnly); + + public void Add(TKey key, TValue value) + { + _fallback.Add(key, value); + } + + public void Add(KeyValuePair item) + { + _fallback.Add(item); + } + + public void Clear() + { + foreach (var dict in _inner) + dict.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return _inner.Any(d => d.Contains(item)); + } + + public bool ContainsKey(TKey key) + { + return _inner.Any(d => d.ContainsKey(key)); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + foreach (var dict in _inner) + { + dict.CopyTo(array, arrayIndex); + arrayIndex += dict.Count; + } + } + + public IEnumerator> GetEnumerator() + { + return _inner.SelectMany(d => d).GetEnumerator(); + } + + public bool Remove(TKey key) + { + foreach (var dict in _inner) + { + if (dict.Remove(key)) + return true; + } + + return false; + } + + public bool Remove(KeyValuePair item) + { + foreach (var dict in _inner) + { + if (dict.Remove(item)) + return true; + } + + return false; + } + + public bool TryGetValue(TKey key, out TValue value) + { + foreach (var dict in _inner) + { + if (dict.TryGetValue(key, out value)) + return true; + } + + value = default; + return false; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.csproj b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.csproj index 73e366a1..fb51d2ca 100644 --- a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.csproj +++ b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.csproj @@ -36,7 +36,6 @@ - diff --git a/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs b/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs index 6ae8895b..27dbdafd 100644 --- a/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs +++ b/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs @@ -46,6 +46,8 @@ class StubOptions : IQueryExecutionOptions public string PrimaryDataSource => throw new NotImplementedException(); + public event EventHandler PrimaryDataSourceChanged; + public Guid UserId => throw new NotImplementedException(); public bool QuotedIdentifiers => throw new NotImplementedException(); diff --git a/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs b/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs index 83858604..984294d6 100644 --- a/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using MarkMpn.Sql4Cds.Engine.ExecutionPlan; using Microsoft.SqlServer.TransactSql.ScriptDom; using Microsoft.Xrm.Sdk; using Microsoft.Xrm.Sdk.Metadata; @@ -251,6 +252,52 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM throw new ApplicationException("Unknown attribute type " + attrMetadata.GetType()); } + + /// + /// Converts a constant string value to an expression that generates the value of the appropriate type for use in a DML operation + /// + /// The attribute to convert the value for + /// The string representation of the value + /// Indicates if the value is a variable name + /// The context that the expression is being compiled in + /// The data source that the operation will be performed in + /// An expression that returns the value in the appropriate type + internal static ScalarExpression GetDmlValue(this AttributeMetadata attribute, string value, bool isVariable, ExpressionCompilationContext context, DataSource dataSource) + { + var expr = (ScalarExpression)new StringLiteral { Value = value }; + + if (isVariable) + expr = new VariableReference { Name = value }; + + expr.GetType(context, out var exprType); + var attrType = attribute.GetAttributeSqlType(dataSource, false); + + if (DataTypeHelpers.IsSameAs(exprType, attrType)) + return expr; + + if (attribute.IsPrimaryId == true) + { + expr = new FunctionCall + { + FunctionName = new Identifier { Value = nameof(ExpressionFunctions.CreateLookup) }, + Parameters = + { + new StringLiteral { Value = attribute.EntityLogicalName }, + expr + } + }; + } + else + { + expr = new CastCall + { + Parameter = expr, + DataType = attrType + }; + } + + return expr; + } } /// diff --git a/MarkMpn.Sql4Cds.Engine/NodeContext.cs b/MarkMpn.Sql4Cds.Engine/NodeContext.cs index 01d16b73..27f792b1 100644 --- a/MarkMpn.Sql4Cds.Engine/NodeContext.cs +++ b/MarkMpn.Sql4Cds.Engine/NodeContext.cs @@ -24,12 +24,12 @@ class NodeCompilationContext /// The names and types of the parameters that are available to the query /// A callback function to log messages public NodeCompilationContext( - IDictionary dataSources, + SessionContext session, IQueryExecutionOptions options, IDictionary parameterTypes, Action log) { - DataSources = dataSources; + Session = session; Options = options; ParameterTypes = parameterTypes; GlobalCalculations = new NestedLoopNode @@ -58,7 +58,7 @@ public NodeCompilationContext( NodeCompilationContext parentContext, IDictionary parameterTypes) { - DataSources = parentContext.DataSources; + Session = parentContext.Session; Options = parentContext.Options; ParameterTypes = parameterTypes; GlobalCalculations = parentContext.GlobalCalculations; @@ -67,9 +67,9 @@ public NodeCompilationContext( } /// - /// Returns the data sources that are available to the query + /// Returns the connection session the query will be executed in /// - public IDictionary DataSources { get; } + public SessionContext Session { get; } /// /// Returns the options that the query will be executed with @@ -84,7 +84,7 @@ public NodeCompilationContext( /// /// Returns the details of the primary data source /// - public DataSource PrimaryDataSource => DataSources[Options.PrimaryDataSource]; + public DataSource PrimaryDataSource => Session.DataSources[Options.PrimaryDataSource]; /// /// Returns a which can be used to calculate global values to be injected into other nodes @@ -142,12 +142,40 @@ class NodeExecutionContext : NodeCompilationContext /// The current value of each parameter /// A callback function to log messages public NodeExecutionContext( - IDictionary dataSources, + SessionContext session, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, Action log) - : base(dataSources, options, parameterTypes, log) + : base(session, options, parameterTypes, log) + { + ParameterValues = parameterValues; + } + + /// + /// Creates a new based on a + /// + /// The to inherit settings from + /// The values to use for any parameters + public NodeExecutionContext( + NodeCompilationContext parentContext, + IDictionary parameterValues) + : base(parentContext, parentContext.ParameterTypes) + { + ParameterValues = parameterValues; + } + + /// + /// Creates a new based on another context but with additional parameters for a subquery + /// + /// The to inherit settings from + /// The names and types of the parameters that are available to the subquery + /// The current value of each parameter + public NodeExecutionContext( + NodeExecutionContext parentContext, + IDictionary parameterTypes, + IDictionary parameterValues) + : base(parentContext, parameterTypes) { ParameterValues = parameterValues; } @@ -157,6 +185,9 @@ public NodeExecutionContext( /// public IDictionary ParameterValues { get; } + /// + /// Returns or sets the current error + /// public Sql4CdsError Error { get; set; } } @@ -174,12 +205,12 @@ class ExpressionCompilationContext : NodeCompilationContext /// The schema of data which is available to the expression /// The schema of data prior to aggregation public ExpressionCompilationContext( - IDictionary dataSources, + SessionContext session, IQueryExecutionOptions options, IDictionary parameterTypes, INodeSchema schema, INodeSchema nonAggregateSchema) - : base(dataSources, options, parameterTypes, null) + : base(session, options, parameterTypes, null) { Schema = schema; NonAggregateSchema = nonAggregateSchema; @@ -195,7 +226,7 @@ public ExpressionCompilationContext( NodeCompilationContext nodeContext, INodeSchema schema, INodeSchema nonAggregateSchema) - : base(nodeContext.DataSources, nodeContext.Options, nodeContext.ParameterTypes, nodeContext.Log) + : base(nodeContext.Session, nodeContext.Options, nodeContext.ParameterTypes, nodeContext.Log) { Schema = schema; NonAggregateSchema = nonAggregateSchema; @@ -228,13 +259,13 @@ class ExpressionExecutionContext : NodeExecutionContext /// The values for the current row the expression is being evaluated for /// The current value of each parameter public ExpressionExecutionContext( - IDictionary dataSources, + SessionContext session, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, Action log, Entity entity) - : base(dataSources, options, parameterTypes, parameterValues, log) + : base(session, options, parameterTypes, parameterValues, log) { Entity = entity; } @@ -249,7 +280,7 @@ public ExpressionExecutionContext( /// representing each row as it is processed. /// public ExpressionExecutionContext(NodeExecutionContext nodeContext) - : base(nodeContext.DataSources, nodeContext.Options, nodeContext.ParameterTypes, nodeContext.ParameterValues, nodeContext.Log) + : base(nodeContext, nodeContext.ParameterValues) { Entity = null; Error = nodeContext.Error; @@ -268,7 +299,7 @@ public ExpressionExecutionContext(NodeExecutionContext nodeContext) /// representing each row as it is processed. /// public ExpressionExecutionContext(ExpressionCompilationContext compilationContext) - : base(compilationContext.DataSources, compilationContext.Options, compilationContext.ParameterTypes, null, null) + : base(compilationContext, null) { Entity = null; } diff --git a/MarkMpn.Sql4Cds.Engine/SessionContext.cs b/MarkMpn.Sql4Cds.Engine/SessionContext.cs new file mode 100644 index 00000000..d8dda573 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/SessionContext.cs @@ -0,0 +1,233 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Crm.Sdk.Messages; +using Microsoft.SqlServer.TransactSql.ScriptDom; +#if NETCOREAPP +using Microsoft.PowerPlatform.Dataverse.Client; +#else +using Microsoft.Xrm.Tooling.Connector; +#endif + +namespace MarkMpn.Sql4Cds.Engine +{ + /// + /// Holds context for a session that should be persisted across multiple queries + /// + class SessionContext + { + /// + /// Provides just-in-time access to global variables that require a service request to determine the value for + /// + class SessionContextVariables : IDictionary + { + private readonly SessionContext _context; + private readonly Dictionary> _values; + + public SessionContextVariables(SessionContext context) + { + _context = context; + _values = new Dictionary>(StringComparer.OrdinalIgnoreCase); + Reset(); + } + + public void Reset() + { + _values["@@SERVERNAME"] = new Lazy(() => GetServerName()); + _values["@@VERSION"] = new Lazy(() => GetVersion()); + } + + private SqlString GetVersion() + { + var dataSource = _context.DataSources[_context._options.PrimaryDataSource]; + string orgVersion = null; + +#if NETCOREAPP + if (dataSource.Connection is ServiceClient svc) + orgVersion = svc.ConnectedOrgVersion.ToString(); +#else + if (dataSource.Connection is CrmServiceClient svc) + orgVersion = svc.ConnectedOrgVersion.ToString(); +#endif + + if (orgVersion == null) + orgVersion = ((RetrieveVersionResponse)dataSource.Execute(new RetrieveVersionRequest())).Version; + + var assembly = typeof(Sql4CdsConnection).Assembly; + var assemblyVersion = assembly.GetName().Version; + var assemblyCopyright = assembly + .GetCustomAttributes(typeof(AssemblyCopyrightAttribute), false) + .OfType() + .FirstOrDefault()? + .Copyright; + var assemblyFilename = assembly.Location; + var assemblyDate = System.IO.File.GetLastWriteTime(assemblyFilename); + + return $"Microsoft Dataverse - {orgVersion}\r\n\tSQL 4 CDS - {assemblyVersion}\r\n\t{assemblyDate:MMM dd yyyy HH:mm:ss}\r\n\t{assemblyCopyright}"; + } + + private SqlString GetServerName() + { + var dataSource = _context.DataSources[_context._options.PrimaryDataSource]; + +#if NETCOREAPP + var svc = dataSource.Connection as ServiceClient; + + if (svc != null) + return svc.ConnectedOrgUriActual.Host; +#else + var svc = dataSource.Connection as CrmServiceClient; + + if (svc != null) + return svc.CrmConnectOrgUriActual.Host; +#endif + + return dataSource.Name; + } + + public INullable this[string key] + { + get => _values[key].Value; + set => throw new NotImplementedException(); + } + + public ICollection Keys => _values.Keys; + + public ICollection Values => _values.Values.Select(v => v.Value).ToArray(); + + public int Count => _values.Count; + + public bool IsReadOnly => true; + + public void Add(string key, INullable value) + { + throw new NotImplementedException(); + } + + public void Add(KeyValuePair item) + { + throw new NotImplementedException(); + } + + public void Clear() + { + throw new NotImplementedException(); + } + + public bool Contains(KeyValuePair item) + { + return _values.TryGetValue(item.Key, out var val) + && val.Value.Equals(item.Value); + } + + public bool ContainsKey(string key) + { + return _values.ContainsKey(key); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + foreach (var item in this) + { + array[arrayIndex] = item; + arrayIndex++; + } + } + + public IEnumerator> GetEnumerator() + { + foreach (var item in _values) + yield return new KeyValuePair(item.Key, item.Value.Value); + } + + public bool Remove(string key) + { + throw new NotImplementedException(); + } + + public bool Remove(KeyValuePair item) + { + throw new NotImplementedException(); + } + + public bool TryGetValue(string key, out INullable value) + { + if (!_values.TryGetValue(key, out var lazy)) + { + value = default; + return false; + } + + value = lazy.Value; + return true; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } + + private readonly IQueryExecutionOptions _options; + private readonly SessionContextVariables _variables; + + public SessionContext(IDictionary dataSources, IQueryExecutionOptions options) + { + _options = options; + DataSources = dataSources; + DateFormat = DateFormat.mdy; + _variables = new SessionContextVariables(this); + + GlobalVariableTypes = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + ["@@IDENTITY"] = DataTypeHelpers.EntityReference, + ["@@ROWCOUNT"] = DataTypeHelpers.Int, + ["@@ERROR"] = DataTypeHelpers.Int, + }; + + GlobalVariableValues = new LayeredDictionary( + new Dictionary(StringComparer.OrdinalIgnoreCase) + { + ["@@IDENTITY"] = SqlEntityReference.Null, + ["@@ROWCOUNT"] = (SqlInt32)0, + ["@@ERROR"] = (SqlInt32)0, + }, + _variables); + + GetServerDetails(); + _options.PrimaryDataSourceChanged += (_, __) => GetServerDetails(); + } + + private void GetServerDetails() + { + GlobalVariableTypes["@@SERVERNAME"] = DataTypeHelpers.NVarChar(100, DataSources[_options.PrimaryDataSource].DefaultCollation, CollationLabel.CoercibleDefault); + GlobalVariableTypes["@@VERSION"] = DataTypeHelpers.NVarChar(Int32.MaxValue, DataSources[_options.PrimaryDataSource].DefaultCollation, CollationLabel.CoercibleDefault); + _variables.Reset(); + } + + /// + /// Returns the data sources that are available to the query + /// + public IDictionary DataSources { get; } + + /// + /// Returns or sets the current SET DATEFORMAT option + /// + public DateFormat DateFormat { get; set; } + + /// + /// Returns the types of the global variables + /// + internal Dictionary GlobalVariableTypes { get; } + + /// + /// Returns the values of the global variables + /// + internal IDictionary GlobalVariableValues { get; } + } +} diff --git a/MarkMpn.Sql4Cds.Engine/SqlDateTypes.cs b/MarkMpn.Sql4Cds.Engine/SqlDateTypes.cs index e37fe40c..3058b8a2 100644 --- a/MarkMpn.Sql4Cds.Engine/SqlDateTypes.cs +++ b/MarkMpn.Sql4Cds.Engine/SqlDateTypes.cs @@ -1,300 +1,600 @@ using System; using System.Collections.Generic; using System.Data.SqlTypes; +using System.Globalization; +using System.Linq; using System.Text; +using System.Threading; +using Microsoft.Crm.Sdk.Messages; namespace MarkMpn.Sql4Cds.Engine { - public struct SqlDate : INullable, IComparable + public struct SqlSmallDateTime : INullable, IComparable { - private SqlDateTime _dt; + private readonly DateTime? _dt; - public SqlDate(SqlDateTime dt) + public SqlSmallDateTime(DateTime? dt) { - _dt = dt.IsNull ? dt : dt.Value.Date; + if (dt == null) + { + _dt = dt; + } + else + { + // Value is rounded to the nearest minute + // https://learn.microsoft.com/en-us/sql/t-sql/functions/dateadd-transact-sql?view=sql-server-ver16#return-values-for-a-smalldatetime-date-and-a-second-or-fractional-seconds-datepart + _dt = new DateTime(dt.Value.Year, dt.Value.Month, dt.Value.Day, dt.Value.Hour, dt.Value.Minute, 0); + + if (dt.Value.TimeOfDay.Seconds >= 30) + _dt = _dt.Value.AddMinutes(1); + } } - public static readonly SqlDate Null = new SqlDate(SqlDateTime.Null); + public static readonly SqlSmallDateTime Null = new SqlSmallDateTime(null); + + public static readonly SqlSmallDateTime MinValue = new SqlSmallDateTime(new DateTime(1900, 1, 1)); + + public static readonly SqlSmallDateTime MaxValue = new SqlSmallDateTime(new DateTime(2079, 6, 6, 23, 59, 0)); - public bool IsNull => _dt.IsNull; + public bool IsNull => _dt == null; public int CompareTo(object obj) { - if (obj is SqlDate dt) - obj = dt._dt; + var value = (SqlSmallDateTime)obj; + + if (IsNull) + { + if (!value.IsNull) + { + return -1; + } + return 0; + } + + if (value.IsNull) + { + return 1; + } + + if (this < value) + { + return -1; + } - return _dt.CompareTo(obj); + if (this > value) + { + return 1; + } + + return 0; } public override int GetHashCode() { - return _dt.GetHashCode(); + return _dt?.GetHashCode() ?? 0; } public override bool Equals(object obj) { - if (obj is SqlDate dt) - obj = dt._dt; + if (!(obj is SqlSmallDateTime dt)) + return false; - return _dt.Equals(obj); + return _dt == dt._dt; } public DateTime Value => _dt.Value; - public static SqlBoolean operator ==(SqlDate x, SqlDate y) => x._dt == y._dt; + public static SqlBoolean operator ==(SqlSmallDateTime x, SqlSmallDateTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt == y._dt; - public static SqlBoolean operator !=(SqlDate x, SqlDate y) => x._dt != y._dt; + public static SqlBoolean operator !=(SqlSmallDateTime x, SqlSmallDateTime y) => !(x == y); - public static SqlBoolean operator <(SqlDate x, SqlDate y) => x._dt < y._dt; + public static SqlBoolean operator <(SqlSmallDateTime x, SqlSmallDateTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt < y._dt; - public static SqlBoolean operator >(SqlDate x, SqlDate y) => x._dt > y._dt; + public static SqlBoolean operator >(SqlSmallDateTime x, SqlSmallDateTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt > y._dt; - public static SqlBoolean operator <=(SqlDate x, SqlDate y) => x._dt <= y._dt; + public static SqlBoolean operator <=(SqlSmallDateTime x, SqlSmallDateTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt <= y._dt; - public static SqlBoolean operator >=(SqlDate x, SqlDate y) => x._dt >= y._dt; + public static SqlBoolean operator >=(SqlSmallDateTime x, SqlSmallDateTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt >= y._dt; - public static implicit operator SqlDateTime(SqlDate dt) + public static implicit operator SqlDateTime(SqlSmallDateTime dt) { - return dt._dt; + if (dt.IsNull) + return SqlDateTime.Null; + + return dt._dt.Value; } - public static implicit operator SqlDateTime2(SqlDate dt) + public static implicit operator SqlSmallDateTime(SqlDateTime dt) { - return dt._dt; + if (dt.IsNull) + return Null; + + return new SqlSmallDateTime(dt.Value); } - public static implicit operator SqlDateTimeOffset(SqlDate dt) + public static implicit operator SqlDateTime2(SqlSmallDateTime dt) { - return dt._dt; + if (dt.IsNull) + return SqlDateTime2.Null; + + return new SqlDateTime2(dt.Value); } - public static implicit operator SqlDate(SqlDateTime dt) + public static implicit operator SqlSmallDateTime(SqlDateTime2 dt) { - return new SqlDate(dt); + if (dt.IsNull) + return Null; + + return new SqlSmallDateTime(dt.Value); } - public static implicit operator SqlDate(SqlDateTime2 dt) + public static implicit operator SqlDateTimeOffset(SqlSmallDateTime dt) { - return new SqlDate(dt); + if (dt.IsNull) + return SqlDateTimeOffset.Null; + + return new SqlDateTimeOffset(dt.Value); } - public static implicit operator SqlDate(SqlDateTimeOffset dt) + public static implicit operator SqlSmallDateTime(SqlDateTimeOffset dt) { - return new SqlDate(dt); + if (dt.IsNull) + return Null; + + return new SqlSmallDateTime(dt.Value.DateTime); } - public static implicit operator SqlDate(SqlString str) + public static implicit operator SqlDate(SqlSmallDateTime dt) { - return new SqlDate((SqlDateTime)str); + if (dt.IsNull) + return SqlDate.Null; + + return new SqlDate(dt.Value); } - public override string ToString() + public static implicit operator SqlSmallDateTime(SqlDate dt) + { + if (dt.IsNull) + return Null; + + return new SqlSmallDateTime(dt.Value); + } + + public static implicit operator SqlTime(SqlSmallDateTime dt) + { + if (dt.IsNull) + return SqlTime.Null; + + return new SqlTime(dt.Value.TimeOfDay); + } + + public static implicit operator SqlSmallDateTime(SqlTime dt) { - return _dt.ToString(); + if (dt.IsNull) + return Null; + + return new SqlSmallDateTime(new DateTime(1900, 1, 1) + dt.Value); } } - public struct SqlTime : INullable, IComparable + public struct SqlDate : INullable, IComparable { - private SqlDateTime _dt; + private readonly DateTime? _dt; + + public SqlDate(DateTime? dt) + { + _dt = dt; + } + + public static readonly SqlDate Null = new SqlDate(null); + + public bool IsNull => _dt == null; - public SqlTime(SqlDateTime dt) + public int CompareTo(object obj) + { + var value = (SqlDate)obj; + + if (IsNull) + { + if (!value.IsNull) + { + return -1; + } + return 0; + } + + if (value.IsNull) + { + return 1; + } + + if (this < value) + { + return -1; + } + + if (this > value) + { + return 1; + } + + return 0; + } + + public override int GetHashCode() + { + return _dt?.GetHashCode() ?? 0; + } + + public override bool Equals(object obj) + { + if (!(obj is SqlDate dt)) + return false; + + return _dt == dt._dt; + } + + public DateTime Value => _dt.Value; + + public static SqlBoolean operator ==(SqlDate x, SqlDate y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt == y._dt; + + public static SqlBoolean operator !=(SqlDate x, SqlDate y) => !(x == y); + + public static SqlBoolean operator <(SqlDate x, SqlDate y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt < y._dt; + + public static SqlBoolean operator >(SqlDate x, SqlDate y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt > y._dt; + + public static SqlBoolean operator <=(SqlDate x, SqlDate y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt <= y._dt; + + public static SqlBoolean operator >=(SqlDate x, SqlDate y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt >= y._dt; + + public static implicit operator SqlDateTime(SqlDate dt) + { + if (dt.IsNull) + return SqlDateTime.Null; + + return dt._dt.Value; + } + + public static implicit operator SqlDate(SqlDateTime dt) + { + if (dt.IsNull) + return Null; + + return new SqlDate(dt.Value.Date); + } + + public static implicit operator SqlDate(DateTime? dt) { - _dt = dt.IsNull ? dt : new DateTime(1900, 1, 1).AddTicks(dt.TimeTicks * TimeSpan.TicksPerSecond / SqlDateTime.SQLTicksPerSecond); + return new SqlDate(dt); } + public override string ToString() + { + return _dt?.ToString() ?? "Null"; + } + + /// + /// Converts a string literal to the equivalent value + /// + /// + /// The string literal to parse + /// The current DATEFORMAT setting to control the expected formats + /// The parsed version of the date + /// if the was parsed successfully, or otherwise + internal static bool TryParse(SqlString value, DateFormat dateFormat, out SqlDate date) + { + return SqlDateParsing.TryParse(value, dateFormat, out date); + } + } + + public struct SqlTime : INullable, IComparable + { + private readonly TimeSpan? _ts; + private static readonly DateTime _defaultDate = new DateTime(1900, 1, 1); + public SqlTime(TimeSpan? ts) { - _dt = ts == null ? SqlDateTime.Null : new DateTime(1900, 1, 1).Add(ts.Value); + _ts = ts; } - public static readonly SqlTime Null = new SqlTime(SqlDateTime.Null); + public static readonly SqlTime Null = new SqlTime(null); - public bool IsNull => _dt.IsNull; + public bool IsNull => _ts == null; public int CompareTo(object obj) { - if (obj is SqlTime dt) - obj = dt._dt; + var value = (SqlTime)obj; + + if (IsNull) + { + if (!value.IsNull) + { + return -1; + } + return 0; + } + + if (value.IsNull) + { + return 1; + } + + if (this < value) + { + return -1; + } - return _dt.CompareTo(obj); + if (this > value) + { + return 1; + } + + return 0; } public override int GetHashCode() { - return _dt.GetHashCode(); + return _ts?.GetHashCode() ?? 0; } public override bool Equals(object obj) { - if (obj is SqlTime dt) - obj = dt._dt; + if (!(obj is SqlTime dt)) + return false; - return _dt.Equals(obj); + return _ts == dt._ts; } - public TimeSpan Value => _dt.Value.TimeOfDay; + public TimeSpan Value => _ts.Value; - public static SqlBoolean operator ==(SqlTime x, SqlTime y) => x._dt == y._dt; + public static SqlBoolean operator ==(SqlTime x, SqlTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._ts == y._ts; - public static SqlBoolean operator !=(SqlTime x, SqlTime y) => x._dt != y._dt; + public static SqlBoolean operator !=(SqlTime x, SqlTime y) => !(x == y); - public static SqlBoolean operator <(SqlTime x, SqlTime y) => x._dt < y._dt; + public static SqlBoolean operator <(SqlTime x, SqlTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._ts < y._ts; - public static SqlBoolean operator >(SqlTime x, SqlTime y) => x._dt > y._dt; + public static SqlBoolean operator >(SqlTime x, SqlTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._ts > y._ts; - public static SqlBoolean operator <=(SqlTime x, SqlTime y) => x._dt <= y._dt; + public static SqlBoolean operator <=(SqlTime x, SqlTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._ts <= y._ts; - public static SqlBoolean operator >=(SqlTime x, SqlTime y) => x._dt >= y._dt; + public static SqlBoolean operator >=(SqlTime x, SqlTime y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._ts >= y._ts; public static implicit operator SqlDateTime(SqlTime dt) { - return dt._dt; - } + if (dt.IsNull) + return SqlDateTime.Null; + return _defaultDate + dt._ts.Value; + } + /* public static implicit operator SqlDateTime2(SqlTime dt) { - return dt._dt; + if (dt.IsNull) + return SqlDateTime2.Null; + + return _defaultDate + dt._ts.Value; } public static implicit operator SqlDateTimeOffset(SqlTime dt) { - return dt._dt; - } + if (dt.IsNull) + return SqlDateTimeOffset.Null; + return (DateTimeOffset)(_defaultDate + dt._ts.Value); + } + */ public static implicit operator SqlTime(SqlDateTime dt) { - return new SqlTime(dt); - } + if (dt.IsNull) + return Null; + return new SqlTime(dt.Value.TimeOfDay); + } + /* public static implicit operator SqlTime(SqlDateTime2 dt) { - return new SqlTime(dt); + if (dt.IsNull) + return Null; + + return new SqlTime(dt.Value.TimeOfDay); } public static implicit operator SqlTime(SqlDateTimeOffset dt) { - return new SqlTime(dt); - } + if (dt.IsNull) + return Null; + return new SqlTime(dt.Value.TimeOfDay); + } + */ public static implicit operator SqlTime(SqlString str) { - return new SqlTime((SqlDateTime)str); + return (SqlTime)(SqlDateTime)str; + } + + public static implicit operator SqlTime(TimeSpan? ts) + { + return new SqlTime(ts); } public override string ToString() { - return _dt.ToString(); + return _ts?.ToString() ?? "Null"; } } public struct SqlDateTime2 : INullable, IComparable { - private SqlDateTime _dt; + private readonly DateTime? _dt; + private static readonly string[] x_DateTimeFormats = new string[8] { "MMM d yyyy hh:mm:ss:ffftt", "MMM d yyyy hh:mm:ss:fff", "d MMM yyyy hh:mm:ss:ffftt", "d MMM yyyy hh:mm:ss:fff", "hh:mm:ss:ffftt", "hh:mm:ss:fff", "yyMMdd", "yyyyMMdd" }; + private static readonly DateTime _defaultDate = new DateTime(1900, 1, 1); - public SqlDateTime2(SqlDateTime dt) + public SqlDateTime2(DateTime? dt) { _dt = dt; } - public static readonly SqlDateTime2 Null = new SqlDateTime2(SqlDateTime.Null); + public static readonly SqlDateTime2 Null = new SqlDateTime2(null); - public bool IsNull => _dt.IsNull; + public bool IsNull => _dt == null; public int CompareTo(object obj) { - if (obj is SqlDateTime2 dt) - obj = dt._dt; + var value = (SqlDateTime2)obj; + + if (IsNull) + { + if (!value.IsNull) + { + return -1; + } + return 0; + } + + if (value.IsNull) + { + return 1; + } + + if (this < value) + { + return -1; + } - return _dt.CompareTo(obj); + if (this > value) + { + return 1; + } + + return 0; } public override int GetHashCode() { - return _dt.GetHashCode(); + return _dt?.GetHashCode() ?? 0; } public override bool Equals(object obj) { - if (obj is SqlDateTime2 dt) - obj = dt._dt; + if (!(obj is SqlDateTime2 value)) + return false; - return _dt.Equals(obj); + return _dt == value._dt; } public DateTime Value => _dt.Value; - public static SqlBoolean operator ==(SqlDateTime2 x, SqlDateTime2 y) => x._dt == y._dt; + public static SqlBoolean operator ==(SqlDateTime2 x, SqlDateTime2 y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt == y._dt; - public static SqlBoolean operator !=(SqlDateTime2 x, SqlDateTime2 y) => x._dt != y._dt; + public static SqlBoolean operator !=(SqlDateTime2 x, SqlDateTime2 y) => !(x == y); - public static SqlBoolean operator <(SqlDateTime2 x, SqlDateTime2 y) => x._dt < y._dt; + public static SqlBoolean operator <(SqlDateTime2 x, SqlDateTime2 y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt < y._dt; - public static SqlBoolean operator >(SqlDateTime2 x, SqlDateTime2 y) => x._dt > y._dt; + public static SqlBoolean operator >(SqlDateTime2 x, SqlDateTime2 y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt > y._dt; - public static SqlBoolean operator <=(SqlDateTime2 x, SqlDateTime2 y) => x._dt <= y._dt; + public static SqlBoolean operator <=(SqlDateTime2 x, SqlDateTime2 y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt <= y._dt; - public static SqlBoolean operator >=(SqlDateTime2 x, SqlDateTime2 y) => x._dt >= y._dt; + public static SqlBoolean operator >=(SqlDateTime2 x, SqlDateTime2 y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt >= y._dt; public static implicit operator SqlDateTime(SqlDateTime2 dt) { - return dt._dt; + if (dt.IsNull) + return SqlDateTime.Null; + + return (SqlDateTime)dt._dt.Value; } public static implicit operator SqlDate(SqlDateTime2 dt) { - return dt._dt; + if (dt.IsNull) + return SqlDate.Null; + + return (SqlDate)dt._dt.Value.Date; } public static implicit operator SqlTime(SqlDateTime2 dt) { - return dt._dt; - } + if (dt.IsNull) + return SqlTime.Null; + return (SqlTime)dt._dt.Value.TimeOfDay; + } + /* public static implicit operator SqlDateTimeOffset(SqlDateTime2 dt) { - return dt._dt; - } + if (dt.IsNull) + return SqlDateTimeOffset.Null; + return (DateTimeOffset)dt._dt.Value; + } + */ public static implicit operator SqlDateTime2(SqlDateTime dt) { - return new SqlDateTime2(dt); + if (dt.IsNull) + return Null; + + return new SqlDateTime2(dt.Value); } public static implicit operator SqlDateTime2(SqlDate dt) { - return new SqlDateTime2(dt); + if (dt.IsNull) + return Null; + + return new SqlDateTime2(dt.Value); } public static implicit operator SqlDateTime2(SqlTime dt) { - return new SqlDateTime2(dt); - } + if (dt.IsNull) + return Null; + return new SqlDateTime2(_defaultDate + dt.Value); + } + /* public static implicit operator SqlDateTime2(SqlDateTimeOffset dt) + { + if (dt.IsNull) + return Null; + + return new SqlDateTime2(dt.Value.DateTime); + } + */ + public static implicit operator SqlDateTime2(DateTime? dt) { return new SqlDateTime2(dt); } public static implicit operator SqlDateTime2(SqlString str) { - return new SqlDateTime2((SqlDateTime)str); + if (str.IsNull) + return Null; + + DateTime value; + try + { + value = DateTime.Parse(str.Value, CultureInfo.InvariantCulture); + } + catch (FormatException) + { + DateTimeFormatInfo provider = (DateTimeFormatInfo)Thread.CurrentThread.CurrentCulture.GetFormat(typeof(DateTimeFormatInfo)); + value = DateTime.ParseExact(str.Value, x_DateTimeFormats, provider, DateTimeStyles.AllowWhiteSpaces); + } + + return new SqlDateTime2((DateTime?)value); } public override string ToString() { - return _dt.ToString(); + return _dt?.ToString() ?? "Null"; } } public struct SqlDateTimeOffset : INullable, IComparable { - private DateTimeOffset? _dt; - - public SqlDateTimeOffset(SqlDateTime dt) - { - _dt = dt.IsNull ? (DateTimeOffset?)null : dt.Value; - } + private readonly DateTimeOffset? _dt; + private static readonly DateTime _defaultDate = new DateTime(1900, 1, 1, 0, 0, 0, DateTimeKind.Utc); public SqlDateTimeOffset(DateTimeOffset? dt) { @@ -307,83 +607,124 @@ public SqlDateTimeOffset(DateTimeOffset? dt) public int CompareTo(object obj) { - if (obj is SqlDateTimeOffset dt) - obj = dt._dt; + var value = (SqlDateTimeOffset)obj; + + if (IsNull) + { + if (!value.IsNull) + { + return -1; + } + return 0; + } + + if (value.IsNull) + { + return 1; + } - if (_dt == null) - return obj == null ? 0 : -1; + if (this < value) + { + return -1; + } - if (obj == null) + if (this > value) + { return 1; + } - return _dt.Value.CompareTo((DateTimeOffset) obj); + return 0; } public override int GetHashCode() { - return _dt.GetHashCode(); + return _dt?.GetHashCode() ?? 0; } public override bool Equals(object obj) { - if (obj is SqlDateTimeOffset dt) - obj = dt._dt; + if (!(obj is SqlDateTimeOffset dt)) + return false; - return _dt.Equals(obj); + return _dt == dt._dt; } public DateTimeOffset Value => _dt.Value; - public static SqlBoolean operator ==(SqlDateTimeOffset x, SqlDateTimeOffset y) => x._dt == y._dt; + public static SqlBoolean operator ==(SqlDateTimeOffset x, SqlDateTimeOffset y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt == y._dt; - public static SqlBoolean operator !=(SqlDateTimeOffset x, SqlDateTimeOffset y) => x._dt != y._dt; + public static SqlBoolean operator !=(SqlDateTimeOffset x, SqlDateTimeOffset y) => !(x._dt != y._dt); - public static SqlBoolean operator <(SqlDateTimeOffset x, SqlDateTimeOffset y) => x._dt < y._dt; + public static SqlBoolean operator <(SqlDateTimeOffset x, SqlDateTimeOffset y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt < y._dt; - public static SqlBoolean operator >(SqlDateTimeOffset x, SqlDateTimeOffset y) => x._dt > y._dt; + public static SqlBoolean operator >(SqlDateTimeOffset x, SqlDateTimeOffset y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt > y._dt; - public static SqlBoolean operator <=(SqlDateTimeOffset x, SqlDateTimeOffset y) => x._dt <= y._dt; + public static SqlBoolean operator <=(SqlDateTimeOffset x, SqlDateTimeOffset y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt <= y._dt; - public static SqlBoolean operator >=(SqlDateTimeOffset x, SqlDateTimeOffset y) => x._dt >= y._dt; + public static SqlBoolean operator >=(SqlDateTimeOffset x, SqlDateTimeOffset y) => x.IsNull || y.IsNull ? SqlBoolean.Null : x._dt >= y._dt; public static implicit operator SqlDateTime(SqlDateTimeOffset dt) { - return dt._dt == null ? SqlDateTime.Null : (SqlDateTime)dt._dt.Value.DateTime; + if (dt.IsNull) + return SqlDateTime.Null; + + return (SqlDateTime)dt._dt.Value.DateTime; } public static implicit operator SqlDate(SqlDateTimeOffset dt) { - return dt._dt == null ? SqlDate.Null : (SqlDate)(SqlDateTime)dt._dt.Value.DateTime; + if (dt.IsNull) + return SqlDateTime.Null; + + return (SqlDate)dt._dt.Value.Date; } public static implicit operator SqlTime(SqlDateTimeOffset dt) { - return dt._dt == null ? SqlTime.Null : (SqlTime)(SqlDateTime)dt._dt.Value.DateTime; + if (dt.IsNull) + return SqlTime.Null; + + return (SqlTime)dt._dt.Value.TimeOfDay; } public static implicit operator SqlDateTime2(SqlDateTimeOffset dt) { - return dt._dt == null ? SqlDateTime2.Null : (SqlDateTime2)(SqlDateTime)dt._dt.Value.DateTime; + if (dt.IsNull) + return SqlDateTime.Null; + + return (SqlDateTime2)dt._dt.Value.DateTime; } public static implicit operator SqlDateTimeOffset(SqlDateTime dt) { - return new SqlDateTimeOffset(dt); + if (dt.IsNull) + return Null; + + return new SqlDateTimeOffset(DateTime.SpecifyKind(dt.Value, DateTimeKind.Utc)); } public static implicit operator SqlDateTimeOffset(SqlDate dt) { - return new SqlDateTimeOffset(dt); + if (dt.IsNull) + return Null; + + return new SqlDateTimeOffset(DateTime.SpecifyKind(dt.Value, DateTimeKind.Utc)); } public static implicit operator SqlDateTimeOffset(SqlTime dt) { - return new SqlDateTimeOffset(dt); + if (dt.IsNull) + return Null; + + return new SqlDateTimeOffset(_defaultDate + dt.Value); } public static implicit operator SqlDateTimeOffset(SqlDateTime2 dt) { - return new SqlDateTimeOffset(dt); + if (dt.IsNull) + return Null; + + return new SqlDateTimeOffset(DateTime.SpecifyKind(dt.Value, DateTimeKind.Utc)); } public static implicit operator SqlDateTimeOffset(SqlString str) @@ -395,9 +736,517 @@ public static implicit operator SqlDateTimeOffset(SqlString str) return new SqlDateTimeOffset(dto); } + public static implicit operator SqlDateTimeOffset(DateTimeOffset? dt) + { + return new SqlDateTimeOffset(dt); + } + public override string ToString() { - return _dt.ToString(); + return _dt?.ToString() ?? "Null"; + } + } + + /// + /// Sets the DATEFORMAT that is used to parse string literals into dates + /// + /// + enum DateFormat + { + mdy, + dmy, + ymd, + ydm, + myd, + dym + } + + /// + /// Provides methods to parse string literals into various date/time types + /// + /// + /// The built-in conversion for strings to SqlDateTime does not respect the DATEFORMAT setting. The + /// conversion from strings to the other date/time/datetime2/datetimeoffset types also support a different + /// set of formats, so we need to implement them here. Those types use a consistent set of formats between + /// them, so we implement them here in the most precise format (datetimeoffset) and then down-cast the result + /// to the other types as necessary. + /// + static class SqlDateParsing + { + /// + /// Converts a string literal to the equivalent value + /// + /// + /// The string literal to parse + /// The current DATEFORMAT setting to control the expected formats + /// The parsed version of the date + /// if the was parsed successfully, or otherwise + public static bool TryParse(SqlString value, DateFormat dateFormat, out SqlDate date) + { + var ret = TryParse(value, dateFormat, out SqlDateTimeOffset dto); + date = dto; + return ret; + } + + /// + /// Converts a string literal to the equivalent value + /// + /// + /// The string literal to parse + /// The current DATEFORMAT setting to control the expected formats + /// The parsed version of the datetime + /// if the was parsed successfully, or otherwise + public static bool TryParse(SqlString value, DateFormat dateFormat, out SqlDateTime2 dateTime2) + { + var ret = TryParse(value, dateFormat, out SqlDateTimeOffset dto); + dateTime2 = dto; + return ret; + } + + /// + /// Converts a string literal to the equivalent value + /// + /// + /// The string literal to parse + /// The current DATEFORMAT setting to control the expected formats + /// The parsed version of the time + /// if the was parsed successfully, or otherwise + public static bool TryParse(SqlString value, DateFormat dateFormat, out SqlTime time) + { + var ret = TryParse(value, dateFormat, out SqlDateTimeOffset dto); + time = dto; + return ret; + } + + /// + /// Converts a string literal to the equivalent value + /// + /// + /// The string literal to parse + /// The current DATEFORMAT setting to control the expected formats + /// The parsed version of the datetimeoffset + /// if the was parsed successfully, or otherwise + public static bool TryParse(SqlString value, DateFormat dateFormat, out SqlDateTimeOffset dateTimeOffset) + { + if (value.IsNull) + { + dateTimeOffset = SqlDateTimeOffset.Null; + return true; + } + + // Allowed formats vary depending on format, but all formats are consistent combinations of + // year, month and day with different separators + string formatStringPart1; + string formatStringPart2; + string formatStringPart3; + string[] separators = new[] { "/", "-", "." }; + + switch (dateFormat) + { + case DateFormat.mdy: + formatStringPart1 = "[M]M"; + formatStringPart2 = "[d]d"; + formatStringPart3 = "[yy]yy"; + break; + + case DateFormat.myd: + formatStringPart1 = "[M]M"; + formatStringPart2 = "[yy]yy"; + formatStringPart3 = "[d]d"; + break; + + case DateFormat.dmy: + formatStringPart1 = "[d]d"; + formatStringPart2 = "[M]M"; + formatStringPart3 = "[yy]yy"; + break; + + case DateFormat.dym: + formatStringPart1 = "[d]d"; + formatStringPart2 = "[yy]yy"; + formatStringPart3 = "[M]M"; + break; + + case DateFormat.ymd: + formatStringPart1 = "[yy]yy"; + formatStringPart2 = "[M]M"; + formatStringPart3 = "[d]d"; + break; + + default: + // ydm format is not supported for datetimeoffset/datetime2/date/time + dateTimeOffset = SqlDateTimeOffset.Null; + return false; + } + + var numericFormatStrings = separators + .Select(s => $"{formatStringPart1}{s}{formatStringPart2}{s}{formatStringPart3}"); + + var alphaFormatStrings = new[] + { + "mon [dd][,] yyyy", + "mon dd[,] [yy]", + "mon yyyy [dd]", + "[dd] mon[,] yyyy", + "dd mon[,][yy]yy", + "dd [yy]yy mon", + "[dd] yyyy mon", + "yyyy mon [dd]", + "yyyy [dd] mon" + }; + + var isoFormatStrings = new[] + { + "yyyy-MM-dd", + "yyyyMMdd" + }; + + var unseparatedFormatStrings = new[] + { + "yyMMdd", + "yyyy" + }; + + var w3cFormatString = new[] + { + "yyyy-MM-ddK" + }; + + var timeFormatStrings = new[] + { + "HH:mm", + "HH:mm:ss", + "HH:mm:ss:fffffff", + "HH:mm:ss.fffffff", + "hh:mmtt", + "hh:mm:sstt", + "hh:mm:ss:ffffffftt", + "hh:mm:ss.ffffffftt", + "hhtt", + "hh tt" + }; + + var allDateFormats = numericFormatStrings + .Concat(alphaFormatStrings) + .Concat(isoFormatStrings) + .Concat(unseparatedFormatStrings) + .Concat(w3cFormatString) + .SelectMany(f => SqlToNetFormatString(f)) + .ToArray(); + + var allTimeFormats = timeFormatStrings + .SelectMany(f => new[] { f, f + "K", f + " K" }) // Allow optional timezone with all time formats, with optional space + .SelectMany(f => SqlToNetFormatString(f)) + .ToArray(); + + var allDateTimeFormats = allDateFormats + .Concat(allDateFormats.SelectMany(d => allTimeFormats.Select(t => d + " " + t))) + .ToArray(); + + if (!DateTimeOffset.TryParseExact(value.Value.Trim(), allDateTimeFormats, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal, out var parsed)) + { + if (DateTimeOffset.TryParseExact(value.Value, allTimeFormats, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal, out parsed)) + { + dateTimeOffset = new DateTimeOffset(new DateTime(1900, 1, 1), parsed.Offset) + parsed.TimeOfDay; + return true; + } + + dateTimeOffset = SqlDateTimeOffset.Null; + return false; + } + + dateTimeOffset = parsed; + return true; + } + /// + /// Converts a string literal to the equivalent value + /// + /// + /// The string literal to parse + /// The current DATEFORMAT setting to control the expected formats + /// The parsed version of the date + /// if the was parsed successfully, or otherwise + public static bool TryParse(SqlString value, DateFormat dateFormat, out SqlDateTime date) + { + if (value.IsNull) + { + date = SqlDateTime.Null; + return true; + } + + // Allowed formats vary depending on format, but all formats are consistent combinations of + // year, month and day with different separators + string formatStringPart1; + string formatStringPart2; + string formatStringPart3; + string formatString4DigitYearPart1; + string formatString4DigitYearPart2; + string[] separators = new[] { "/", "-", "." }; + + switch (dateFormat) + { + case DateFormat.mdy: + formatStringPart1 = "[M]M"; + formatStringPart2 = "[d]d"; + formatStringPart3 = "[yy]yy"; + + formatString4DigitYearPart1 = "[M]M"; + formatString4DigitYearPart2 = "[d]d"; + break; + + case DateFormat.myd: + formatStringPart1 = "[M]M"; + formatStringPart2 = "[yy]yy"; + formatStringPart3 = "[d]d"; + + formatString4DigitYearPart1 = "[M]M"; + formatString4DigitYearPart2 = "[d]d"; + break; + + case DateFormat.dmy: + formatStringPart1 = "[d]d"; + formatStringPart2 = "[M]M"; + formatStringPart3 = "[yy]yy"; + + formatString4DigitYearPart1 = "[d]d"; + formatString4DigitYearPart2 = "[M]M"; + break; + + case DateFormat.dym: + formatStringPart1 = "[d]d"; + formatStringPart2 = "[yy]yy"; + formatStringPart3 = "[M]M"; + + formatString4DigitYearPart1 = "[d]d"; + formatString4DigitYearPart2 = "[M]M"; + break; + + case DateFormat.ymd: + formatStringPart1 = "[yy]yy"; + formatStringPart2 = "[M]M"; + formatStringPart3 = "[d]d"; + + formatString4DigitYearPart1 = "[M]M"; + formatString4DigitYearPart2 = "[d]d"; + break; + + case DateFormat.ydm: + formatStringPart1 = "[yy]yy"; + formatStringPart2 = "[d]d"; + formatStringPart3 = "[M]M"; + + formatString4DigitYearPart1 = "[d]d"; + formatString4DigitYearPart2 = "[M]M"; + break; + + default: + date = SqlDateTime.Null; + return false; + } + + var numericTimeFormatStrings = new[] + { + "", + "HH:mm", + "HH:mm:ss", + "HH:mm:ss:fff", + "HH:mm:ss.fff", + "hhtt", + "hh tt" + }; + + var numericFormatStrings = separators + .SelectMany(s => new[] { + // Parts in the expected order + $"{formatStringPart1}{s}{formatStringPart2}{s}{formatStringPart3}", + + // 4-digit year can come in any position and the other parts remain + // in their original relative order + $"yyyy{s}{formatString4DigitYearPart1}{s}{formatString4DigitYearPart2}", + $"{formatString4DigitYearPart1}{s}yyyy{s}{formatString4DigitYearPart2}", + $"{formatString4DigitYearPart1}{s}{formatString4DigitYearPart2}{s}yyyy", + }) + .SelectMany(s => numericTimeFormatStrings.Select(t => s + " " + t)); + + var alphaFormatStrings = new[] + { + "mon [dd][,] yyyy", + "mon dd[,] [yy]", + "mon yyyy [dd]", + "[dd] mon[,] yyyy", + "dd mon[,][yy]yy", + "dd [yy]yy mon", + "[dd] yyyy mon", + "yyyy mon [dd]", + "yyyy [dd] mon" + }; + + var isoFormatStrings = new[] + { + "yyyy-MM-ddTHH:mm:ss", + "yyyy-MM-ddTHH:mm:ss.fff", + "yyyyMMdd", + "yyyyMMdd HH:mm:ss", + "yyyyMMdd HH:mm:ss.fff" + }; + + var allFormats = numericFormatStrings + .Concat(alphaFormatStrings) + .Concat(isoFormatStrings) + .SelectMany(f => SqlToNetFormatString(f)) + .ToArray(); + + if (!DateTime.TryParseExact(value.Value.Trim(), allFormats, CultureInfo.InvariantCulture, DateTimeStyles.None, out var parsed)) + { + date = SqlDateTime.Null; + return false; + } + + date = parsed; + return true; + } + + /// + /// Converts a SQL format string from the SQL Server documentation to the corresponding .NET format strings + /// + /// + /// + public static string[] SqlToNetFormatString(string formatString) + { + var parts = new List(); + + // Parse the SQL format string into parts + // dd and [d]d both indicate a 1- or 2-digit day + // M and [M]M is a 1- or 2-digit month + // MM is a 2-digit month + // mon is an abbreviated or full month name + // yy is a 2-digit year + // yyyy is a 4-digit year + // [yy]yy is a 2- or 4-digit year + // HH is a 1- or 2-digit hour (24 hour clock) + // hh is a 1- or 2-digit hour (12 hour clock) + // mm is a 1- or 2-digit minute + // ss is a 1- or 2-digit second + // fff is a 1- to 3-digit fraction of a second + // fffffff is a 1- to 7-digit fraction of a second + // Anything else in [] is optional + // Anything else is a literal + // + // Note docs sometimes use m or mm for month - we need to change this to M or MM to consistently + // interpret it as a month instead of minutes + var optional = false; + + for (var i = 0; i < formatString.Length; i++) + { + var length = formatString.Length - i; + var setOptional = false; + + while (length >= 1) + { + var part = formatString.Substring(i, length); + + if (part == "d" || part == "dd" || part == "[d]d") + { + parts.Add(new[] { "d", "dd" }); + break; + } + else if (part == "M" || part == "MM" || part == "[M]M") + { + parts.Add(new[] { "M", "MM" }); + break; + } + else if (part == "mon") + { + parts.Add(new[] { "MMM", "MMMM" }); + break; + } + else if (part == "yy") + { + parts.Add(new[] { "yy" }); + break; + } + else if (part == "yyyy") + { + parts.Add(new[] { "yyyy" }); + break; + } + else if (part == "[yy]yy") + { + parts.Add(new[] { "yy", "yyyy" }); + break; + } + else if (part == "HH") + { + parts.Add(new[] { "H", "HH" }); + break; + } + else if (part == "hh") + { + parts.Add(new[] { "h", "hh" }); + break; + } + else if (part == "mm") + { + parts.Add(new[] { "m", "mm" }); + break; + } + else if (part == "ss") + { + parts.Add(new[] { "s", "ss" }); + break; + } + else if (part == "fff") + { + parts.Add(new[] { "FFF" }); + break; + } + else if (part == "fffffff") + { + parts.Add(new[] { "FFFFFFF" }); + break; + } + else if (part == "[") + { + optional = true; + setOptional = true; + break; + } + else if (length == 1) + { + // Literal + parts.Add(new[] { part }); + break; + } + else + { + // Try a shorter part + length--; + } + } + + if (length == 0) + throw new FormatException(); + + i += length - 1; + + if (optional && !setOptional) + { + parts[parts.Count - 1] = parts[parts.Count - 1].Concat(new[] { "" }).ToArray(); + optional = false; + + if (formatString[i + 1] == ']') + i++; + else + throw new FormatException(); + } + } + + var formatStrings = parts[0]; + + for (var i = 1; i < parts.Count; i++) + formatStrings = formatStrings.SelectMany(s => parts[i].Select(p => s + p)).ToArray(); + + return formatStrings.Select(s => s.Trim().Replace(" ", " ")).ToArray(); } } } diff --git a/MarkMpn.Sql4Cds.Engine/SqlVariant.cs b/MarkMpn.Sql4Cds.Engine/SqlVariant.cs index 803d9687..6606e386 100644 --- a/MarkMpn.Sql4Cds.Engine/SqlVariant.cs +++ b/MarkMpn.Sql4Cds.Engine/SqlVariant.cs @@ -11,16 +11,20 @@ namespace MarkMpn.Sql4Cds.Engine { struct SqlVariant : INullable, IComparable { + private readonly ExpressionExecutionContext _context; + private SqlVariant(bool @null) { BaseType = DataTypeHelpers.Variant; Value = null; + _context = null; } - public SqlVariant(DataTypeReference baseType, INullable value) + public SqlVariant(DataTypeReference baseType, INullable value, ExpressionExecutionContext context) { BaseType = baseType ?? throw new ArgumentNullException(nameof(baseType)); Value = value ?? throw new ArgumentNullException(nameof(value)); + _context = context ?? throw new ArgumentNullException(nameof(context)); } public static readonly SqlVariant Null = new SqlVariant(true); @@ -77,8 +81,8 @@ sqlVariant.BaseType is SqlDataTypeReferenceWithCollation coll2 && if (!SqlTypeConverter.CanMakeConsistentTypes(BaseType, sqlVariant.BaseType, null, null, null, out var consistentType)) throw new ArgumentException(); - var value1 = SqlTypeConverter.GetConversion(BaseType, consistentType)(Value); - var value2 = SqlTypeConverter.GetConversion(sqlVariant.BaseType, consistentType)(sqlVariant.Value); + var value1 = SqlTypeConverter.GetConversion(BaseType, consistentType)(Value, _context); + var value2 = SqlTypeConverter.GetConversion(sqlVariant.BaseType, consistentType)(sqlVariant.Value, sqlVariant._context); if (!(value1 is IComparable comparable1)) throw new ArgumentException(); diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/OptimizerHintValidatingVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/OptimizerHintValidatingVisitor.cs index 99e622ec..f3f374b7 100644 --- a/MarkMpn.Sql4Cds.Engine/Visitors/OptimizerHintValidatingVisitor.cs +++ b/MarkMpn.Sql4Cds.Engine/Visitors/OptimizerHintValidatingVisitor.cs @@ -60,12 +60,18 @@ class OptimizerHintValidatingVisitor : TSqlFragmentVisitor // Ignore duplicate keys on insert, equivalent to IGNORE_DUP_KEY option on creation of index "IGNORE_DUP_KEY", + + // Custom hint to disable converting DML source queries to constant scans + "NO_DIRECT_DML", }; private static readonly HashSet _removableSql4CdsQueryHints = new HashSet(StringComparer.OrdinalIgnoreCase) { // DML-related hint can be removed from the SELECT statement sent to the TDS Endpoint "BYPASS_CUSTOM_PLUGIN_EXECUTION", + + // Custom hint to disable converting DML source queries to constant scans + "NO_DIRECT_DML", }; private static readonly string[] _tsqlQueryHintPrefixes = new[] diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/SimpleFilterVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/SimpleFilterVisitor.cs new file mode 100644 index 00000000..d11e0f8e --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/Visitors/SimpleFilterVisitor.cs @@ -0,0 +1,118 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using MarkMpn.Sql4Cds.Engine.ExecutionPlan; +using MarkMpn.Sql4Cds.Engine.FetchXml; +using Microsoft.SqlServer.TransactSql.ScriptDom; + +namespace MarkMpn.Sql4Cds.Engine.Visitors +{ + /// + /// Flattens a nested AND/OR filter into a simple list of conditions + /// + class SimpleFilterVisitor : TSqlFragmentVisitor + { + private List _conditions; + private BooleanBinaryExpressionType _binaryType; + private bool _invalid; + private bool _setType; + + public SimpleFilterVisitor() + { + _conditions = new List(); + _binaryType = BooleanBinaryExpressionType.And; + } + + public override void Visit(BooleanBinaryExpression node) + { + base.Visit(node); + + if (_setType && _binaryType != node.BinaryExpressionType) + _invalid = true; + + _binaryType = node.BinaryExpressionType; + _setType = true; + } + + public override void Visit(BooleanExpression node) + { + base.Visit(node); + + if (node is BooleanBinaryExpression) + { + // NOOP + } + else if (node is BooleanComparisonExpression cmp) + { + if (cmp.FirstExpression is ColumnReferenceExpression col1 && cmp.SecondExpression is ValueExpression lit2) + { + if (!cmp.ComparisonType.TryConvertToFetchXml(out var op)) + _invalid = true; + + _conditions.Add(new condition + { + attribute = col1.MultiPartIdentifier.Identifiers.Last().Value, + @operator = op, + value = GetValue(lit2), + IsVariable = !(lit2 is Literal) + }); + } + else if (cmp.FirstExpression is ValueExpression lit1 && cmp.SecondExpression is ColumnReferenceExpression col2) + { + if (!cmp.ComparisonType.TransitiveComparison().TryConvertToFetchXml(out var op)) + _invalid = true; + + _conditions.Add(new condition + { + attribute = col2.MultiPartIdentifier.Identifiers.Last().Value, + @operator = op, + value = GetValue(lit1), + IsVariable = !(lit1 is Literal) + }); + } + else + { + // Unsupported + _invalid = true; + } + } + else if (node is InPredicate @in && @in.Expression is ColumnReferenceExpression inCol && @in.Subquery == null && @in.Values.All(v => v is ValueExpression)) + { + _conditions.Add(new condition + { + attribute = inCol.MultiPartIdentifier.Identifiers.Last().Value, + @operator = @operator.@in, + Items = @in.Values.Cast().Select(l => new conditionValue { Value = GetValue(l), IsVariable = !(l is Literal) }).ToArray() + }); + } + else + { + // Unsupported + _invalid = true; + } + } + + private string GetValue(ValueExpression expr) + { + if (expr is Literal lit) + return lit.Value; + + if (expr is GlobalVariableExpression g) + return g.Name; + + if (expr is VariableReference v) + return v.Name; + + throw new NotSupportedException("Unknown value type"); + } + + /// + /// The type of comparison used to combine the conditions. if the filter is not . + /// + public BooleanBinaryExpressionType? BinaryType => _invalid ? null : (BooleanBinaryExpressionType?)_binaryType; + + public IEnumerable Conditions => _conditions; + } +} diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/TDSEndpointCompatibilityVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/TDSEndpointCompatibilityVisitor.cs index 36cb1110..656421bf 100644 --- a/MarkMpn.Sql4Cds.Engine/Visitors/TDSEndpointCompatibilityVisitor.cs +++ b/MarkMpn.Sql4Cds.Engine/Visitors/TDSEndpointCompatibilityVisitor.cs @@ -458,5 +458,18 @@ public override void Visit(CommonTableExpression node) _ctes[node.ExpressionName.Value] = node; } + + public override void Visit(GeneralSetCommand node) + { + if (node.CommandType == GeneralSetCommandType.DateFormat) + { + // SET DATEFORMAT does work, but isn't persisted correctly across the session + // Mark it as not compatible so we can track the selected format internally + // and pass it to the TDS Endpoint on each call as required. + IsCompatible = false; + } + + base.Visit(node); + } } } diff --git a/MarkMpn.Sql4Cds.Export/ValueFormatter.cs b/MarkMpn.Sql4Cds.Export/ValueFormatter.cs index 7bc6603a..9fb4911f 100644 --- a/MarkMpn.Sql4Cds.Export/ValueFormatter.cs +++ b/MarkMpn.Sql4Cds.Export/ValueFormatter.cs @@ -48,6 +48,10 @@ public static DbCellValue Format(object value, string dataTypeName, int? numeric text = dt.ToString("yyyy-MM-dd HH:mm:ss" + (numericScale == 0 ? "" : ("." + new string('f', numericScale.Value)))); } } + else if (value is DateTimeOffset dto && !localFormatDates) + { + text = dto.ToString("yyyy-MM-dd HH:mm:ss" + (numericScale == 0 ? "" : ("." + new string('f', numericScale.Value))) + " zzz"); + } else if (value is TimeSpan ts && !localFormatDates) { text = ts.ToString("hh\\:mm\\:ss" + (numericScale == 0 ? "" : ("\\." + new string('f', numericScale.Value)))); diff --git a/MarkMpn.Sql4Cds.XTB/Images/LargeIcon_Smooth.png b/MarkMpn.Sql4Cds.XTB/Images/LargeIcon_Smooth.png new file mode 100644 index 00000000..2bb8b0ad Binary files /dev/null and b/MarkMpn.Sql4Cds.XTB/Images/LargeIcon_Smooth.png differ