Skip to content

Commit

Permalink
Reuse previous lambda body binding if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
cston committed Apr 16, 2020
1 parent 5c681cc commit b43120a
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Immutable;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp
{
Expand Down Expand Up @@ -54,6 +55,16 @@ protected override UnboundLambdaState WithCachingCore(bool includeCache)
return new QueryUnboundLambdaState(Binder, _rangeVariableMap, _parameters, _bodyFactory, includeCache);
}

protected override BoundExpression GetReusableLambdaExpressionBody(BoundBlock body)
{
return null;
}

protected override BoundBlock CreateBlockFromExpression(LambdaSymbol lambdaSymbol, Binder lambdaBodyBinder, BoundExpression expression, DiagnosticBag diagnostics)
{
throw ExceptionUtilities.Unreachable;
}

protected override BoundBlock BindLambdaBody(LambdaSymbol lambdaSymbol, Binder lambdaBodyBinder, DiagnosticBag diagnostics)
{
return _bodyFactory(lambdaSymbol, lambdaBodyBinder, diagnostics);
Expand Down
15 changes: 11 additions & 4 deletions src/Compilers/CSharp/Portable/Binder/Binder_Statements.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using System.Transactions;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.PooledObjects;
Expand Down Expand Up @@ -3155,7 +3153,7 @@ private static bool IsValidExpressionBody(SyntaxNode expressionSyntax, BoundExpr
}

/// <summary>
/// Binds an expression-bodied member with expression e as either { return e;} or { e; }.
/// Binds an expression-bodied member with expression e as either { return e; } or { e; }.
/// </summary>
internal virtual BoundBlock BindExpressionBodyAsBlock(ArrowExpressionClauseSyntax expressionBody,
DiagnosticBag diagnostics)
Expand All @@ -3179,7 +3177,7 @@ static BoundBlock bindExpressionBodyAsBlockInternal(ArrowExpressionClauseSyntax
}

/// <summary>
/// Binds a lambda with expression e as either { return e;} or { e; }.
/// Binds a lambda with expression e as either { return e; } or { e; }.
/// </summary>
public BoundBlock BindLambdaExpressionAsBlock(ExpressionSyntax body, DiagnosticBag diagnostics)
{
Expand All @@ -3195,6 +3193,15 @@ public BoundBlock BindLambdaExpressionAsBlock(ExpressionSyntax body, DiagnosticB
return bodyBinder.CreateBlockFromExpression(body, bodyBinder.GetDeclaredLocalsForScope(body), refKind, expression, expressionSyntax, diagnostics);
}

public BoundBlock BindLambdaExpressionAsBlockContinued(ExpressionSyntax body, BoundExpression expression, DiagnosticBag diagnostics)
{
Binder bodyBinder = this.GetBinder(body);
Debug.Assert(bodyBinder != null);

Debug.Assert(body.Kind() != SyntaxKind.RefExpression);
return bodyBinder.CreateBlockFromExpression(body, bodyBinder.GetDeclaredLocalsForScope(body), RefKind.None, expression, body, diagnostics);
}

private BindValueKind GetRequiredReturnValueKind(RefKind refKind)
{
BindValueKind requiredValueKind = BindValueKind.RValue;
Expand Down
107 changes: 66 additions & 41 deletions src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,20 @@ internal sealed partial class BoundLocalFunctionStatement : IBoundLambdaOrFuncti
BoundBlock IBoundLambdaOrFunction.Body { get => this.Body; }
}

internal struct InferredLambdaReturnType
internal readonly struct InferredLambdaReturnType
{
internal readonly int NumExpressions;
internal readonly bool HadExpressionlessReturn;
internal readonly bool HasReturnValue;
internal readonly RefKind RefKind;
internal readonly TypeWithAnnotations TypeWithAnnotations;
internal readonly ImmutableArray<DiagnosticInfo> UseSiteDiagnostics;

internal InferredLambdaReturnType(
int numExpressions,
bool hadExpressionlessReturn,
bool hasReturnValue,
RefKind refKind,
TypeWithAnnotations typeWithAnnotations,
ImmutableArray<DiagnosticInfo> useSiteDiagnostics)
{
NumExpressions = numExpressions;
HadExpressionlessReturn = hadExpressionlessReturn;
HasReturnValue = hasReturnValue;
RefKind = refKind;
TypeWithAnnotations = typeWithAnnotations;
UseSiteDiagnostics = useSiteDiagnostics;
Expand Down Expand Up @@ -163,7 +160,7 @@ internal static InferredLambdaReturnType InferReturnType(ArrayBuilder<(BoundRetu
BoundNode node, CSharpCompilation compilation, ConversionsBase conversions, TypeSymbol delegateType, bool isAsync)
{
var types = ArrayBuilder<(BoundExpression, TypeWithAnnotations)>.GetInstance();
bool hasReturnWithoutArgument = false;
bool hasReturnValue = false;
RefKind refKind = RefKind.None;
foreach (var (returnStatement, type) in returnTypes)
{
Expand All @@ -173,19 +170,17 @@ internal static InferredLambdaReturnType InferReturnType(ArrayBuilder<(BoundRetu
refKind = rk;
}

if ((object)type.Type == NoReturnExpression)
{
hasReturnWithoutArgument = true;
}
else
if ((object)type.Type != NoReturnExpression)
{
hasReturnValue = true;
types.Add((returnStatement.ExpressionOpt, type));
}
}

HashSet<DiagnosticInfo> useSiteDiagnostics = null;
var bestType = CalculateReturnType(compilation, conversions, delegateType, types, isAsync, node, ref useSiteDiagnostics);
return new InferredLambdaReturnType(types.Count, hasReturnWithoutArgument, refKind, bestType, useSiteDiagnostics.AsImmutableOrEmpty());
types.Free();
return new InferredLambdaReturnType(hasReturnValue, refKind, bestType, useSiteDiagnostics.AsImmutableOrEmpty());
}

private static TypeWithAnnotations CalculateReturnType(
Expand Down Expand Up @@ -465,6 +460,9 @@ internal UnboundLambdaState WithCaching(bool includeCache)
public abstract Location ParameterLocation(int index);
public abstract TypeWithAnnotations ParameterTypeWithAnnotations(int index);
public abstract RefKind RefKind(int index);

protected abstract BoundExpression GetReusableLambdaExpressionBody(BoundBlock body);
protected abstract BoundBlock CreateBlockFromExpression(LambdaSymbol lambdaSymbol, Binder lambdaBodyBinder, BoundExpression expression, DiagnosticBag diagnostics);
protected abstract BoundBlock BindLambdaBody(LambdaSymbol lambdaSymbol, Binder lambdaBodyBinder, DiagnosticBag diagnostics);

public virtual void GenerateAnonymousFunctionConversionError(DiagnosticBag diagnostics, TypeSymbol targetType)
Expand Down Expand Up @@ -541,20 +539,36 @@ private bool DelegateNeedsReturn(MethodSymbol invokeMethod)

private BoundLambda ReallyBind(NamedTypeSymbol delegateType)
{
// When binding for real (not for return inference), we cannot reuse a body of a lambda
// previously bound for return type inference because we have not converted the returned
// expression(s) to the return type.
var invokeMethod = DelegateInvokeMethod(delegateType);
var returnType = DelegateReturnTypeWithAnnotations(invokeMethod, out RefKind refKind);

LambdaSymbol lambdaSymbol = CreateLambdaSymbol(
delegateType,
Binder.ContainingMemberOrLambda,
out MethodSymbol invokeMethod,
out TypeWithAnnotations returnType,
out DiagnosticBag diagnostics,
out ReturnInferenceCacheKey cacheKey);
LambdaSymbol lambdaSymbol;
Binder lambdaBodyBinder;
BoundBlock block;

CSharpCompilation compilation = Binder.Compilation;
Binder lambdaBodyBinder = new ExecutableCodeBinder(_unboundLambda.Syntax, lambdaSymbol, ParameterBinder(lambdaSymbol, Binder));
var diagnostics = DiagnosticBag.GetInstance();
var compilation = Binder.Compilation;
var cacheKey = ReturnInferenceCacheKey.Create(delegateType, IsAsync);

// When binding for real (not for return inference), there is still
// a good chance that we could reuse a body of a lambda previous bound for
// return type inference.
if (_returnInferenceCache.TryGetValue(cacheKey, out BoundLambda returnInferenceLambda) &&
GetReusableLambdaExpressionBody(returnInferenceLambda.Body) is BoundExpression expression &&
(lambdaSymbol = returnInferenceLambda.Symbol).RefKind == refKind &&
(object)LambdaSymbol.InferenceFailureReturnType != lambdaSymbol.ReturnType &&
lambdaSymbol.ReturnTypeWithAnnotations.Equals(returnType, TypeCompareKind.ConsiderEverything))
{
lambdaBodyBinder = returnInferenceLambda.Binder;
block = CreateBlockFromExpression(lambdaSymbol, lambdaBodyBinder, expression, diagnostics);
diagnostics.AddRange(returnInferenceLambda.Diagnostics);
}
else
{
lambdaSymbol = CreateLambdaSymbol(Binder.ContainingMemberOrLambda, returnType, diagnostics, cacheKey.ParameterTypes, cacheKey.ParameterRefKinds, refKind);
lambdaBodyBinder = new ExecutableCodeBinder(_unboundLambda.Syntax, lambdaSymbol, ParameterBinder(lambdaSymbol, Binder));
block = BindLambdaBody(lambdaSymbol, lambdaBodyBinder, diagnostics);
}

if (lambdaSymbol.RefKind == CodeAnalysis.RefKind.RefReadOnly)
{
Expand Down Expand Up @@ -583,8 +597,6 @@ private BoundLambda ReallyBind(NamedTypeSymbol delegateType)
ParameterHelpers.EnsureNullableAttributeExists(compilation, lambdaSymbol, lambdaParameters, diagnostics, modifyCompilation: false);
// Note: we don't need to warn on annotations used in #nullable disable context for lambdas, as this is handled in binding already

BoundBlock block = BindLambdaBody(lambdaSymbol, lambdaBodyBinder, diagnostics);

((ExecutableCodeBinder)lambdaBodyBinder).ValidateIteratorMethods(diagnostics);
ValidateUnsafeParameters(diagnostics, cacheKey.ParameterTypes);

Expand Down Expand Up @@ -626,16 +638,6 @@ private BoundLambda ReallyBind(NamedTypeSymbol delegateType)
return result;
}

private LambdaSymbol CreateLambdaSymbol(NamedTypeSymbol delegateType, Symbol containingSymbol, out MethodSymbol invokeMethod, out TypeWithAnnotations returnType, out DiagnosticBag diagnostics, out ReturnInferenceCacheKey cacheKey)
{
invokeMethod = DelegateInvokeMethod(delegateType);
returnType = DelegateReturnTypeWithAnnotations(invokeMethod, out RefKind refKind);
diagnostics = DiagnosticBag.GetInstance();
cacheKey = ReturnInferenceCacheKey.Create(delegateType, IsAsync);

return CreateLambdaSymbol(containingSymbol, returnType, diagnostics, cacheKey.ParameterTypes, cacheKey.ParameterRefKinds, refKind);
}

internal LambdaSymbol CreateLambdaSymbol(
Symbol containingSymbol,
TypeWithAnnotations returnType,
Expand All @@ -655,7 +657,10 @@ internal LambdaSymbol CreateLambdaSymbol(

internal LambdaSymbol CreateLambdaSymbol(NamedTypeSymbol delegateType, Symbol containingSymbol)
{
return CreateLambdaSymbol(delegateType, containingSymbol, out _, out _, out _, out _);
var invokeMethod = DelegateInvokeMethod(delegateType);
var returnType = DelegateReturnTypeWithAnnotations(invokeMethod, out RefKind refKind);
var cacheKey = ReturnInferenceCacheKey.Create(delegateType, IsAsync);
return CreateLambdaSymbol(containingSymbol, returnType, new DiagnosticBag(), cacheKey.ParameterTypes, cacheKey.ParameterRefKinds, refKind);
}

private void ValidateUnsafeParameters(DiagnosticBag diagnostics, ImmutableArray<TypeWithAnnotations> targetParameterTypes)
Expand Down Expand Up @@ -741,6 +746,7 @@ private BoundLambda ReallyInferReturnType(
var block = BindLambdaBody(lambdaSymbol, lambdaBodyBinder, diagnostics);
return (lambdaSymbol, block, lambdaBodyBinder, diagnostics);
}

public BoundLambda BindForReturnTypeInference(NamedTypeSymbol delegateType)
{
var cacheKey = ReturnInferenceCacheKey.Create(delegateType, IsAsync);
Expand Down Expand Up @@ -927,7 +933,7 @@ BoundLambda rebind(BoundLambda lambda)
returnType = DelegateReturnTypeWithAnnotations(invokeMethod, out refKind);
if (!returnType.HasType || returnType.Type.ContainsTypeParameter())
{
var t = (inferredReturnType.HadExpressionlessReturn || inferredReturnType.NumExpressions == 0)
var t = !inferredReturnType.HasReturnValue
? this.Binder.Compilation.GetSpecialType(SpecialType.System_Void)
: this.Binder.CreateErrorType();
returnType = TypeWithAnnotations.Create(t);
Expand All @@ -943,7 +949,7 @@ BoundLambda rebind(BoundLambda lambda)
diagnostics.ToReadOnlyAndFree(),
lambdaBodyBinder,
delegateType,
new InferredLambdaReturnType(inferredReturnType.NumExpressions, inferredReturnType.HadExpressionlessReturn, refKind, returnType, ImmutableArray<DiagnosticInfo>.Empty))
new InferredLambdaReturnType(inferredReturnType.HasReturnValue, refKind, returnType, ImmutableArray<DiagnosticInfo>.Empty))
{ WasCompilerGenerated = _unboundLambda.WasCompilerGenerated };
}
}
Expand Down Expand Up @@ -1216,6 +1222,25 @@ protected override UnboundLambdaState WithCachingCore(bool includeCache)
return new PlainUnboundLambdaState(unboundLambda: null, Binder, _parameterNames, _parameterIsDiscardOpt, _parameterTypesWithAnnotations, _parameterRefKinds, _isAsync, includeCache);
}

protected override BoundExpression GetReusableLambdaExpressionBody(BoundBlock body)
{
if (IsExpressionLambda)
{
var statements = body.Statements;
if (statements.Length == 1 &&
statements[0] is BoundReturnStatement { RefKind: Microsoft.CodeAnalysis.RefKind.None, ExpressionOpt: BoundExpression expr })
{
return expr;
}
}
return null;
}

protected override BoundBlock CreateBlockFromExpression(LambdaSymbol lambdaSymbol, Binder lambdaBodyBinder, BoundExpression expression, DiagnosticBag diagnostics)
{
return lambdaBodyBinder.BindLambdaExpressionAsBlockContinued((ExpressionSyntax)this.Body, expression, diagnostics);
}

protected override BoundBlock BindLambdaBody(LambdaSymbol lambdaSymbol, Binder lambdaBodyBinder, DiagnosticBag diagnostics)
{
if (this.IsExpressionLambda)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ static class Ext

[ConditionalFactAttribute(typeof(IsRelease))]
[WorkItem(40495, "https://github.com/dotnet/roslyn/issues/40495")]
public void NestedLambdas()
public void NestedLambdas_01()
{
var source =
@"#nullable enable
Expand All @@ -302,6 +302,43 @@ static void Main()
Enumerable.Range(0, 1).Sum(f =>
Enumerable.Range(0, 1).Count(g => true)))))));
}
}";
var comp = CreateCompilation(source);
comp.VerifyDiagnostics();
}

[ConditionalFactAttribute(typeof(IsRelease))]
[WorkItem(1083969, "https://devdiv.visualstudio.com/DevDiv/_workitems/edit/1083969")]
public void NestedLambdas_02()
{
var source =
@"using System.Collections.Generic;
using System.Linq;
class Program
{
static void F(IEnumerable<int[]> x)
{
x.GroupBy(y => y[1]).SelectMany(x =>
x.GroupBy(y => y[2]).SelectMany(x =>
x.GroupBy(y => y[3]).SelectMany(x =>
x.GroupBy(y => y[4]).SelectMany(x =>
x.GroupBy(y => y[5]).SelectMany(x =>
x.GroupBy(y => y[6]).SelectMany(x =>
x.GroupBy(y => y[7]).SelectMany(x =>
x.GroupBy(y => y[8]).SelectMany(x =>
x.GroupBy(y => y[9]).SelectMany(x =>
x.GroupBy(y => y[0]).SelectMany(x =>
x.GroupBy(y => y[1]).SelectMany(x =>
x.GroupBy(y => y[2]).SelectMany(x =>
x.GroupBy(y => y[3]).SelectMany(x =>
x.GroupBy(y => y[4]).SelectMany(x =>
x.GroupBy(y => y[5]).SelectMany(x =>
x.GroupBy(y => y[6]).SelectMany(x =>
x.GroupBy(y => y[7]).SelectMany(x =>
x.GroupBy(y => y[8]).SelectMany(x =>
x.GroupBy(y => y[9]).SelectMany(x =>
x.GroupBy(y => y[0]).Select(x => x.Average(z => z[0])))))))))))))))))))));
}
}";
var comp = CreateCompilation(source);
comp.VerifyDiagnostics();
Expand Down

0 comments on commit b43120a

Please sign in to comment.