Skip to content

Commit

Permalink
Allow devirt into abstract classes if we saw a non-abstract child (do…
Browse files Browse the repository at this point in the history
…tnet#108379)

We avoid devirtualizing into abstract classes because whole program view might have optimized away the method bodies and devirtualizing them doesn't lead to anything good.

However, if the whole program view had a non-abstract child of this, we can no longer optimize this out and devirtualization should be fine.

Fixes issue encountered in dotnet#108153 (comment)
  • Loading branch information
MichalStrehovsky authored and sirntar committed Oct 3, 2024
1 parent 7cb420f commit 8a8c219
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/coreclr/tools/Common/Compiler/MethodExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ public static bool NotCallableWithoutOwningEEType(this MethodDesc method)
TypeDesc owningType = method.OwningType;
return !method.Signature.IsStatic && /* Static methods don't have this */
!owningType.IsValueType && /* Value type instance methods take a ref to data */
!owningType.IsInterface && /* Interface MethodTable can be optimized away but the instance method can still be callable (`this` is of a non-interface type) */
!owningType.IsArrayTypeWithoutGenericInterfaces() && /* Type loader can make these at runtime */
(owningType is not MetadataType mdType || !mdType.IsModuleType) && /* Compiler parks some instance methods on the <Module> type */
!method.IsSharedByGenericInstantiations; /* Current impl limitation; can be lifted */
Expand Down
39 changes: 23 additions & 16 deletions src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/ILScanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -541,16 +541,27 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
}

TypeDesc canonType = type.ConvertToCanonForm(CanonicalFormKind.Specific);
TypeDesc baseType;

if (!canonType.IsDefType || !((MetadataType)canonType).IsAbstract)
if (canonType is not MetadataType { IsAbstract: true })
{
_canonConstructedTypes.Add(canonType.GetClosestDefType());
baseType = canonType.BaseType;
while (baseType != null)
{
baseType = baseType.ConvertToCanonForm(CanonicalFormKind.Specific);
if (!_canonConstructedTypes.Add(baseType))
break;
baseType = baseType.BaseType;
}
}

TypeDesc baseType = canonType.BaseType;
bool added = true;
while (baseType != null && added)
baseType = canonType.BaseType;
while (baseType != null)
{
baseType = baseType.ConvertToCanonForm(CanonicalFormKind.Specific);
added = _unsealedTypes.Add(baseType);
if (!_unsealedTypes.Add(baseType))
break;
baseType = baseType.BaseType;
}

Expand Down Expand Up @@ -686,20 +697,16 @@ public override bool IsEffectivelySealed(MethodDesc method)

protected override MethodDesc ResolveVirtualMethod(MethodDesc declMethod, DefType implType, out CORINFO_DEVIRTUALIZATION_DETAIL devirtualizationDetail)
{
MethodDesc result = base.ResolveVirtualMethod(declMethod, implType, out devirtualizationDetail);
if (result != null)
// If we would resolve into a type that wasn't seen as allocated, don't allow devirtualization.
// It would go past what we scanned in the scanner and that doesn't lead to good things.
if (!_canonConstructedTypes.Contains(implType.ConvertToCanonForm(CanonicalFormKind.Specific)))
{
// If we would resolve into a type that wasn't seen as allocated, don't allow devirtualization.
// It would go past what we scanned in the scanner and that doesn't lead to good things.
if (!_canonConstructedTypes.Contains(result.OwningType.ConvertToCanonForm(CanonicalFormKind.Specific)))
{
// FAILED_BUBBLE_IMPL_NOT_REFERENCEABLE is close enough...
devirtualizationDetail = CORINFO_DEVIRTUALIZATION_DETAIL.CORINFO_DEVIRTUALIZATION_FAILED_BUBBLE_IMPL_NOT_REFERENCEABLE;
return null;
}
// FAILED_BUBBLE_IMPL_NOT_REFERENCEABLE is close enough...
devirtualizationDetail = CORINFO_DEVIRTUALIZATION_DETAIL.CORINFO_DEVIRTUALIZATION_FAILED_BUBBLE_IMPL_NOT_REFERENCEABLE;
return null;
}

return result;
return base.ResolveVirtualMethod(declMethod, implType, out devirtualizationDetail);
}

