diff --git a/Gradio.Net.AspNetCore/GradioServiceExtensions.cs b/Gradio.Net.AspNetCore/GradioServiceExtensions.cs index 3cf4d4a..b45701c 100644 --- a/Gradio.Net.AspNetCore/GradioServiceExtensions.cs +++ b/Gradio.Net.AspNetCore/GradioServiceExtensions.cs @@ -86,17 +86,17 @@ public static WebApplication UseGradio(this WebApplication webApplication context.Response.Headers.Add("Content-Type", "text/event-stream"); - StreamWriter streamWriter = new(context.Response.Body); - var sessionHash = context.Request.Query["session_hash"].FirstOrDefault(); - await foreach (SSEMessage message in app.QueueData(sessionHash, stoppingToken)) - { - await streamWriter.WriteLineAsync(message.ProcessMsg()); - await streamWriter.FlushAsync(); - } - await streamWriter.WriteLineAsync(new CloseStreamMessage().ProcessMsg()); + StreamWriter streamWriter = new(context.Response.Body); + var sessionHash = context.Request.Query["session_hash"].FirstOrDefault(); + await foreach (SSEMessage message in app.QueueData(sessionHash, stoppingToken)) + { + await streamWriter.WriteLineAsync(message.ProcessMsg()); await streamWriter.FlushAsync(); - Context.PendingEventIdsSession.TryRemove(sessionHash, out _); - }); + } + await streamWriter.WriteLineAsync(new CloseStreamMessage().ProcessMsg()); + await streamWriter.FlushAsync(); + app.ClonseSession(sessionHash); + }); webApplication.MapPost("/upload", async (HttpRequest request, [FromServices] GradioApp app) => { diff --git a/Gradio.Net/Context.cs b/Gradio.Net/Context.cs index 7282afc..fe65e89 100644 --- a/Gradio.Net/Context.cs +++ b/Gradio.Net/Context.cs @@ -4,10 +4,10 @@ namespace Gradio.Net; -public static class Context +internal static class Context { internal static object PendingMessageLock = new object(); - public static ConcurrentDictionary> PendingEventIdsSession { get; private set; } = new ConcurrentDictionary>(); + internal static ConcurrentDictionary> PendingEventIdsSession { get; private set; } = new ConcurrentDictionary>(); internal static ConcurrentDictionary DownloadableFiles { get; private set; } = new ConcurrentDictionary(); internal static ConcurrentDictionary EventResults { get; private set; } = new ConcurrentDictionary(); internal static Channel EventChannel { get; private set; } = Channel.CreateUnbounded(); diff --git a/Gradio.Net/GradioApp.cs b/Gradio.Net/GradioApp.cs index fa9151a..c26a91f 100644 --- a/Gradio.Net/GradioApp.cs +++ b/Gradio.Net/GradioApp.cs @@ -74,8 +74,6 @@ public async Task QueueJoin(string rootUrl, PredictBodyIn body) public async IAsyncEnumerable QueueData(string sessionHash, CancellationToken stoppingToken) { - - const int heartbeatRate = 150; const int checkRate = 50; int heartbeatCount = 0; @@ -272,11 +270,17 @@ public async IAsyncEnumerable QueueData(string sessionHash, Cancella } } + public List? ClonseSession(string sessionHash) + { + Context.PendingEventIdsSession.TryRemove(sessionHash, out List? tmpIds); + return tmpIds; + } + private void RemovePendingEvent(string sessionHash, string pendingEventId) { lock (Context.PendingMessageLock) { - if (!Context.PendingEventIdsSession.TryGetValue(sessionHash, out List tmpIds)) + if (!Context.PendingEventIdsSession.TryGetValue(sessionHash, out List? tmpIds)) { return; }