Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Adjust message update on HttpGPT Chat to avoid crashes when the editor context changes #30

Merged
merged 2 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions Source/HttpGPT/Private/HttpGPTRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ void UHttpGPTRequest::StopHttpGPTTask()
{
FScopeLock Lock(&Mutex);

if (!bIsActive)
if (!bIsTaskActive)
{
return;
}

UE_LOG(LogHttpGPT, Display, TEXT("%s (%d): Stopping task"), *FString(__func__), GetUniqueID());

bIsActive = false;
bIsTaskActive = false;

if (HttpRequest.IsValid())
{
Expand All @@ -73,13 +73,6 @@ void UHttpGPTRequest::StopHttpGPTTask()
SetReadyToDestroy();
}

const bool UHttpGPTRequest::IsTaskActive() const
{
FScopeLock Lock(&Mutex);

return bIsActive;
}

const FHttpGPTOptions UHttpGPTRequest::GetTaskOptions() const
{
return TaskOptions;
Expand All @@ -91,7 +84,7 @@ void UHttpGPTRequest::Activate()

UE_LOG(LogHttpGPT, Display, TEXT("%s (%d): Activating task"), *FString(__func__), GetUniqueID());

bIsActive = true;
bIsTaskActive = true;

if (HttpGPT::Internal::HasEmptyParam(Messages) || HttpGPT::Internal::HasEmptyParam(TaskOptions.APIKey))
{
Expand Down Expand Up @@ -132,7 +125,7 @@ void UHttpGPTRequest::SetReadyToDestroy()
#endif

bIsReadyToDestroy = true;
bIsActive = false;
bIsTaskActive = false;

Super::SetReadyToDestroy();
}
Expand Down Expand Up @@ -304,7 +297,7 @@ void UHttpGPTRequest::BindRequestCallbacks()
{
FScopeTryLock Lock(&Mutex);

if (!Lock.IsLocked() || !IsValid(this) || !bIsActive)
if (!Lock.IsLocked() || !IsValid(this) || !bIsTaskActive)
{
return;
}
Expand All @@ -319,7 +312,7 @@ void UHttpGPTRequest::BindRequestCallbacks()
{
FScopeTryLock Lock(&Mutex);

if (!Lock.IsLocked() || !IsValid(this) || !bIsActive)
if (!Lock.IsLocked() || !IsValid(this) || !bIsTaskActive)
{
return;
}
Expand Down Expand Up @@ -559,3 +552,24 @@ void UHttpGPTRequest::DeserializeSingleResponse(const FString& Content)
Response.Usage = FHttpGPTUsage((*UsageObj)->GetNumberField("prompt_tokens"), (*UsageObj)->GetNumberField("completion_tokens"), (*UsageObj)->GetNumberField("total_tokens"));
}
}

bool UHttpGPTTaskStatus::IsTaskActive(const UHttpGPTRequest* Test)
{
return IsValid(Test) && Test->bIsTaskActive;
}

bool UHttpGPTTaskStatus::IsTaskReadyToDestroy(const UHttpGPTRequest* Test)
{
return IsValid(Test) && Test->bIsReadyToDestroy;
}

bool UHttpGPTTaskStatus::IsTaskStillValid(const UHttpGPTRequest* Test)
{
bool bOutput = IsValid(Test) && !IsTaskReadyToDestroy(Test);

#if WITH_EDITOR
bOutput = bOutput && !Test->bEndingPIE;
#endif

return bOutput;
}
27 changes: 22 additions & 5 deletions Source/HttpGPT/Public/HttpGPTRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#pragma once

#include <CoreMinimal.h>
#include <Kismet/BlueprintAsyncActionBase.h>
#include <Interfaces/IHttpRequest.h>
#include <Kismet/BlueprintAsyncActionBase.h>
#include <Kismet/BlueprintFunctionLibrary.h>
#include "HttpGPTTypes.h"
#include "HttpGPTRequest.generated.h"

Expand All @@ -21,6 +22,8 @@ class HTTPGPT_API UHttpGPTRequest : public UBlueprintAsyncActionBase
{
GENERATED_BODY()

friend class UHttpGPTTaskStatus;

public:
UPROPERTY(BlueprintAssignable, Category = "HttpGPT")
FHttpGPTResponseDelegate ProcessCompleted;
Expand Down Expand Up @@ -55,9 +58,7 @@ class HTTPGPT_API UHttpGPTRequest : public UBlueprintAsyncActionBase
UFUNCTION(BlueprintCallable, Category = "HttpGPT", meta = (DisplayName = "Stop HttpGPT Task"))
void StopHttpGPTTask();

const bool IsTaskActive() const;

UFUNCTION(BlueprintPure, Category = "AzSpeech")
UFUNCTION(BlueprintPure, Category = "HttpGPT")
const FHttpGPTOptions GetTaskOptions() const;

virtual void Activate() override;
Expand Down Expand Up @@ -88,11 +89,27 @@ class HTTPGPT_API UHttpGPTRequest : public UBlueprintAsyncActionBase

bool bInitialized = false;
bool bIsReadyToDestroy = false;
bool bIsActive = false;
bool bIsTaskActive = false;

#if WITH_EDITOR
virtual void PrePIEEnded(bool bIsSimulating);

bool bEndingPIE = false;
#endif
};

UCLASS(NotPlaceable, Category = "HttpGPT")
class HTTPGPT_API UHttpGPTTaskStatus final : public UBlueprintFunctionLibrary
{
GENERATED_BODY()

public:
UFUNCTION(BlueprintPure, Category = "HttpGPT")
static bool IsTaskActive(const UHttpGPTRequest* Test);

UFUNCTION(BlueprintPure, Category = "HttpGPT")
static bool IsTaskReadyToDestroy(const UHttpGPTRequest* Test);

UFUNCTION(BlueprintPure, Category = "HttpGPT")
static bool IsTaskStillValid(const UHttpGPTRequest* Test);
};
70 changes: 45 additions & 25 deletions Source/HttpGPTEditor/Private/SHttpGPTChatView.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ UHttpGPTMessagingHandler::UHttpGPTMessagingHandler(const FObjectInitializer& Obj

void UHttpGPTMessagingHandler::RequestSent()
{
Message.Content = "Waiting for response...";
OnMessageContentUpdated.ExecuteIfBound("Waiting for response...");
}

void UHttpGPTMessagingHandler::RequestFailed()
{
Message.Content = "Request Failed.\nPlease check the logs. (Enable internal logs in Project Settings -> Plugins -> HttpGPT).";
OnMessageContentUpdated.ExecuteIfBound("Request Failed.\nPlease check the logs. (Enable internal logs in Project Settings -> Plugins -> HttpGPT).");
Destroy();
}

void UHttpGPTMessagingHandler::ProcessUpdated(const FHttpGPTResponse& Response)
Expand All @@ -35,28 +36,27 @@ void UHttpGPTMessagingHandler::ProcessUpdated(const FHttpGPTResponse& Response)
void UHttpGPTMessagingHandler::ProcessCompleted(const FHttpGPTResponse& Response)
{
ProcessResponse(Response);

ScrollBoxReference.Reset();
Destroy();
}

void UHttpGPTMessagingHandler::ProcessResponse(const FHttpGPTResponse& Response)
{
if (!Response.bSuccess)
{
const FStringFormatOrderedArguments Arguments_ErrorDetails{
"Request Failed.",
"Please check the logs. (Enable internal logs in Project Settings -> Plugins -> HttpGPT).",
"Error Details: ",
"\tError Code: " + Response.Error.Code.ToString(),
"\tError Type: " + Response.Error.Type.ToString(),
"\tError Message: " + Response.Error.Message
FString("Request Failed."),
FString("Please check the logs. (Enable internal logs in Project Settings -> Plugins -> HttpGPT)."),
FString("Error Details: "),
FString("\tError Code: ") + Response.Error.Code.ToString(),
FString("\tError Type: ") + Response.Error.Type.ToString(),
FString("\tError Message: ") + Response.Error.Message
};

Message.Content = FString::Format(TEXT("{0}\n{1}\n\n{2}\n{3}\n{4}\n{5}"), Arguments_ErrorDetails);
OnMessageContentUpdated.ExecuteIfBound(FString::Format(TEXT("{0}\n{1}\n\n{2}\n{3}\n{4}\n{5}"), Arguments_ErrorDetails));
}
else if (Response.bSuccess && !HttpGPT::Internal::HasEmptyParam(Response.Choices))
{
Message = Response.Choices[0].Message;
OnMessageContentUpdated.ExecuteIfBound(Response.Choices[0].Message.Content);
}
else
{
Expand All @@ -69,10 +69,26 @@ void UHttpGPTMessagingHandler::ProcessResponse(const FHttpGPTResponse& Response)
}
}

void UHttpGPTMessagingHandler::Destroy()
{
#if ENGINE_MAJOR_VERSION >= 5
MarkAsGarbage();
#else
MarkPendingKill();
#endif
}

void SHttpGPTChatItem::Construct(const FArguments& InArgs)
{
Message = FHttpGPTMessage(InArgs._MessageRole, InArgs._InputText);

MessagingHandlerObject = NewObject<UHttpGPTMessagingHandler>();
MessagingHandlerObject->Message = FHttpGPTMessage(InArgs._MessageRole, InArgs._InputText);
MessagingHandlerObject->OnMessageContentUpdated.BindLambda(
[this](FString Content)
{
Message.Content = Content;
}
);

#if ENGINE_MAJOR_VERSION < 5
using FAppStyle = FEditorStyle;
Expand All @@ -84,7 +100,7 @@ void SHttpGPTChatItem::Construct(const FArguments& InArgs)
[
SNew(SVerticalBox)
+ SVerticalBox::Slot()
.Padding(MessagingHandlerObject->Message.Role == EHttpGPTRole::User ? FMargin(Slot_Padding * 16.f, Slot_Padding, Slot_Padding, Slot_Padding) : FMargin(Slot_Padding, Slot_Padding, Slot_Padding * 16.f, Slot_Padding))
.Padding(Message.Role == EHttpGPTRole::User ? FMargin(Slot_Padding * 16.f, Slot_Padding, Slot_Padding, Slot_Padding) : FMargin(Slot_Padding, Slot_Padding, Slot_Padding * 16.f, Slot_Padding))
[
SNew(SBorder)
.BorderImage(AppStyle.GetBrush("Menu.Background"))
Expand All @@ -96,7 +112,7 @@ void SHttpGPTChatItem::Construct(const FArguments& InArgs)
[
SNew(STextBlock)
.Font(FCoreStyle::GetDefaultFontStyle("Bold", 10))
.Text(FText::FromString(MessagingHandlerObject->Message.Role == EHttpGPTRole::User ? "User:" : "Assistant:"))
.Text(FText::FromString(Message.Role == EHttpGPTRole::User ? "User:" : "Assistant:"))
]
+ SVerticalBox::Slot()
.Padding(FMargin(Slot_Padding * 4, Slot_Padding, Slot_Padding, Slot_Padding))
Expand All @@ -113,7 +129,7 @@ void SHttpGPTChatItem::Construct(const FArguments& InArgs)

FText SHttpGPTChatItem::GetMessageText() const
{
return FText::FromString(MessagingHandlerObject->Message.Content);
return FText::FromString(Message.Content);
}

void SHttpGPTChatView::Construct([[maybe_unused]] const FArguments&)
Expand Down Expand Up @@ -203,12 +219,12 @@ FReply SHttpGPTChatView::HandleSendMessageButton()

RequestReference = UHttpGPTRequest::SendMessages_CustomOptions(GEditor->GetEditorWorldContext().World(), GetChatHistory(), Options);

RequestReference->ProgressStarted.AddDynamic(AssistantMessage->MessagingHandlerObject, &UHttpGPTMessagingHandler::ProcessUpdated);
RequestReference->ProgressUpdated.AddDynamic(AssistantMessage->MessagingHandlerObject, &UHttpGPTMessagingHandler::ProcessUpdated);
RequestReference->ProcessCompleted.AddDynamic(AssistantMessage->MessagingHandlerObject, &UHttpGPTMessagingHandler::ProcessCompleted);
RequestReference->ErrorReceived.AddDynamic(AssistantMessage->MessagingHandlerObject, &UHttpGPTMessagingHandler::ProcessCompleted);
RequestReference->RequestFailed.AddDynamic(AssistantMessage->MessagingHandlerObject, &UHttpGPTMessagingHandler::RequestFailed);
RequestReference->RequestSent.AddDynamic(AssistantMessage->MessagingHandlerObject, &UHttpGPTMessagingHandler::RequestSent);
RequestReference->ProgressStarted.AddDynamic(AssistantMessage->MessagingHandlerObject.Get(), &UHttpGPTMessagingHandler::ProcessUpdated);
RequestReference->ProgressUpdated.AddDynamic(AssistantMessage->MessagingHandlerObject.Get(), &UHttpGPTMessagingHandler::ProcessUpdated);
RequestReference->ProcessCompleted.AddDynamic(AssistantMessage->MessagingHandlerObject.Get(), &UHttpGPTMessagingHandler::ProcessCompleted);
RequestReference->ErrorReceived.AddDynamic(AssistantMessage->MessagingHandlerObject.Get(), &UHttpGPTMessagingHandler::ProcessCompleted);
RequestReference->RequestFailed.AddDynamic(AssistantMessage->MessagingHandlerObject.Get(), &UHttpGPTMessagingHandler::RequestFailed);
RequestReference->RequestSent.AddDynamic(AssistantMessage->MessagingHandlerObject.Get(), &UHttpGPTMessagingHandler::RequestSent);

RequestReference->Activate();

Expand All @@ -224,18 +240,22 @@ FReply SHttpGPTChatView::HandleSendMessageButton()

bool SHttpGPTChatView::IsSendMessageEnabled() const
{
return (!IsValid(RequestReference) || !RequestReference->IsTaskActive()) && !HttpGPT::Internal::HasEmptyParam(InputTextBox->GetText());
return (!RequestReference.IsValid() || !UHttpGPTTaskStatus::IsTaskActive(RequestReference.Get())) && !HttpGPT::Internal::HasEmptyParam(InputTextBox->GetText());
}

FReply SHttpGPTChatView::HandleClearChatButton()
{
ChatItems.Empty();
ChatBox->ClearChildren();

if (RequestReference)
if (RequestReference.IsValid())
{
RequestReference->StopHttpGPTTask();
}
else
{
RequestReference.Reset();
}

return FReply::Handled();
}
Expand All @@ -254,7 +274,7 @@ TArray<FHttpGPTMessage> SHttpGPTChatView::GetChatHistory() const

for (const auto& Item : ChatItems)
{
Output.Add(Item->MessagingHandlerObject->Message);
Output.Add(Item->Message);
}

return Output;
Expand Down
19 changes: 9 additions & 10 deletions Source/HttpGPTEditor/Private/SHttpGPTChatView.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <Widgets/Layout/SScrollBox.h>
#include "SHttpGPTChatView.generated.h"

DECLARE_DELEGATE_OneParam(FMessageContentUpdated, FString);

UCLASS(MinimalAPI, NotBlueprintable, NotPlaceable, Category = "Implementation")
class UHttpGPTMessagingHandler : public UObject
{
Expand All @@ -19,6 +21,8 @@ class UHttpGPTMessagingHandler : public UObject
public:
explicit UHttpGPTMessagingHandler(const FObjectInitializer& ObjectInitializer = FObjectInitializer::Get());

FMessageContentUpdated OnMessageContentUpdated;

UFUNCTION()
void RequestSent();

Expand All @@ -31,16 +35,14 @@ class UHttpGPTMessagingHandler : public UObject
UFUNCTION()
void ProcessCompleted(const FHttpGPTResponse& Response);

FHttpGPTMessage Message;

TSharedPtr<SScrollBox> ScrollBoxReference;

void Destroy();

private:
void ProcessResponse(const FHttpGPTResponse& Response);
};

typedef UHttpGPTMessagingHandler* UHttpGPTMessagingHandlerPtr;

class SHttpGPTChatItem final : public SCompoundWidget
{
public:
Expand All @@ -55,7 +57,8 @@ class SHttpGPTChatItem final : public SCompoundWidget

FText GetMessageText() const;

UHttpGPTMessagingHandlerPtr MessagingHandlerObject;
TWeakObjectPtr<UHttpGPTMessagingHandler> MessagingHandlerObject;
FHttpGPTMessage Message;

private:
TSharedPtr<STextBlock> MessageBox;
Expand Down Expand Up @@ -96,9 +99,5 @@ class SHttpGPTChatView final : public SCompoundWidget

TArray<TSharedPtr<FString>> AvailableModels;

#if ENGINE_MAJOR_VERSION >= 5
TObjectPtr<class UHttpGPTRequest> RequestReference;
#else
class UHttpGPTRequest* RequestReference;
#endif
TWeakObjectPtr<class UHttpGPTRequest> RequestReference;
};