public override bool CanReferenceConstructedMethodTable(TypeDesc type)
Expand Down
34 changes: 34 additions & 0 deletions src/tests/nativeaot/SmokeTests/UnitTests/Devirtualization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Devirtualization
{
internal static int Run()
{
TestDevirtualizationIntoAbstract.Run();
RegressionBug73076.Run();
RegressionGenericHierarchy.Run();
DevirtualizationCornerCaseTests.Run();
Expand All @@ -19,6 +20,39 @@ internal static int Run()
return 100;
}

class TestDevirtualizationIntoAbstract
{
class Something { }

abstract class Base
{
[MethodImpl(MethodImplOptions.NoInlining)]
public virtual Type GetSomething() => typeof(Something);
}

sealed class Derived : Base { }

class Unrelated : Base
{
public override Type GetSomething() => typeof(Unrelated);
}

public static void Run()
{
TestUnrelated(new Unrelated());

// We were getting a scanning failure because GetSomething got devirtualized into
// Base.GetSomething, but that's unreachable.
Test(null);

[MethodImpl(MethodImplOptions.NoInlining)]
static Type Test(Derived d) => d?.GetSomething();

[MethodImpl(MethodImplOptions.NoInlining)]
static Type TestUnrelated(Base d) => d?.GetSomething();
}
}

class RegressionBug73076
{
interface IFactory
Expand Down
143 changes: 143 additions & 0 deletions src/tests/nativeaot/SmokeTests/UnitTests/Interfaces.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public static int Run()

TestPublicAndNonpublicDifference.Run();
TestDefaultInterfaceMethods.Run();
TestDefaultInterfaceMethodsDevirtNoInline.Run();
TestDefaultInterfaceMethodsNoDevirt.Run();
TestDefaultInterfaceVariance.Run();
TestVariantInterfaceOptimizations.Run();
TestSharedInterfaceMethods.Run();
Expand Down Expand Up @@ -581,6 +583,147 @@ public static void Run()
}
}

class TestDefaultInterfaceMethodsDevirtNoInline
{
interface IFoo
{
[MethodImpl(MethodImplOptions.NoInlining)]
int GetNumber() => 42;
}

interface IBar : IFoo
{
[MethodImpl(MethodImplOptions.NoInlining)]
int IFoo.GetNumber() => 43;
}

class Foo : IFoo { }
class Bar : IBar { }

class Baz : IFoo
{
[MethodImpl(MethodImplOptions.NoInlining)]
public int GetNumber() => 100;
}

interface IFoo<T>
{
[MethodImpl(MethodImplOptions.NoInlining)]
Type GetInterfaceType() => typeof(IFoo<T>);
}

class Foo<T> : IFoo<T> { }

class Base : IFoo
{
[MethodImpl(MethodImplOptions.NoInlining)]
int IFoo.GetNumber() => 100;
}

class Derived : Base, IBar { }

public static void Run()
{
Console.WriteLine("Testing default interface methods that can be devirtualized but not inlined...");

typeof(IFoo).ToString();

if (((IFoo)new Foo()).GetNumber() != 42)
throw new Exception();

if (((IFoo)new Bar()).GetNumber() != 43)
throw new Exception();

if (((IFoo)new Baz()).GetNumber() != 100)
throw new Exception();

if (((IFoo)new Derived()).GetNumber() != 100)
throw new Exception();

if (((IFoo<object>)new Foo<object>()).GetInterfaceType() != typeof(IFoo<object>))
throw new Exception();

if (((IFoo<int>)new Foo<int>()).GetInterfaceType() != typeof(IFoo<int>))
throw new Exception();
}
}

class TestDefaultInterfaceMethodsNoDevirt
{
interface IFoo
{
int GetNumber() => 42;
}

interface IBar : IFoo
{
int IFoo.GetNumber() => 43;
}

class Foo : IFoo { }
class Bar : IBar { }

class Baz : IFoo
{
public int GetNumber() => 100;
}

interface IFoo<T>
{
Type GetInterfaceType() => typeof(IFoo<T>);
}

class Foo<T> : IFoo<T> { }

class Base : IFoo
{
int IFoo.GetNumber() => 100;
}

class Derived : Base, IBar { }

public static void Run()
{
Console.WriteLine("Testing default interface methods that cannot be devirtualized...");

if (GetFoo().GetNumber() != 42)
throw new Exception();

[MethodImpl(MethodImplOptions.NoInlining)]
static IFoo GetFoo() => new Foo();

if (GetBar().GetNumber() != 43)
throw new Exception();

[MethodImpl(MethodImplOptions.NoInlining)]
static IFoo GetBar() => new Bar();

if (GetBaz().GetNumber() != 100)
throw new Exception();

[MethodImpl(MethodImplOptions.NoInlining)]
static IFoo GetBaz() => new Baz();

if (GetDerived().GetNumber() != 100)
throw new Exception();

[MethodImpl(MethodImplOptions.NoInlining)]
static IFoo GetDerived() => new Derived();

if (GetFooObject().GetInterfaceType() != typeof(IFoo<object>))
throw new Exception();

[MethodImpl(MethodImplOptions.NoInlining)]
static IFoo<object> GetFooObject() => new Foo<object>();

if (GetFooInt().GetInterfaceType() != typeof(IFoo<int>))
throw new Exception();

[MethodImpl(MethodImplOptions.NoInlining)]
static IFoo<int> GetFooInt() => new Foo<int>();
}
}

class TestDefaultInterfaceVariance
{
class Foo : IVariant<string>, IVariant<object>
Expand Down

0 comments on commit 8a8c219

Please sign in to comment.