Skip to content

Commit

Permalink
Strip unused generic type context for VASigCookie (#99026)
Browse files Browse the repository at this point in the history
* Strip unused generic type context for VASigCookie

Fixes #98977

---------

Co-authored-by: Aaron Robinson <arobins@microsoft.com>
  • Loading branch information
jkotas and AaronRobinsonMSFT authored Mar 1, 2024
1 parent fcc0209 commit b00f084
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 7 deletions.
152 changes: 145 additions & 7 deletions src/coreclr/vm/ceeload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4659,6 +4659,122 @@ PTR_VOID ReflectionModule::GetRvaField(RVA field) // virtual
// VASigCookies
// ===========================================================================

static bool TypeSignatureContainsGenericVariables(SigParser& sp);
static bool MethodSignatureContainsGenericVariables(SigParser& sp);

static bool TypeSignatureContainsGenericVariables(SigParser& sp)
{
STANDARD_VM_CONTRACT;

CorElementType et = ELEMENT_TYPE_END;
IfFailThrow(sp.GetElemType(&et));

if (CorIsPrimitiveType(et))
return false;

switch (et)
{
case ELEMENT_TYPE_OBJECT:
case ELEMENT_TYPE_STRING:
case ELEMENT_TYPE_TYPEDBYREF:
return false;

case ELEMENT_TYPE_BYREF:
case ELEMENT_TYPE_PTR:
case ELEMENT_TYPE_SZARRAY:
return TypeSignatureContainsGenericVariables(sp);

case ELEMENT_TYPE_VALUETYPE:
case ELEMENT_TYPE_CLASS:
IfFailThrow(sp.GetToken(NULL)); // Skip RID
return false;

case ELEMENT_TYPE_FNPTR:
return MethodSignatureContainsGenericVariables(sp);

case ELEMENT_TYPE_ARRAY:
{
if (TypeSignatureContainsGenericVariables(sp))
return true;

uint32_t rank;
IfFailThrow(sp.GetData(&rank)); // Get rank
if (rank)
{
uint32_t nsizes;
IfFailThrow(sp.GetData(&nsizes)); // Get # of sizes
while (nsizes--)
{
IfFailThrow(sp.GetData(NULL)); // Skip size
}

uint32_t nlbounds;
IfFailThrow(sp.GetData(&nlbounds)); // Get # of lower bounds
while (nlbounds--)
{
IfFailThrow(sp.GetData(NULL)); // Skip lower bounds
}
}
}
return false;

case ELEMENT_TYPE_GENERICINST:
{
if (TypeSignatureContainsGenericVariables(sp))
return true;

uint32_t argCnt;
IfFailThrow(sp.GetData(&argCnt)); // Get number of parameters
while (argCnt--)
{
if (TypeSignatureContainsGenericVariables(sp))
return true;
}
}
return false;

case ELEMENT_TYPE_INTERNAL:
IfFailThrow(sp.GetPointer(NULL));
return false;

case ELEMENT_TYPE_VAR:
case ELEMENT_TYPE_MVAR:
return true;

default:
// Return conservative answer for unhandled elements
_ASSERTE(!"Unexpected element type.");
return true;
}
}

static bool MethodSignatureContainsGenericVariables(SigParser& sp)
{
STANDARD_VM_CONTRACT;

uint32_t callConv = 0;
IfFailThrow(sp.GetCallingConvInfo(&callConv));

if (callConv & IMAGE_CEE_CS_CALLCONV_GENERIC)
{
// Generic signatures should never show up here, return conservative answer.
_ASSERTE(!"Unexpected generic signature.");
return true;
}

uint32_t numArgs = 0;
IfFailThrow(sp.GetData(&numArgs));

// iterate over the return type and parameters
for (uint32_t i = 0; i <= numArgs; i++)
{
if (TypeSignatureContainsGenericVariables(sp))
return true;
}

return false;
}

//==========================================================================
// Enregisters a VASig.
//==========================================================================
Expand All @@ -4667,15 +4783,39 @@ VASigCookie *Module::GetVASigCookie(Signature vaSignature, const SigTypeContext*
CONTRACT(VASigCookie*)
{
INSTANCE_CHECK;
THROWS;
GC_TRIGGERS;
MODE_ANY;
STANDARD_VM_CHECK;
POSTCONDITION(CheckPointer(RETVAL));
INJECT_FAULT(COMPlusThrowOM());
}
CONTRACT_END;

Module* pLoaderModule = ClassLoader::ComputeLoaderModuleWorker(this, mdTokenNil, typeContext->m_classInst, typeContext->m_methodInst);
SigTypeContext emptyContext;

Module* pLoaderModule = this;
if (!typeContext->IsEmpty())
{
// Strip the generic context if it is not actually used by the signature. It is nececessary for both:
// - Performance: allow more sharing of vasig cookies
// - Functionality: built-in runtime marshalling is disallowed for generic signatures
SigParser sigParser = vaSignature.CreateSigParser();
if (MethodSignatureContainsGenericVariables(sigParser))
{
pLoaderModule = ClassLoader::ComputeLoaderModuleWorker(this, mdTokenNil, typeContext->m_classInst, typeContext->m_methodInst);
}
else
{
typeContext = &emptyContext;
}
}
else
{
#ifdef _DEBUG
// The method signature should not contain any generic variables if the generic context is not provided.
SigParser sigParser = vaSignature.CreateSigParser();
_ASSERTE(!MethodSignatureContainsGenericVariables(sigParser));
#endif
}

VASigCookie *pCookie = GetVASigCookieWorker(this, pLoaderModule, vaSignature, typeContext);

RETURN pCookie;
Expand All @@ -4685,9 +4825,7 @@ VASigCookie *Module::GetVASigCookieWorker(Module* pDefiningModule, Module* pLoad
{
CONTRACT(VASigCookie*)
{
THROWS;
GC_TRIGGERS;
MODE_ANY;
STANDARD_VM_CHECK;
POSTCONDITION(CheckPointer(RETVAL));
INJECT_FAULT(COMPlusThrowOM());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ static BlittableGeneric<string> UnmanagedExportedFunctionBlittableGenericString(
return new() { X = Convert.ToInt32(arg) };
}

[UnmanagedCallersOnly]
static unsafe void UnmanagedExportedFunctionRefInt(int* pval, float arg)
{
*pval = Convert.ToInt32(arg);
}

class GenericCaller<T>
{
internal static unsafe T GenericCalli<U>(void* fnptr, U arg)
Expand All @@ -40,6 +46,11 @@ internal static unsafe BlittableGeneric<T> WrappedGenericCalli<U>(void* fnptr, U
{
return ((delegate* unmanaged<U, BlittableGeneric<T>>)fnptr)(arg);
}

internal static unsafe void NonGenericCalli<U>(void* fnptr, ref int val, float arg)
{
((delegate* unmanaged<ref int, float, void>)fnptr)(ref val, arg);
}
}

struct BlittableGeneric<T>
Expand Down Expand Up @@ -81,6 +92,14 @@ public static void RunGenericFunctionPointerTest(float inVal)
outVar = GenericCaller<string>.WrappedGenericCalli((delegate* unmanaged<float, BlittableGeneric<string>>)&UnmanagedExportedFunctionBlittableGenericString, inVal).X;
}
Assert.Equal(expectedValue, outVar);

outVar = 0;
Console.WriteLine("Testing non-GenericCalli with non-blittable argument in a generic caller");
unsafe
{
GenericCaller<string>.NonGenericCalli<string>((delegate* unmanaged<int*, float, void>)&UnmanagedExportedFunctionRefInt, ref outVar, inVal);
}
Assert.Equal(expectedValue, outVar);
}

[ConditionalFact(nameof(CanRunInvalidGenericFunctionPointerTest))]
Expand Down

0 comments on commit b00f084

Please sign in to comment.