Skip to content

Commit

Permalink
Report a diagnostic for return types with HResult-like named structur…
Browse files Browse the repository at this point in the history
…es and provide a code-fix to do the correct marshalling (#90282)
  • Loading branch information
jkoritzinsky committed Aug 10, 2023
1 parent 2d03dc4 commit d303781
Show file tree
Hide file tree
Showing 30 changed files with 581 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Editing;

namespace Microsoft.Interop.Analyzers
{
[ExportCodeFixProvider(LanguageNames.CSharp), Shared]
public sealed class AddMarshalAsToElementFixer : CodeFixProvider
{
public override FixAllProvider? GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;

public override ImmutableArray<string> FixableDiagnosticIds => ImmutableArray.Create(GeneratorDiagnostics.Ids.NotRecommendedGeneratedComInterfaceUsage);

public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
// Get the syntax root and semantic model
Document doc = context.Document;
SyntaxNode? root = await doc.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
if (root == null)
return;

SyntaxNode node = root.FindNode(context.Span);

foreach (var diagnostic in context.Diagnostics)
{
if (!diagnostic.Properties.TryGetValue(GeneratorDiagnosticProperties.AddMarshalAsAttribute, out string? addMarshalAsAttribute))
{
continue;
}

foreach (var unmanagedType in addMarshalAsAttribute.Split(','))
{
string unmanagedTypeName = unmanagedType.Trim();
context.RegisterCodeFix(
CodeAction.Create(
$"Add [MarshalAs(UnmanagedType.{unmanagedTypeName})]",
async ct =>
{
DocumentEditor editor = await DocumentEditor.CreateAsync(doc, ct).ConfigureAwait(false);
SyntaxGenerator gen = editor.Generator;
SyntaxNode marshalAsAttribute = gen.Attribute(
TypeNames.System_Runtime_InteropServices_MarshalAsAttribute,
gen.AttributeArgument(
gen.MemberAccessExpression(
gen.DottedName(TypeNames.System_Runtime_InteropServices_UnmanagedType),
gen.IdentifierName(unmanagedTypeName.Trim()))));
if (node.IsKind(SyntaxKind.MethodDeclaration))
{
editor.AddReturnAttribute(node, marshalAsAttribute);
}
else
{
editor.AddAttribute(node, marshalAsAttribute);
}
return editor.GetChangedDocument();
},
$"AddUnmanagedType.{unmanagedTypeName}"),
diagnostic);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
Expand Down Expand Up @@ -111,6 +112,12 @@ private static async Task ConvertComImportToGeneratedComInterfaceAsync(DocumentE
var generatedDeclaration = member;

generatedDeclaration = AddExplicitDefaultBoolMarshalling(gen, method, generatedDeclaration, "VariantBool");

if (method.MethodImplementationFlags.HasFlag(MethodImplAttributes.PreserveSig))
{
generatedDeclaration = AddHResultStructAsErrorMarshalling(gen, method, generatedDeclaration);
}

editor.ReplaceNode(member, generatedDeclaration);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplicati
.WithTypeParameterList(context.ContainingSyntax.TypeParameters)
.AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate))));

private static bool IsHResultLikeType(ManagedTypeInfo type)
{
string typeName = type.FullTypeName.Split('.', ':')[^1];
return typeName.Equals("hr", StringComparison.OrdinalIgnoreCase)
|| typeName.Equals("hresult", StringComparison.OrdinalIgnoreCase);
}

private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo owningInterface, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
Expand Down Expand Up @@ -280,7 +287,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
{
if ((returnSwappedSignatureElements[i].ManagedType is SpecialTypeInfo { SpecialType: SpecialType.System_Int32 or SpecialType.System_Enum } or EnumTypeInfo
&& returnSwappedSignatureElements[i].MarshallingAttributeInfo.Equals(NoMarshallingInfo.Instance))
|| (returnSwappedSignatureElements[i].ManagedType.FullTypeName.Split('.', ':').LastOrDefault()?.ToLowerInvariant() is "hr" or "hresult"))
|| (IsHResultLikeType(returnSwappedSignatureElements[i].ManagedType)))
{
generatorDiagnostics.ReportDiagnostic(DiagnosticInfo.Create(GeneratorDiagnostics.ComMethodManagedReturnWillBeOutVariable, symbol.Locations[0]));
}
Expand Down Expand Up @@ -310,6 +317,23 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
})
};
}
else
{
// If our method is PreserveSig, we will notify the user if they are returning a type that may be an HRESULT type
// that is defined as a structure. These types used to work with built-in COM interop, but they do not work with
// source-generated interop as we now use the MemberFunction calling convention, which is more correct.
TypePositionInfo? managedReturnInfo = signatureContext.ElementTypeInformation.FirstOrDefault(e => e.IsManagedReturnPosition);
if (managedReturnInfo is { MarshallingAttributeInfo: UnmanagedBlittableMarshallingInfo, ManagedType: ValueTypeInfo valueType }
&& IsHResultLikeType(valueType))
{
generatorDiagnostics.ReportDiagnostic(DiagnosticInfo.Create(
GeneratorDiagnostics.HResultTypeWillBeTreatedAsStruct,
symbol.Locations[0],
ImmutableDictionary<string, string>.Empty.Add(GeneratorDiagnosticProperties.AddMarshalAsAttribute, "Error"),
valueType.DiagnosticFormattedName));
}
}

