diff --git a/src/Microsoft.Cci.Extensions/Writers/CSharp/CSDeclarationWriter.Methods.cs b/src/Microsoft.Cci.Extensions/Writers/CSharp/CSDeclarationWriter.Methods.cs index 33e48c558c7..d2cecded023 100644 --- a/src/Microsoft.Cci.Extensions/Writers/CSharp/CSDeclarationWriter.Methods.cs +++ b/src/Microsoft.Cci.Extensions/Writers/CSharp/CSDeclarationWriter.Methods.cs @@ -33,7 +33,14 @@ private void WriteMethodDefinition(IMethodDefinition method) return; } - if (!method.ContainingTypeDefinition.IsInterface) + if (method.ContainingTypeDefinition.IsInterface) + { + if (method.IsMethodUnsafe()) + { + WriteKeyword("unsafe"); + } + } + else { if (!method.IsExplicitInterfaceMethod() && !method.IsStaticConstructor) { @@ -42,6 +49,7 @@ private void WriteMethodDefinition(IMethodDefinition method) WriteMethodModifiers(method); } + WriteInterfaceMethodModifiers(method); WriteMethodDefinitionSignature(method); WriteMethodBody(method); @@ -245,8 +253,11 @@ private void WriteInterfaceMethodModifiers(IMethodDefinition method) private void WriteMethodModifiers(IMethodDefinition method) { - if (method.IsMethodUnsafe()) + if (method.IsMethodUnsafe() || + (method.IsConstructor && IsBaseConstructorCallUnsafe(method.ContainingTypeDefinition))) + { WriteKeyword("unsafe"); + } if (method.IsStatic) WriteKeyword("static"); @@ -344,6 +355,11 @@ private void WritePrivateConstructor(ITypeDefinition type) return; WriteVisibility(TypeMemberVisibility.Assembly); + if (IsBaseConstructorCallUnsafe(type)) + { + WriteKeyword("unsafe"); + } + WriteIdentifier(((INamedEntity)type).Name); WriteSymbol("("); WriteSymbol(")"); @@ -353,24 +369,7 @@ private void WritePrivateConstructor(ITypeDefinition type) private void WriteBaseConstructorCall(ITypeDefinition type) { - if (!_forCompilation) - return; - - ITypeDefinition baseType = type.BaseClasses.FirstOrDefault().GetDefinitionOrNull(); - - if (baseType == null) - return; - - var ctors = baseType.Methods.Where(m => m.IsConstructor && _filter.Include(m) && !m.Attributes.Any(a => a.IsObsoleteWithUsageTreatedAsCompilationError())); - - var defaultCtor = ctors.Where(c => c.ParameterCount == 0); - - // Don't need a base call if we have a default constructor - if (defaultCtor.Any()) - return; - - var ctor = ctors.FirstOrDefault(); - + var ctor = GetBaseConstructorForCall(type); if (ctor == null) return; @@ -382,6 +381,52 @@ private void WriteBaseConstructorCall(ITypeDefinition type) WriteSymbol(")"); } + private bool IsBaseConstructorCallUnsafe(ITypeDefinition type) + { + var constructor = GetBaseConstructorForCall(type); + if (constructor == null) + { + return false; + } + + foreach (var parameter in constructor.Parameters) + { + if (parameter.Type.IsUnsafeType()) + { + return true; + } + } + + return false; + } + + private IMethodDefinition GetBaseConstructorForCall(ITypeDefinition type) + { + if (!_forCompilation) + { + // No need to generate a call to a base constructor. + return null; + } + + var baseType = type.BaseClasses.FirstOrDefault().GetDefinitionOrNull(); + if (baseType == null) + { + // No base type to worry about. + return null; + } + + var constructors = baseType.Methods.Where( + m => m.IsConstructor && _filter.Include(m) && !m.Attributes.Any(a => a.IsObsoleteWithUsageTreatedAsCompilationError())); + + if (constructors.Any(c => c.ParameterCount == 0)) + { + // Don't need a base call if base class has a default constructor. + return null; + } + + return constructors.FirstOrDefault(); + } + /// /// When generated .notsupported.cs files, we need to generate calls to the base constructor. /// However, if the base constructor doesn't accept null, passing default(T) will cause a compile