Skip to content

Commit

Permalink
修改Context公开方式,把关闭session放在GradioApp实现
Browse files Browse the repository at this point in the history
  • Loading branch information
imxcstar committed Jun 6, 2024
1 parent 03c0b6b commit 8d02058
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
20 changes: 10 additions & 10 deletions Gradio.Net.AspNetCore/GradioServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,17 @@ public static WebApplication UseGradio(this WebApplication webApplication
context.Response.Headers.Add("Content-Type", "text/event-stream");
StreamWriter streamWriter = new(context.Response.Body);
var sessionHash = context.Request.Query["session_hash"].FirstOrDefault();
await foreach (SSEMessage message in app.QueueData(sessionHash, stoppingToken))
{
await streamWriter.WriteLineAsync(message.ProcessMsg());
await streamWriter.FlushAsync();
}
await streamWriter.WriteLineAsync(new CloseStreamMessage().ProcessMsg());
StreamWriter streamWriter = new(context.Response.Body);
var sessionHash = context.Request.Query["session_hash"].FirstOrDefault();
await foreach (SSEMessage message in app.QueueData(sessionHash, stoppingToken))
{
await streamWriter.WriteLineAsync(message.ProcessMsg());
await streamWriter.FlushAsync();
Context.PendingEventIdsSession.TryRemove(sessionHash, out _);
});
}
await streamWriter.WriteLineAsync(new CloseStreamMessage().ProcessMsg());
await streamWriter.FlushAsync();
app.ClonseSession(sessionHash);
});

webApplication.MapPost("/upload", async (HttpRequest request, [FromServices] GradioApp app) =>
{
Expand Down
4 changes: 2 additions & 2 deletions Gradio.Net/Context.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

namespace Gradio.Net;

public static class Context
internal static class Context
{
internal static object PendingMessageLock = new object();
public static ConcurrentDictionary<string, List<string>> PendingEventIdsSession { get; private set; } = new ConcurrentDictionary<string, List<string>>();
internal static ConcurrentDictionary<string, List<string>> PendingEventIdsSession { get; private set; } = new ConcurrentDictionary<string, List<string>>();
internal static ConcurrentDictionary<string,string> DownloadableFiles { get; private set; } = new ConcurrentDictionary<string,string>();
internal static ConcurrentDictionary<string, EventResult> EventResults { get; private set; } = new ConcurrentDictionary<string, EventResult>();
internal static Channel<Event> EventChannel { get; private set; } = Channel.CreateUnbounded<Event>();
Expand Down
10 changes: 7 additions & 3 deletions Gradio.Net/GradioApp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ public async Task<QueueJoinOut> QueueJoin(string rootUrl, PredictBodyIn body)

public async IAsyncEnumerable<SSEMessage> QueueData(string sessionHash, CancellationToken stoppingToken)
{


const int heartbeatRate = 150;
const int checkRate = 50;
int heartbeatCount = 0;
Expand Down Expand Up @@ -272,11 +270,17 @@ public async IAsyncEnumerable<SSEMessage> QueueData(string sessionHash, Cancella
}
}

public List<string>? ClonseSession(string sessionHash)
{
Context.PendingEventIdsSession.TryRemove(sessionHash, out List<string>? tmpIds);
return tmpIds;
}

private void RemovePendingEvent(string sessionHash, string pendingEventId)
{
lock (Context.PendingMessageLock)
{
if (!Context.PendingEventIdsSession.TryGetValue(sessionHash, out List<string> tmpIds))
if (!Context.PendingEventIdsSession.TryGetValue(sessionHash, out List<string>? tmpIds))
{
return;
}
Expand Down

0 comments on commit 8d02058

Please sign in to comment.