Skip to content

Commit

Permalink
Merge pull request #3589 from bylaws/wow-inval
Browse files Browse the repository at this point in the history
WOW64 backend code invalidation fixes
  • Loading branch information
Sonicadvance1 authored Apr 23, 2024
2 parents 68e543c + a7f4e99 commit 81c219c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 41 deletions.
47 changes: 37 additions & 10 deletions Source/Windows/Common/InvalidationTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@

#include <FEXCore/Utils/LogManager.h>
#include <FEXCore/Utils/TypeDefines.h>
#include <FEXCore/Utils/SignalScopeGuards.h>
#include <FEXCore/Core/Context.h>
#include <FEXCore/Debug/InternalThreadState.h>
#include "InvalidationTracker.h"
#include <windef.h>
#include <winternl.h>

namespace FEX::Windows {
void InvalidationTracker::HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size,
ULONG Prot) {
InvalidationTracker::InvalidationTracker(FEXCore::Context::Context& CTX, const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads)
: CTX {CTX}
, Threads {Threads} {}

void InvalidationTracker::HandleMemoryProtectionNotification(uint64_t Address, uint64_t Size, ULONG Prot) {
const auto AlignedBase = Address & FEXCore::Utils::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FEXCore::Utils::FEX_PAGE_SIZE - 1) & FEXCore::Utils::FEX_PAGE_MASK;

if (Prot & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) {
Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize);
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, AlignedBase, AlignedSize);
}
}

if (Prot & PAGE_EXECUTE_READWRITE) {
Expand All @@ -28,26 +35,43 @@ void InvalidationTracker::HandleMemoryProtectionNotification(FEXCore::Core::Inte
}
}

void InvalidationTracker::InvalidateContainingSection(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, bool Free) {
void InvalidationTracker::InvalidateContainingSection(uint64_t Address, bool Free) {
MEMORY_BASIC_INFORMATION Info;
if (NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast<void*>(Address), MemoryBasicInformation, &Info, sizeof(Info), nullptr)) {
return;
}

const auto SectionBase = reinterpret_cast<uint64_t>(Info.AllocationBase);
const auto SectionSize = reinterpret_cast<uint64_t>(Info.BaseAddress) + Info.RegionSize - reinterpret_cast<uint64_t>(Info.AllocationBase);
Thread->CTX->InvalidateGuestCodeRange(Thread, SectionBase, SectionSize);
auto SectionSize = reinterpret_cast<uint64_t>(Info.BaseAddress) + Info.RegionSize - SectionBase;

while (!NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast<void*>(SectionBase + SectionSize), MemoryBasicInformation, &Info,
sizeof(Info), nullptr) &&
reinterpret_cast<uint64_t>(Info.AllocationBase) == SectionBase) {
SectionSize += Info.RegionSize;
}
{
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, SectionBase, SectionSize);
}
}

if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Remove({SectionBase, SectionBase + SectionSize});
}
}

void InvalidationTracker::InvalidateAlignedInterval(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, bool Free) {
void InvalidationTracker::InvalidateAlignedInterval(uint64_t Address, uint64_t Size, bool Free) {
const auto AlignedBase = Address & FEXCore::Utils::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FEXCore::Utils::FEX_PAGE_SIZE - 1) & FEXCore::Utils::FEX_PAGE_MASK;
Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize);

{
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, AlignedBase, AlignedSize);
}
}

if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
Expand Down Expand Up @@ -75,7 +99,7 @@ void InvalidationTracker::ReprotectRWXIntervals(uint64_t Address, uint64_t Size)
} while (Address < End);
}

