Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update entry point finder to support async Main #75808

Merged
merged 7 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.CodeAnalysis;
using Microsoft.VisualStudio.LanguageServices.Implementation.ProjectSystem;

namespace Microsoft.VisualStudio.LanguageServices.CSharp.ProjectSystemShim;

internal sealed class CSharpEntryPointFinder(Compilation compilation)
: AbstractEntryPointFinder(compilation)
{
protected override bool MatchesMainMethodName(string name)
=> name == "Main";

public static IEnumerable<INamedTypeSymbol> FindEntryPoints(Compilation compilation)
{
// This differs from the VB implementation
// (Microsoft.VisualStudio.LanguageServices.VisualBasic.ProjectSystemShim.EntryPointFinder) because we don't
// ever consider forms entry points. Technically, this is wrong but it just doesn't matter since the ref
// assemblies are unlikely to have a random Main() method that matches
var visitor = new CSharpEntryPointFinder(compilation);
visitor.Visit(compilation.SourceModule.GlobalNamespace);
return visitor.EntryPoints;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public int GetValidStartupClasses(IntPtr[] classNames, ref int count)
{
var project = Workspace.CurrentSolution.GetRequiredProject(ProjectSystemProject.Id);
var compilation = project.GetRequiredCompilationAsync(CancellationToken.None).WaitAndGetResult(CancellationToken.None);
var entryPoints = EntryPointFinder.FindEntryPoints(compilation.SourceModule.GlobalNamespace);
var entryPoints = CSharpEntryPointFinder.FindEntryPoints(compilation);

// If classNames is NULL, then we need to populate the number of valid startup
// classes only
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.VisualStudio.LanguageServices.CSharp.ProjectSystemShim;
using Roslyn.Test.Utilities;
using Xunit;

namespace Roslyn.VisualStudio.CSharp.UnitTests.ProjectSystemShim;

[WorkItem("https://github.com/dotnet/roslyn/issues/35376")]
public sealed class EntryPointFinderTests
{
[Theory, CombinatorialData]
public void PositiveTests(
[CombinatorialValues("public", "private", "")] string accessibility,
[CombinatorialValues("void", "int", "System.Int32", "Int32", "ValueTask", "Task", "ValueTask<int>", "Task<int>")] string returnType,
[CombinatorialValues("string[] args", "string[] args1", "")] string parameters)
{
Validate($"static {accessibility} {returnType} Main({parameters})", entryPoints =>
{
Assert.Single(entryPoints);
Assert.Equal("C", entryPoints.Single().Name);
});
}

private static void NegativeTest(string signature)
=> Validate(signature, Assert.Empty);

private static void Validate(string signature, Action<IEnumerable<INamedTypeSymbol>> validate)
{
var compilation = CSharpCompilation.Create("Test", references: [TestBase.MscorlibRef]).AddSyntaxTrees(CSharpSyntaxTree.ParseText($$"""
using System;
using System.Threading.Tasks;

class C
{
{{signature}}
{
}
}
"""));

var entryPoints = CSharpEntryPointFinder.FindEntryPoints(compilation);
validate(entryPoints);
}

[Theory, CombinatorialData]
public void TestWrongName(
[CombinatorialValues("public", "private", "")] string accessibility,
[CombinatorialValues("void", "int", "System.Int32", "Int32", "ValueTask", "Task", "ValueTask<int>", "Task<int>")] string returnType,
[CombinatorialValues("string[] args", "string[] args1", "")] string parameters)
{
NegativeTest($"static {accessibility} {returnType} main({parameters})");
}

[Theory, CombinatorialData]
public void TestNotStatic(
[CombinatorialValues("public", "private", "")] string accessibility,
[CombinatorialValues("void", "int", "System.Int32", "Int32", "ValueTask", "Task", "ValueTask<int>", "Task<int>")] string returnType,
[CombinatorialValues("string[] args", "string[] args1", "")] string parameters)
{
NegativeTest($"{accessibility} {returnType} main({parameters})");
Copy link
Member

@akhera99 akhera99 Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be Main to be clear it's a "not static" test and not a "wrong name" test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup/ good catch.

CyrusNajmabadi marked this conversation as resolved.
Show resolved Hide resolved
}

[Theory, CombinatorialData]
public void TestInvalidReturnType(
[CombinatorialValues("public", "private", "")] string accessibility,
[CombinatorialValues("string", "Task<string>", "ValueTask<string>")] string returnType,
[CombinatorialValues("string[] args", "string[] args1", "")] string parameters)
{
NegativeTest($"static {accessibility} {returnType} Main({parameters})");
}

[Theory, CombinatorialData]
public void TestInvalidParameterTypes(
[CombinatorialValues("public", "private", "")] string accessibility,
[CombinatorialValues("void", "int", "System.Int32", "Int32", "ValueTask", "Task", "ValueTask<int>", "Task<int>")] string returnType,
[CombinatorialValues("string args", "string* args", "int[] args", "string[] args1, string[] args2")] string parameters)
{
NegativeTest($"static {accessibility} {returnType} Main({parameters})");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,59 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Shared.Extensions;

namespace Microsoft.VisualStudio.LanguageServices.Implementation.ProjectSystem;

internal abstract class AbstractEntryPointFinder : SymbolVisitor
internal abstract class AbstractEntryPointFinder(Compilation compilation) : SymbolVisitor
{
protected readonly HashSet<INamedTypeSymbol> EntryPoints = [];

private readonly KnownTaskTypes _knownTaskTypes = new(compilation);

protected abstract bool MatchesMainMethodName(string name);

public override void VisitNamespace(INamespaceSymbol symbol)
{
foreach (var member in symbol.GetMembers())
{
member.Accept(this);
}
}

public override void VisitNamedType(INamedTypeSymbol symbol)
{
foreach (var member in symbol.GetMembers())
{
member.Accept(this);
}
}

public override void VisitMethod(IMethodSymbol symbol)
{
// named Main
if (!MatchesMainMethodName(symbol.Name))
{
return;
}

// static
if (!symbol.IsStatic)
// Similar to the form `static void Main(string[] args)` (and varying permutations).
if (symbol.IsStatic &&
MatchesMainMethodName(symbol.Name) &&
HasValidReturnType(symbol) &&
symbol.Parameters is [{ Type: IArrayTypeSymbol { ElementType.SpecialType: SpecialType.System_String } }] or [])
{
return;
EntryPoints.Add(symbol.ContainingType);
}
}

// returns void or int
if (!symbol.ReturnsVoid && symbol.ReturnType.SpecialType != SpecialType.System_Int32)
{
return;
}
private bool HasValidReturnType(IMethodSymbol symbol)
{
// void
if (symbol.ReturnsVoid)
return true;

// parameterless or takes a string[]
if (symbol.Parameters.Length == 1)
{
var parameter = symbol.Parameters.Single();
if (parameter.Type is IArrayTypeSymbol)
{
var elementType = ((IArrayTypeSymbol)parameter.Type).ElementType;
var specialType = elementType.SpecialType;
var returnType = symbol.ReturnType;

if (specialType == SpecialType.System_String)
{
EntryPoints.Add(symbol.ContainingType);
}
}
}
// int
if (returnType.SpecialType == SpecialType.System_Int32)
return true;

if (!symbol.Parameters.Any())
{
EntryPoints.Add(symbol.ContainingType);
}
// Task or ValueTask
// Task<int> or ValueTask<int>
return _knownTaskTypes.IsTaskLike(returnType) &&
returnType.GetTypeArguments() is [] or [{ SpecialType: SpecialType.System_Int32 }];
}

protected abstract bool MatchesMainMethodName(string name);
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ Imports Microsoft.CodeAnalysis
Imports Microsoft.VisualStudio.LanguageServices.Implementation.ProjectSystem

Namespace Microsoft.VisualStudio.LanguageServices.VisualBasic.ProjectSystemShim
Friend Class EntryPointFinder
Friend NotInheritable Class VisualBasicEntryPointFinder
Inherits AbstractEntryPointFinder

Private ReadOnly _findFormsOnly As Boolean

Public Sub New(findFormsOnly As Boolean)
Me._findFormsOnly = findFormsOnly
Public Sub New(compilation As Compilation, findFormsOnly As Boolean)
MyBase.New(compilation)
_findFormsOnly = findFormsOnly
End Sub

Protected Overrides Function MatchesMainMethodName(name As String) As Boolean
Expand All @@ -23,8 +24,10 @@ Namespace Microsoft.VisualStudio.LanguageServices.VisualBasic.ProjectSystemShim
Return String.Equals(name, "Main", StringComparison.OrdinalIgnoreCase)
End Function

Public Shared Function FindEntryPoints(symbol As INamespaceSymbol, findFormsOnly As Boolean) As IEnumerable(Of INamedTypeSymbol)
Dim visitor = New EntryPointFinder(findFormsOnly)
Public Shared Function FindEntryPoints(compilation As Compilation, findFormsOnly As Boolean) As IEnumerable(Of INamedTypeSymbol)
Dim visitor = New VisualBasicEntryPointFinder(compilation, findFormsOnly)
Dim symbol = compilation.SourceModule.GlobalNamespace

' Attempt to only search source symbols
' Some callers will give a symbol that is not part of a compilation
If symbol.ContainingCompilation IsNot Nothing Then
Expand All @@ -49,6 +52,5 @@ Namespace Microsoft.VisualStudio.LanguageServices.VisualBasic.ProjectSystemShim

MyBase.VisitNamedType(symbol)
End Sub

End Class
End Namespace

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Namespace Microsoft.VisualStudio.LanguageServices.VisualBasic.ProjectSystemShim
ByVal pcActualItems As IntPtr,
findFormsOnly As Boolean)

Dim entryPoints = EntryPointFinder.FindEntryPoints(compilation.SourceModule.GlobalNamespace, findFormsOnly:=findFormsOnly)
Dim entryPoints = VisualBasicEntryPointFinder.FindEntryPoints(compilation, findFormsOnly:=findFormsOnly)

' If called with cItems = 0 and pcActualItems != NULL, GetEntryPointsList returns in pcActualItems the number of items available.
If cItems = 0 AndAlso pcActualItems <> Nothing Then
Expand Down
Loading