var direction = GetDirectionFromOptions(generatedComInterfaceAttributeData.Options);

// Ensure the size of collections are known at marshal / unmarshal in time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<Compile Include="..\Common\OperationExtensions.cs" Link="Common\OperationExtensions.cs" />
<Compile Include="..\Common\ConvertToSourceGeneratedInteropFixer.cs" Link="Common\ConvertToSourceGeneratedInteropFixer.cs" />
<Compile Include="..\Common\FixAllContextExtensions.cs" Link="Common\FixAllContextExtensions.cs" />
<Compile Include="$(CoreLibSharedDir)System\Index.cs" Link="Common\System\Index.cs" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ internal static class ComInterfaceGeneratorHelpers
InteropGenerationOptions interopGenerationOptions = new(UseMarshalType: true);
generatorFactory = new MarshalAsMarshallingGeneratorFactory(interopGenerationOptions, generatorFactory);

generatorFactory = new StructAsHResultMarshallerFactory(generatorFactory);

IMarshallingGeneratorFactory elementFactory = new AttributedMarshallingModelGeneratorFactory(
// Since the char type in an array will not be part of the P/Invoke signature, we can
// use the regular blittable marshaller in all cases.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,16 @@ public class Ids
DiagnosticSeverity.Info,
isEnabledByDefault: true);

/// <inheritdoc cref="SR.HResultTypeWillBeTreatedAsStructMessage"/>
public static readonly DiagnosticDescriptor HResultTypeWillBeTreatedAsStruct =
new DiagnosticDescriptor(
Ids.NotRecommendedGeneratedComInterfaceUsage,
GetResourceString(nameof(SR.HResultTypeWillBeTreatedAsStructTitle)),
GetResourceString(nameof(SR.HResultTypeWillBeTreatedAsStructMessage)),
Category,
DiagnosticSeverity.Info,
isEnabledByDefault: true);

/// <summary>
/// Report diagnostic for invalid configuration for string marshalling.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Microsoft.Interop
{
internal sealed class StructAsHResultMarshallerFactory : IMarshallingGeneratorFactory
{
private static readonly Marshaller s_marshaller = new();

private readonly IMarshallingGeneratorFactory _inner;

public StructAsHResultMarshallerFactory(IMarshallingGeneratorFactory inner)
{
_inner = inner;
}

public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context)
{
// Value type with MarshalAs(UnmanagedType.Error), to be marshalled as an unmanaged HRESULT.
if (info is { ManagedType: ValueTypeInfo, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.Error, _) })
{
return ResolvedGenerator.Resolved(s_marshaller);
}

return _inner.Create(info, context);
}

private sealed class Marshaller : IMarshallingGenerator
{
public ManagedTypeInfo AsNativeType(TypePositionInfo info) => SpecialTypeInfo.Int32;

public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
var (managed, unmanaged) = context.GetIdentifiers(info);

switch (context.CurrentStage)
{
case StubCodeContext.Stage.Marshal:
if (MarshallerHelpers.GetMarshalDirection(info, context) is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
{
// unmanaged = Unsafe.BitCast<managedType, int>(managed);
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(unmanaged),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.System_Runtime_CompilerServices_Unsafe),
GenericName(Identifier("BitCast"),
TypeArgumentList(
SeparatedList(
new[]
{
info.ManagedType.Syntax,
AsNativeType(info).Syntax
})))),
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(managed)))))));
}
break;
case StubCodeContext.Stage.Unmarshal:
if (MarshallerHelpers.GetMarshalDirection(info, context) is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
{
// managed = Unsafe.BitCast<int, managedType>(unmanaged);
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managed),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.System_Runtime_CompilerServices_Unsafe),
GenericName(Identifier("BitCast"),
TypeArgumentList(
SeparatedList(
new[]
{
AsNativeType(info).Syntax,
info.ManagedType.Syntax
})))),
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(unmanaged)))))));
}
break;
default:
break;
}
}