bool InvalidationTracker::HandleRWXAccessViolation(FEXCore::Core::InternalThreadState* Thread, uint64_t FaultAddress) {
bool InvalidationTracker::HandleRWXAccessViolation(uint64_t FaultAddress) {
const bool NeedsInvalidate = [&](uint64_t Address) {
std::unique_lock Lock(RWXIntervalsLock);
const bool Enclosed = RWXIntervals.Query(Address).Enclosed;
Expand All @@ -93,7 +117,10 @@ bool InvalidationTracker::HandleRWXAccessViolation(FEXCore::Core::InternalThread

if (NeedsInvalidate) {
// RWXIntervalsLock cannot be held during invalidation
Thread->CTX->InvalidateGuestCodeRange(Thread, FaultAddress & FEXCore::Utils::FEX_PAGE_MASK, FEXCore::Utils::FEX_PAGE_SIZE);
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, FaultAddress & FEXCore::Utils::FEX_PAGE_MASK, FEXCore::Utils::FEX_PAGE_SIZE);
}
return true;
}
return false;
Expand Down
17 changes: 12 additions & 5 deletions Source/Windows/Common/InvalidationTracker.h
Original file line number Diff line number Diff line change
@@ -1,28 +1,35 @@
// SPDX-License-Identifier: MIT
// FIXME TODO put in cpp
#pragma once

#include "IntervalList.h"
#include <mutex>
#include <unordered_map>

namespace FEXCore::Core {
struct InternalThreadState;
}

namespace FEXCore::Context {
class Context;
}

namespace FEX::Windows {
/**
* @brief Handles SMC and regular code invalidation
*/
class InvalidationTracker {
public:
void HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, ULONG Prot);
void InvalidateContainingSection(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, bool Free);
void InvalidateAlignedInterval(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, bool Free);
InvalidationTracker(FEXCore::Context::Context& CTX, const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads);
void HandleMemoryProtectionNotification(uint64_t Address, uint64_t Size, ULONG Prot);
void InvalidateContainingSection(uint64_t Address, bool Free);
void InvalidateAlignedInterval(uint64_t Address, uint64_t Size, bool Free);
void ReprotectRWXIntervals(uint64_t Address, uint64_t Size);
bool HandleRWXAccessViolation(FEXCore::Core::InternalThreadState* Thread, uint64_t FaultAddress);
bool HandleRWXAccessViolation(uint64_t FaultAddress);

private:
IntervalList<uint64_t> RWXIntervals;
std::mutex RWXIntervalsLock;
FEXCore::Context::Context& CTX;
const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads;
};
} // namespace FEX::Windows
2 changes: 0 additions & 2 deletions Source/Windows/Defs/ntdll.def
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ EXPORTS
NtSuspendThread
NtGetContextThread
NtContinue
NtQueryVirtualMemory
NtProtectVirtualMemory
__wine_dbg_output
__wine_unix_call
61 changes: 37 additions & 24 deletions Source/Windows/WOW64/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ desc: Implements the WOW64 BT module API using FEXCore
#include <atomic>
#include <mutex>
#include <utility>
#include <unordered_set>
#include <unordered_map>
#include <ntstatus.h>
#include <windef.h>
#include <winternl.h>
Expand Down Expand Up @@ -94,11 +94,12 @@ fextl::unique_ptr<FEXCore::Context::Context> CTX;
fextl::unique_ptr<FEX::DummyHandlers::DummySignalDelegator> SignalDelegator;
fextl::unique_ptr<WowSyscallHandler> SyscallHandler;

FEX::Windows::InvalidationTracker InvalidationTracker;
std::optional<FEX::Windows::InvalidationTracker> InvalidationTracker;
std::optional<FEX::Windows::CPUFeatures> CPUFeatures;

std::mutex ThreadSuspendLock;
std::unordered_set<DWORD> InitializedWOWThreads; // Set of TIDs, `ThreadSuspendLock` must be locked when accessing
std::mutex ThreadCreationMutex;
// Map of TIDs to their FEX thread state, `ThreadCreationMutex` must be locked when accessing
std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*> Threads;

std::pair<NTSTATUS, TLS> GetThreadTLS(HANDLE Thread) {
THREAD_BASIC_INFORMATION Info;
Expand Down Expand Up @@ -393,7 +394,7 @@ class WowSyscallHandler : public FEXCore::HLE::SyscallHandler, public FEXCore::A
}

void MarkGuestExecutableRange(FEXCore::Core::InternalThreadState* Thread, uint64_t Start, uint64_t Length) override {
InvalidationTracker.ReprotectRWXIntervals(Start, Length);
InvalidationTracker->ReprotectRWXIntervals(Start, Length);
}
};

Expand Down Expand Up @@ -422,14 +423,16 @@ void BTCpuProcessInit() {
CTX->SetSignalDelegator(SignalDelegator.get());
CTX->SetSyscallHandler(SyscallHandler.get());
CTX->InitCore();
InvalidationTracker.emplace(*CTX, Threads);
CPUFeatures.emplace(*CTX);
}

