Skip to content

Commit

Permalink
properly escape single quotes values in PK and RK Uri paths (#21650)
Browse files Browse the repository at this point in the history
* properly escape single quotes values in PK and RK Uri paths
  • Loading branch information
christothes authored Jun 7, 2021
1 parent 0f550f5 commit f5b425d
Show file tree
Hide file tree
Showing 11 changed files with 1,522 additions and 35 deletions.
64 changes: 32 additions & 32 deletions sdk/tables/Azure.Data.Tables/src/TableClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ public virtual Response<TableItem> Create(CancellationToken cancellationToken =
try
{
var response = _tableOperations.Create(
new TableProperties() { TableName = Name },
new TableProperties { TableName = Name },
null,
_defaultQueryOptions,
cancellationToken);
Expand All @@ -347,7 +347,7 @@ public virtual async Task<Response<TableItem>> CreateAsync(CancellationToken can
try
{
var response = await _tableOperations.CreateAsync(
new TableProperties() { TableName = Name },
new TableProperties { TableName = Name },
null,
_defaultQueryOptions,
cancellationToken)
Expand All @@ -374,7 +374,7 @@ public virtual Response<TableItem> CreateIfNotExists(CancellationToken cancellat
try
{
var response = _tableOperations.Create(
new TableProperties() { TableName = Name },
new TableProperties { TableName = Name },
null,
_defaultQueryOptions,
cancellationToken);
Expand Down Expand Up @@ -404,7 +404,7 @@ public virtual async Task<Response<TableItem>> CreateIfNotExistsAsync(Cancellati
try
{
var response = await _tableOperations.CreateAsync(
new TableProperties() { TableName = Name },
new TableProperties { TableName = Name },
null,
_defaultQueryOptions,
cancellationToken)
Expand Down Expand Up @@ -577,7 +577,7 @@ public virtual Response<T> GetEntity<T>(string partitionKey, string rowKey, IEnu
Name,
partitionKey,
rowKey,
queryOptions: new QueryOptions() { Format = _defaultQueryOptions.Format, Select = selectArg },
queryOptions: new QueryOptions { Format = _defaultQueryOptions.Format, Select = selectArg },
cancellationToken: cancellationToken);

var result = ((Dictionary<string, object>)response.Value).ToTableEntity<T>();
Expand Down Expand Up @@ -620,7 +620,7 @@ public virtual async Task<Response<T>> GetEntityAsync<T>(
Name,
partitionKey,
rowKey,
queryOptions: new QueryOptions() { Format = _defaultQueryOptions.Format, Select = selectArg },
queryOptions: new QueryOptions { Format = _defaultQueryOptions.Format, Select = selectArg },
cancellationToken: cancellationToken)
.ConfigureAwait(false);

Expand Down Expand Up @@ -660,16 +660,16 @@ public virtual async Task<Response> UpsertEntityAsync<T>(
{
TableUpdateMode.Replace => await _tableOperations.UpdateEntityAsync(
Name,
entity!.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity!.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
queryOptions: _defaultQueryOptions,
cancellationToken: cancellationToken)
.ConfigureAwait(false),
TableUpdateMode.Merge => await _tableOperations.MergeEntityAsync(
Name,
entity!.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity!.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
queryOptions: _defaultQueryOptions,
cancellationToken: cancellationToken)
Expand Down Expand Up @@ -708,15 +708,15 @@ public virtual Response UpsertEntity<T>(T entity, TableUpdateMode mode = TableUp
{
TableUpdateMode.Replace => _tableOperations.UpdateEntity(
Name,
entity!.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity!.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
queryOptions: _defaultQueryOptions,
cancellationToken: cancellationToken),
TableUpdateMode.Merge => _tableOperations.MergeEntity(
Name,
entity!.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity!.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
queryOptions: _defaultQueryOptions,
cancellationToken: cancellationToken),
Expand Down Expand Up @@ -770,8 +770,8 @@ public virtual async Task<Response> UpdateEntityAsync<T>(
{
return await _tableOperations.UpdateEntityAsync(
Name,
entity.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
ifMatch: ifMatch.ToString(),
queryOptions: _defaultQueryOptions,
Expand All @@ -782,8 +782,8 @@ public virtual async Task<Response> UpdateEntityAsync<T>(
{
return await _tableOperations.MergeEntityAsync(
Name,
entity!.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity!.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
ifMatch: ifMatch.ToString(),
queryOptions: _defaultQueryOptions,
Expand Down Expand Up @@ -836,8 +836,8 @@ public virtual Response UpdateEntity<T>(T entity, ETag ifMatch, TableUpdateMode
{
return _tableOperations.UpdateEntity(
Name,
entity!.PartitionKey,
entity!.RowKey,
TableOdataFilter.EscapeStringValue(entity!.PartitionKey),
TableOdataFilter.EscapeStringValue(entity!.RowKey),
tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
ifMatch: ifMatch.ToString(),
queryOptions: _defaultQueryOptions,
Expand All @@ -847,8 +847,8 @@ public virtual Response UpdateEntity<T>(T entity, ETag ifMatch, TableUpdateMode
{
return _tableOperations.MergeEntity(
Name,
entity.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
ifMatch: ifMatch.ToString(),
queryOptions: _defaultQueryOptions,
Expand Down Expand Up @@ -975,7 +975,7 @@ public virtual AsyncPageable<T> QueryAsync<T>(
{
var response = await _tableOperations.QueryEntitiesAsync(
Name,
queryOptions: new QueryOptions() { Format = _defaultQueryOptions.Format, Top = pageSizeHint, Filter = filter, Select = selectArg },
queryOptions: new QueryOptions { Format = _defaultQueryOptions.Format, Top = pageSizeHint, Filter = filter, Select = selectArg },
cancellationToken: cancellationToken)
.ConfigureAwait(false);

Expand All @@ -1000,7 +1000,7 @@ public virtual AsyncPageable<T> QueryAsync<T>(

var response = await _tableOperations.QueryEntitiesAsync(
Name,
queryOptions: new QueryOptions() { Format = _defaultQueryOptions.Format, Top = pageSizeHint, Filter = filter, Select = selectArg },
queryOptions: new QueryOptions { Format = _defaultQueryOptions.Format, Top = pageSizeHint, Filter = filter, Select = selectArg },
nextPartitionKey: NextPartitionKey,
nextRowKey: NextRowKey,
cancellationToken: cancellationToken)
Expand Down Expand Up @@ -1054,7 +1054,7 @@ public virtual Pageable<T> Query<T>(
scope.Start();
try
{
var queryOptions = new QueryOptions() { Format = _defaultQueryOptions.Format, Top = pageSizeHint, Filter = filter, Select = selectArg };
var queryOptions = new QueryOptions { Format = _defaultQueryOptions.Format, Top = pageSizeHint, Filter = filter, Select = selectArg };

var response = _tableOperations.QueryEntities(
Name,
Expand All @@ -1080,7 +1080,7 @@ public virtual Pageable<T> Query<T>(
{
var (NextPartitionKey, NextRowKey) = ParseContinuationToken(continuationToken);

var queryOptions = new QueryOptions() { Format = _defaultQueryOptions.Format, Top = pageSizeHint, Filter = filter, Select = selectArg };
var queryOptions = new QueryOptions { Format = _defaultQueryOptions.Format, Top = pageSizeHint, Filter = filter, Select = selectArg };

var response = _tableOperations.QueryEntities(
Name,
Expand Down Expand Up @@ -1422,8 +1422,8 @@ private MultipartContent BuildChangeSet(
new QueryOptions { Format = _defaultQueryOptions.Format!.Value }),
TableTransactionActionType.Delete => batchOperations.CreateDeleteEntityRequest(
Name,
item.Entity.PartitionKey,
item.Entity.RowKey,
TableOdataFilter.EscapeStringValue(item.Entity.PartitionKey),
TableOdataFilter.EscapeStringValue(item.Entity.RowKey),
item.ETag == default ? ETag.All.ToString() : item.ETag.ToString(),
null,
new QueryOptions { Format = _defaultQueryOptions.Format!.Value }),
Expand All @@ -1445,16 +1445,16 @@ private HttpMessage CreateUpdateOrMergeRequest(TableRestClient batchOperations,
{
TableUpdateMode.Replace => batchOperations.CreateUpdateEntityRequest(
Name,
entity.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
null,
ifMatch == default ? null : ifMatch.ToString(),
entity.ToOdataAnnotatedDictionary(),
new QueryOptions { Format = _defaultQueryOptions.Format!.Value }),
TableUpdateMode.Merge => batchOperations.CreateMergeEntityRequest(
Name,
entity.PartitionKey,
entity.RowKey,
TableOdataFilter.EscapeStringValue(entity.PartitionKey),
TableOdataFilter.EscapeStringValue(entity.RowKey),
null,
ifMatch == default ? null : ifMatch.ToString(),
entity.ToOdataAnnotatedDictionary(),
Expand Down
16 changes: 13 additions & 3 deletions sdk/tables/Azure.Data.Tables/src/TableOdataFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ public static string Create(FormattableString filter)
DateTime x => $"{XmlConstants.LiteralPrefixDateTime}'{XmlConvert.ToString(x.ToUniversalTime(), XmlDateTimeSerializationMode.RoundtripKind)}'",

// Text
string x => $"'{x.Replace("'", "''")}'",
char x => $"'{x.ToString().Replace("'", "''")}'",
StringBuilder x => $"'{x.Replace("'", "''")}'",
string x => $"'{EscapeStringValue(x)}'",
char x => $"'{EscapeStringValue(x)}'",
StringBuilder x => $"'{EscapeStringValue(x)}'",

// Everything else
object x => throw new ArgumentException(
Expand All @@ -80,5 +80,15 @@ public static string Create(FormattableString filter)

return string.Format(CultureInfo.InvariantCulture, filter.Format, args);
}

internal static string EscapeStringValue(string s) => s.Replace("'", "''");
internal static StringBuilder EscapeStringValue(StringBuilder s) => s.Replace("'", "''");

internal static string EscapeStringValue(char s) =>
s switch
{
_ when s == '\'' => "''",
_ => s.ToString()
};
}
}
Loading

0 comments on commit f5b425d

Please sign in to comment.