diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowEditBase.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowEditBase.cs index 8efdaf915c..bb12d317cb 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowEditBase.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowEditBase.cs @@ -181,38 +181,60 @@ protected WhereClause GetWhereClause(bool parameterize) } else { - if (cellData.RawObject is byte[] || - col.DbColumn.DataTypeName.Equals("TEXT", StringComparison.OrdinalIgnoreCase) || - col.DbColumn.DataTypeName.Equals("NTEXT", StringComparison.OrdinalIgnoreCase)) + if (parameterize) { - // Special cases for byte[] and TEXT/NTEXT types - cellDataClause = "IS NOT NULL"; - } - else - { - // General case is to just use the value from the cell - if (parameterize) + // Add a parameter and parameterized clause component + // NOTE: We include the row ID to make sure the parameter is unique if + // we execute multiple row edits at once. + string paramName = $"@Param{RowId}{col.Ordinal}"; + if (cellData.RawObject is byte[]) { - // Add a parameter and parameterized clause component - // NOTE: We include the row ID to make sure the parameter is unique if - // we execute multiple row edits at once. - string paramName = $"@Param{RowId}{col.Ordinal}"; - cellDataClause = $"= {paramName}"; - SqlParameter parameter = new SqlParameter(paramName, col.DbColumn.SqlDbType) - { - Value = cellData.RawObject - }; - output.Parameters.Add(parameter); + cellDataClause = $"= CONVERT (VARBINARY(MAX), {paramName})"; + } + else if (col.DbColumn.DataTypeName.Equals("TEXT", StringComparison.OrdinalIgnoreCase) || (col.DbColumn.DataTypeName.Equals("TEXT", StringComparison.OrdinalIgnoreCase) || + col.DbColumn.DataTypeName.Equals("NTEXT", StringComparison.OrdinalIgnoreCase))) + { + // Special case for TEXT/NTEXT types. + //NOTE: the types are not compatible with n/varchar so direct comparison + // will not work for these types, must convert first. + cellDataClause = $"= CONVERT (NVARCHAR(MAX), {paramName})"; } else { - // Add the clause component with the formatted value - cellDataClause = $"= {ToSqlScript.FormatValue(cellData, col.DbColumn)}"; + cellDataClause = $"= {paramName}"; } + + SqlParameter parameter = new SqlParameter(paramName, col.DbColumn.SqlDbType) + { + Value = cellData.RawObject + }; + output.Parameters.Add(parameter); + } + else + { + // Add the clause component with the formatted value + cellDataClause = $"= {ToSqlScript.FormatValue(cellData, col.DbColumn)}"; } } - string completeComponent = $"({col.EscapedName} {cellDataClause})"; + string completeComponent; + + if (cellData.RawObject is byte[]) + { + completeComponent = $"(CONVERT (VARBINARY(MAX), {col.EscapedName}) {cellDataClause})"; + } + + else if (col.DbColumn.DataTypeName.Equals("TEXT", StringComparison.OrdinalIgnoreCase) || + col.DbColumn.DataTypeName.Equals("NTEXT", StringComparison.OrdinalIgnoreCase)) + { + // Special case for TEXT/NTEXT types as explained on line 197. + completeComponent = $"(CONVERT (NVARCHAR(MAX), {col.EscapedName}) {cellDataClause})"; + } + else + { + completeComponent = $"({col.EscapedName} {cellDataClause})"; + } + output.ClauseComponents.Add(completeComponent); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowEditBaseTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowEditBaseTests.cs index c9217b3caf..e0c8cef17c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowEditBaseTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowEditBaseTests.cs @@ -80,7 +80,14 @@ public async Task GetWhereClauseSimple(DbColumn col, object val, string nullClau EditTableMetadata etm = Common.GetCustomEditTableMetadata(cols); RowEditTester rt = new RowEditTester(rs, etm); - rt.ValidateWhereClauseSingleKey(nullClause); + if (val == DBNull.Value) + { + rt.ValidateWhereClauseNullKey(nullClause); + } + else + { + rt.ValidateWhereClauseSingleKey(nullClause); + } } public static IEnumerable GetWhereClauseIsNotNullData @@ -95,7 +102,7 @@ public static IEnumerable GetWhereClauseIsNotNullData DataType = typeof(byte[]) }, new byte[5], - "IS NOT NULL" + "= 0x0000000000" }; yield return new object[] { @@ -105,7 +112,7 @@ public static IEnumerable GetWhereClauseIsNotNullData DataTypeName = "TEXT" }, "abc", - "IS NOT NULL" + "= N'abc'" }; yield return new object[] { @@ -116,7 +123,7 @@ public static IEnumerable GetWhereClauseIsNotNullData }, "abc", - "IS NOT NULL" + "= N'abc'" }; } } @@ -252,11 +259,11 @@ public void ValidateColumn(int columnId) } // ReSharper disable once UnusedParameter.Local - public void ValidateWhereClauseSingleKey(string nullValue) + public void ValidateWhereClauseNullKey(string nullValue) { // If: I generate a where clause with one is null column value WhereClause wc = GetWhereClause(false); - + // Then: // ... There should only be one component Assert.AreEqual(1, wc.ClauseComponents.Count); @@ -272,6 +279,27 @@ public void ValidateWhereClauseSingleKey(string nullValue) Assert.AreEqual($"WHERE {wc.ClauseComponents[0]}", wc.CommandText); } + public void ValidateWhereClauseSingleKey(string clauseValue) + { + // If: I generate a where clause with one is null column value + WhereClause wc = GetWhereClause(false); + + // Then: + // ... There should only be one component + Assert.AreEqual(1, wc.ClauseComponents.Count); + + // ... Parameterization should be empty + Assert.IsEmpty(wc.Parameters); + + // ... The component should contain the name of the column and the value + Assert.True(wc.ClauseComponents[0].Contains(AssociatedObjectMetadata.Columns.First().EscapedName)); + Regex r = new Regex($@"\(CONVERT \([A-Z]*\(MAX\), {AssociatedObjectMetadata.Columns.First().EscapedName}\) {clauseValue}\)"); + Assert.True(r.IsMatch(wc.ClauseComponents[0])); + + // ... The complete clause should contain a single WHERE + Assert.AreEqual($"WHERE {wc.ClauseComponents[0]}", wc.CommandText); + } + public void ValidateWhereClauseMultipleKeys() { // If: I generate a where clause with multiple key columns