NTSTATUS BTCpuThreadInit() {
GetTLS().ThreadState() = CTX->CreateThread(0, 0);
auto* Thread = CTX->CreateThread(0, 0);
GetTLS().ThreadState() = Thread;

std::scoped_lock Lock(ThreadSuspendLock);
InitializedWOWThreads.emplace(GetCurrentThreadId());
std::scoped_lock Lock(ThreadCreationMutex);
Threads.emplace(GetCurrentThreadId(), Thread);
return STATUS_SUCCESS;
}

Expand All @@ -446,8 +449,8 @@ NTSTATUS BTCpuThreadTerm(HANDLE Thread) {
}

const auto ThreadTID = reinterpret_cast<uint64_t>(Info.ClientId.UniqueThread);
std::scoped_lock Lock(ThreadSuspendLock);
InitializedWOWThreads.erase(ThreadTID);
std::scoped_lock Lock(ThreadCreationMutex);
Threads.erase(ThreadTID);
}

CTX->DestroyThread(TLS.ThreadState());
Expand Down Expand Up @@ -550,10 +553,10 @@ NTSTATUS BTCpuSuspendLocalThread(HANDLE Thread, ULONG* Count) {
return Err;
}

std::scoped_lock Lock(ThreadSuspendLock);
std::scoped_lock Lock(ThreadCreationMutex);

// If the thread hasn't yet been initialized, suspend it without special handling as it wont yet have entered the JIT
if (!InitializedWOWThreads.contains(ThreadTID)) {
if (!Threads.contains(ThreadTID)) {
return NtSuspendThread(Thread, Count);
}

Expand Down Expand Up @@ -615,15 +618,22 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS* Ptrs) {
if (Exception->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) {
const auto FaultAddress = static_cast<uint64_t>(Exception->ExceptionInformation[1]);

if (InvalidationTracker.HandleRWXAccessViolation(GetTLS().ThreadState(), FaultAddress)) {
LogMan::Msg::DFmt("Handled self-modifying code: pc: {:X} fault: {:X}", Context->Pc, FaultAddress);
NtContinue(Context, FALSE);
}

if (Context::HandleSuspendInterrupt(Context, FaultAddress)) {
LogMan::Msg::DFmt("Resumed from suspend");
NtContinue(Context, FALSE);
}

bool HandledRWX = false;
if (GetTLS().ThreadState()) {
std::scoped_lock Lock(ThreadCreationMutex);
HandledRWX = InvalidationTracker->HandleRWXAccessViolation(FaultAddress);
}

if (HandledRWX) {
LogMan::Msg::DFmt("Handled self-modifying code: pc: {:X} fault: {:X}", Context->Pc, FaultAddress);
NtContinue(Context, FALSE);
}
}

if (!IsAddressInJit(Context->Pc)) {
Expand All @@ -645,29 +655,32 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS* Ptrs) {
}

void BTCpuFlushInstructionCache2(const void* Address, SIZE_T Size) {
InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), false);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->InvalidateAlignedInterval(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), false);
}

void BTCpuNotifyMemoryAlloc(void* Address, SIZE_T Size, ULONG Type, ULONG Prot) {
InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address),
static_cast<uint64_t>(Size), Prot);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->HandleMemoryProtectionNotification(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), Prot);
}

void BTCpuNotifyMemoryProtect(void* Address, SIZE_T Size, ULONG NewProt) {
InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address),
static_cast<uint64_t>(Size), NewProt);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->HandleMemoryProtectionNotification(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), NewProt);
}

void BTCpuNotifyMemoryFree(void* Address, SIZE_T Size, ULONG FreeType) {
std::scoped_lock Lock(ThreadCreationMutex);
if (!Size) {
InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), true);
InvalidationTracker->InvalidateContainingSection(reinterpret_cast<uint64_t>(Address), true);
} else if (FreeType & MEM_DECOMMIT) {
InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), true);
InvalidationTracker->InvalidateAlignedInterval(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), true);
}
}

void BTCpuNotifyUnmapViewOfSection(void* Address, ULONG Flags) {
InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), true);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->InvalidateContainingSection(reinterpret_cast<uint64_t>(Address), true);
}

BOOLEAN WINAPI BTCpuIsProcessorFeaturePresent(UINT Feature) {
Expand Down

0 comments on commit 81c219c

Please sign in to comment.