diff --git a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs index c534af7303..b1c9dde380 100644 --- a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs +++ b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs @@ -1068,11 +1068,11 @@ public async Task CreateStickerAsync(string name, Image image, IE /// /// A task that represents the asynchronous creation operation. The task result contains the created sticker. /// - public Task CreateStickerAsync(string name, string path, IEnumerable tags, string description = null, + public async Task CreateStickerAsync(string name, string path, IEnumerable tags, string description = null, RequestOptions options = null) { - var fs = File.OpenRead(path); - return CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description,options); + using var fs = File.OpenRead(path); + return await CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description,options); } /// /// Creates a new sticker in this guild diff --git a/src/Discord.Net.Rest/Entities/Interactions/Modals/RestModal.cs b/src/Discord.Net.Rest/Entities/Interactions/Modals/RestModal.cs index 524b70adba..3d2f4ba4a3 100644 --- a/src/Discord.Net.Rest/Entities/Interactions/Modals/RestModal.cs +++ b/src/Discord.Net.Rest/Entities/Interactions/Modals/RestModal.cs @@ -228,6 +228,7 @@ public override async Task FollowupWithFileAsync( fileName ??= Path.GetFileName(filePath); Preconditions.NotNullOrEmpty(fileName, nameof(fileName), "File Name must not be empty or null"); + using var fileStream = !string.IsNullOrEmpty(filePath) ? new MemoryStream(File.ReadAllBytes(filePath), false) : null; var args = new API.Rest.CreateWebhookMessageParams { Content = text, @@ -235,7 +236,7 @@ public override async Task FollowupWithFileAsync( IsTTS = isTTS, Embeds = embeds.Select(x => x.ToModel()).ToArray(), Components = component?.Components.Select(x => new API.ActionRowComponent(x)).ToArray() ?? Optional.Unspecified, - File = !string.IsNullOrEmpty(filePath) ? new MultipartFile(new MemoryStream(File.ReadAllBytes(filePath), false), fileName) : Optional.Unspecified + File = fileStream != null ? new MultipartFile(fileStream, fileName) : Optional.Unspecified }; if (ephemeral) diff --git a/src/Discord.Net.Rest/Net/DefaultRestClient.cs b/src/Discord.Net.Rest/Net/DefaultRestClient.cs index 8959caa485..46522a36ec 100644 --- a/src/Discord.Net.Rest/Net/DefaultRestClient.cs +++ b/src/Discord.Net.Rest/Net/DefaultRestClient.cs @@ -1,7 +1,7 @@ -using Discord.Net.Converters; using Newtonsoft.Json; using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Globalization; using System.IO; using System.Linq; @@ -101,62 +101,68 @@ public async Task SendAsync(string method, string endpoint, IReadO IEnumerable>> requestHeaders = null) { string uri = Path.Combine(_baseUrl, endpoint); - using (var restRequest = new HttpRequestMessage(GetMethod(method), uri)) + + // HttpRequestMessage implements IDisposable but we do not need to dispose it as it merely disposes of its Content property, + // which we can do as needed. And regarding that, we do not want to take responsibility for disposing of content provided by + // the caller of this function, since it's possible that the caller wants to reuse it or is forced to reuse it because of a + // 429 response. Therefore, by convention, we only dispose the content objects created in this function (if any). + // + // See this comment explaining why this is safe: https://github.com/aspnet/Security/issues/886#issuecomment-229181249 + // See also the source for HttpRequestMessage: https://github.com/microsoft/referencesource/blob/master/System/net/System/Net/Http/HttpRequestMessage.cs +#pragma warning disable IDISP004 + var restRequest = new HttpRequestMessage(GetMethod(method), uri); +#pragma warning restore IDISP004 + + if (reason != null) + restRequest.Headers.Add("X-Audit-Log-Reason", Uri.EscapeDataString(reason)); + if (requestHeaders != null) + foreach (var header in requestHeaders) + restRequest.Headers.Add(header.Key, header.Value); + var content = new MultipartFormDataContent("Upload----" + DateTime.Now.ToString(CultureInfo.InvariantCulture)); + + static StreamContent GetStreamContent(Stream stream) { - if (reason != null) - restRequest.Headers.Add("X-Audit-Log-Reason", Uri.EscapeDataString(reason)); - if (requestHeaders != null) - foreach (var header in requestHeaders) - restRequest.Headers.Add(header.Key, header.Value); - var content = new MultipartFormDataContent("Upload----" + DateTime.Now.ToString(CultureInfo.InvariantCulture)); - MemoryStream memoryStream = null; - if (multipartParams != null) + if (stream.CanSeek) + { + // Reset back to the beginning; it may have been used elsewhere or in a previous request. + stream.Position = 0; + } + +#pragma warning disable IDISP004 + return new StreamContent(stream); +#pragma warning restore IDISP004 + } + + foreach (var p in multipartParams ?? ImmutableDictionary.Empty) + { + switch (p.Value) { - foreach (var p in multipartParams) - { - switch (p.Value) - { #pragma warning disable IDISP004 - case string stringValue: - { content.Add(new StringContent(stringValue, Encoding.UTF8, "text/plain"), p.Key); continue; } - case byte[] byteArrayValue: - { content.Add(new ByteArrayContent(byteArrayValue), p.Key); continue; } - case Stream streamValue: - { content.Add(new StreamContent(streamValue), p.Key); continue; } - case MultipartFile fileValue: - { - var stream = fileValue.Stream; - if (!stream.CanSeek) - { - memoryStream = new MemoryStream(); - await stream.CopyToAsync(memoryStream).ConfigureAwait(false); - memoryStream.Position = 0; -#pragma warning disable IDISP001 - stream = memoryStream; -#pragma warning restore IDISP001 - } - - var streamContent = new StreamContent(stream); - var extension = fileValue.Filename.Split('.').Last(); - - if (fileValue.ContentType != null) - streamContent.Headers.ContentType = new MediaTypeHeaderValue(fileValue.ContentType); - - content.Add(streamContent, p.Key, fileValue.Filename); + case string stringValue: + { content.Add(new StringContent(stringValue, Encoding.UTF8, "text/plain"), p.Key); continue; } + case byte[] byteArrayValue: + { content.Add(new ByteArrayContent(byteArrayValue), p.Key); continue; } + case Stream streamValue: + { content.Add(GetStreamContent(streamValue), p.Key); continue; } + case MultipartFile fileValue: + { + var streamContent = GetStreamContent(fileValue.Stream); + + if (fileValue.ContentType != null) + streamContent.Headers.ContentType = new MediaTypeHeaderValue(fileValue.ContentType); + + content.Add(streamContent, p.Key, fileValue.Filename); #pragma warning restore IDISP004 - continue; - } - default: - throw new InvalidOperationException($"Unsupported param type \"{p.Value.GetType().Name}\"."); + continue; } - } + default: + throw new InvalidOperationException($"Unsupported param type \"{p.Value.GetType().Name}\"."); } - restRequest.Content = content; - var result = await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false); - memoryStream?.Dispose(); - return result; } + + restRequest.Content = content; + return await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false); } private async Task SendInternalAsync(HttpRequestMessage request, CancellationToken cancelToken, bool headerOnly) diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index ab1d5fcc68..74adf5a179 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -1558,11 +1558,11 @@ public async Task CreateStickerAsync(string name, Image im /// /// A task that represents the asynchronous creation operation. The task result contains the created sticker. /// - public Task CreateStickerAsync(string name, string path, IEnumerable tags, string description = null, + public async Task CreateStickerAsync(string name, string path, IEnumerable tags, string description = null, RequestOptions options = null) { - var fs = File.OpenRead(path); - return CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description, options); + using var fs = File.OpenRead(path); + return await CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description, options); } /// /// Creates a new sticker in this guild