public SignatureBehavior GetNativeSignatureBehavior(TypePositionInfo info)
{
return info.IsByRef ? SignatureBehavior.PointerToNativeType : SignatureBehavior.NativeType;
}

public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, StubCodeContext context)
{
if (info.IsByRef)
{
return ValueBoundaryBehavior.AddressOfNativeIdentifier;
}

return ValueBoundaryBehavior.NativeIdentifier;
}

public bool IsSupported(TargetFramework target, Version version) => target == TargetFramework.Net && version.Major >= 8;

public ByValueMarshalKindSupport SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, TypePositionInfo info, StubCodeContext context, out GeneratorDiagnostic? diagnostic)
=> ByValueMarshalKindSupportDescriptor.Default.GetSupport(marshalKind, info, context, out diagnostic);

public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Runtime.InteropServices;

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
Expand Down Expand Up @@ -259,5 +260,34 @@ static SyntaxNode GenerateMarshalAsUnmanagedTypeBoolAttribute(SyntaxGenerator ge
generator.DottedName(TypeNames.System_Runtime_InteropServices_UnmanagedType),
generator.IdentifierName(unmanagedTypeMemberIdentifier))));
}

protected static SyntaxNode AddHResultStructAsErrorMarshalling(SyntaxGenerator generator, IMethodSymbol methodSymbol, SyntaxNode generatedDeclaration)
{
if (methodSymbol.ReturnType is { TypeKind: TypeKind.Struct }
&& IsHResultLikeType(methodSymbol.ReturnType)
&& !methodSymbol.GetReturnTypeAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.System_Runtime_InteropServices_MarshalAsAttribute))
{
generatedDeclaration = generator.AddReturnAttributes(generatedDeclaration,
GeneratedMarshalAsUnmanagedTypeErrorAttribute(generator));
}

return generatedDeclaration;


static bool IsHResultLikeType(ITypeSymbol type)
{
string typeName = type.Name;
return typeName.Equals("hr", StringComparison.OrdinalIgnoreCase)
|| typeName.Equals("hresult", StringComparison.OrdinalIgnoreCase);
}

// MarshalAs(UnmanagedType.Error)
static SyntaxNode GeneratedMarshalAsUnmanagedTypeErrorAttribute(SyntaxGenerator generator)
=> generator.Attribute(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute,
generator.AttributeArgument(
generator.MemberAccessExpression(
generator.DottedName(TypeNames.System_Runtime_InteropServices_UnmanagedType),
generator.IdentifierName(nameof(UnmanagedType.Error)))));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -883,4 +883,10 @@
<data name="ComMethodReturningIntWillBeOutParameterTitle" xml:space="preserve">
<value>The return value in the managed definition will be converted to an additional 'out' parameter at the end of the parameter list when calling the unmanaged COM method.</value>
</data>
<data name="HResultTypeWillBeTreatedAsStructMessage" xml:space="preserve">
<value>The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method.</value>
</data>
<data name="HResultTypeWillBeTreatedAsStructTitle" xml:space="preserve">
<value>This type will be treated as a struct in the native signature, not as a native HRESULT</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,16 @@
<target state="translated">Poskytnutý graf obsahuje cykly a nelze ho řadit topologicky.</target>
<note />
</trans-unit>
<trans-unit id="HResultTypeWillBeTreatedAsStructMessage">
<source>The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method.</source>
<target state="new">The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method.</target>
<note />
</trans-unit>
<trans-unit id="HResultTypeWillBeTreatedAsStructTitle">
<source>This type will be treated as a struct in the native signature, not as a native HRESULT</source>
<target state="new">This type will be treated as a struct in the native signature, not as a native HRESULT</target>
<note />
</trans-unit>
<trans-unit id="InAttributeNotSupportedWithoutOutBlittableArray">
<source>The '[In]' attribute is not supported unless the '[Out]' attribute is also used. Blittable arrays cannot be marshalled as '[In]' only.</source>
<target state="translated">Atribut [In] není podporován, pokud není použit také atribut [Out]. Blittable arrays nelze zařadit pouze jako [In].</target>
Expand Down
Loading

0 comments on commit d303781

Please sign in to comment.