Skip to content

Commit

Permalink
Implement basic (recommended against) support for inheriting from COM…
Browse files Browse the repository at this point in the history
… interfaces defined in different assemblies (#105119)
  • Loading branch information
jkoritzinsky committed Aug 6, 2024
1 parent edbb2ba commit 1563fec
Show file tree
Hide file tree
Showing 25 changed files with 779 additions and 395 deletions.
20 changes: 15 additions & 5 deletions docs/project/list-of-diagnostics.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,21 @@ The diagnostic id values reserved for .NET Libraries analyzer warnings are `SYSL
| __`SYSLIB1222`__ | Constructor annotated with JsonConstructorAttribute is inaccessible. |
| __`SYSLIB1223`__ | Attributes deriving from JsonConverterAttribute are not supported by the source generator. |
| __`SYSLIB1224`__ | Types annotated with JsonSerializableAttribute must be classes deriving from JsonSerializerContext. |
| __`SYSLIB1225`__ | *`SYSLIB1220`-`SYSLIB229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1226`__ | *`SYSLIB1220`-`SYSLIB229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1227`__ | *`SYSLIB1220`-`SYSLIB229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1228`__ | *`SYSLIB1220`-`SYSLIB229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1229`__ | *`SYSLIB1220`-`SYSLIB229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1225`__ | *`SYSLIB1220`-`SYSLIB1229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1226`__ | *`SYSLIB1220`-`SYSLIB1229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1227`__ | *`SYSLIB1220`-`SYSLIB1229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1228`__ | *`SYSLIB1220`-`SYSLIB1229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1229`__ | *`SYSLIB1220`-`SYSLIB1229` reserved for System.Text.Json.SourceGeneration.* |
| __`SYSLIB1230`__ | Deriving from a `GeneratedComInterface`-attributed interface defined in another assembly is not supported. |
| __`SYSLIB1231`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |
| __`SYSLIB1232`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |
| __`SYSLIB1233`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |
| __`SYSLIB1234`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |
| __`SYSLIB1235`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |
| __`SYSLIB1236`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |
| __`SYSLIB1237`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |
| __`SYSLIB1238`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |
| __`SYSLIB1239`__ | *`SYSLIB1230`-`SYSLIB1239` reserved for Microsoft.Interop.ComInterfaceGenerator.* |

### Diagnostic Suppressions (`SYSLIBSUPPRESS****`)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ internal sealed record ComInterfaceContext
internal ComInterfaceInfo Info { get; init; }
internal ComInterfaceContext? Base { get; init; }
internal ComInterfaceOptions Options { get; init; }
internal bool IsExternallyDefined { get; init; }

private ComInterfaceContext(ComInterfaceInfo info, ComInterfaceContext? @base, ComInterfaceOptions options)
{
Expand Down Expand Up @@ -51,7 +52,7 @@ DiagnosticOr<ComInterfaceContext> AddContext(ComInterfaceInfo iface)

if (iface.BaseInterfaceKey is null)
{
var baselessCtx = DiagnosticOr<ComInterfaceContext>.From(new ComInterfaceContext(iface, null, iface.Options));
var baselessCtx = DiagnosticOr<ComInterfaceContext>.From(new ComInterfaceContext(iface, null, iface.Options) { IsExternallyDefined = iface.IsExternallyDefined });
nameToContextCache[iface.ThisInterfaceKey] = baselessCtx;
return baselessCtx;
}
Expand All @@ -75,7 +76,7 @@ DiagnosticOr<ComInterfaceContext> AddContext(ComInterfaceInfo iface)
}
DiagnosticOr<ComInterfaceContext> baseContext = baseCachedValue ?? baseReturnedValue;
Debug.Assert(baseContext.HasValue);
var ctx = DiagnosticOr<ComInterfaceContext>.From(new ComInterfaceContext(iface, baseContext.Value, iface.Options));
var ctx = DiagnosticOr<ComInterfaceContext>.From(new ComInterfaceContext(iface, baseContext.Value, iface.Options) { IsExternallyDefined = iface.IsExternallyDefined });
nameToContextCache[iface.ThisInterfaceKey] = ctx;
return ctx;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
return ComInterfaceInfo.From(data.Left.Symbol, data.Left.Syntax, data.Right, ct);
});
var interfaceSymbolsWithoutDiagnostics = context.FilterAndReportDiagnostics(interfaceSymbolOrDiagnostics);
var interfaceSymbolsToGenerateWithoutDiagnostics = context.FilterAndReportDiagnostics(interfaceSymbolOrDiagnostics);

var externalInterfaceSymbols = attributedInterfaces.SelectMany(static (data, ct) =>
{
return ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(data.Symbol);
});

var interfaceSymbolsWithoutDiagnostics = interfaceSymbolsToGenerateWithoutDiagnostics.Concat(externalInterfaceSymbols);

var interfaceContextsOrDiagnostics = interfaceSymbolsWithoutDiagnostics
.Select((data, ct) => data.InterfaceInfo!)
Expand Down Expand Up @@ -76,7 +83,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.SelectMany(static (data, ct) =>
{
return ComMethodContext.CalculateAllMethods(data, ct);
});
})
// Now that we've determined method offsets, we can remove all externally defined methods.
// We'll also filter out methods originally declared on externally defined base interfaces
// as we may not be able to emit them into our assembly.
.Where(context => !context.Method.OriginalDeclaringInterface.IsExternallyDefined);

// Now that we've determined method offsets, we can remove all externally defined interfaces.
var interfaceContextsToGenerate = interfaceContexts.Where(context => !context.IsExternallyDefined);

// A dictionary isn't incremental, but it will have symbols, so it will never be incremental anyway.
var methodInfoToSymbolMap = methodInfoAndSymbolGroupedByInterface
Expand All @@ -97,7 +111,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

var interfaceAndMethodsContexts = comMethodContexts
.Collect()
.Combine(interfaceContexts.Collect())
.Combine(interfaceContextsToGenerate.Collect())
.SelectMany((data, ct) => GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));

// Generate the code for the managed-to-unmanaged stubs.
Expand All @@ -120,7 +134,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics).Union(data.DeclaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics))));

// Generate the native interface metadata for each [GeneratedComInterface]-attributed interface.
var nativeInterfaceInformation = interfaceContexts
var nativeInterfaceInformation = interfaceContextsToGenerate
.Select(static (data, ct) => data.Info)
.Select(GenerateInterfaceInformation)
.WithTrackingName(StepNames.GenerateInterfaceInformation)
Expand Down Expand Up @@ -150,14 +164,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.WithComparer(SyntaxEquivalentComparer.Instance)
.SelectNormalized();

var iUnknownDerivedAttributeApplication = interfaceContexts
var iUnknownDerivedAttributeApplication = interfaceContextsToGenerate
.Select(static (data, ct) => data.Info)
.Select(GenerateIUnknownDerivedAttributeApplication)
.WithTrackingName(StepNames.GenerateIUnknownDerivedAttribute)
.WithComparer(SyntaxEquivalentComparer.Instance)
.SelectNormalized();

var filesToGenerate = interfaceContexts
var filesToGenerate = interfaceContextsToGenerate
.Zip(nativeInterfaceInformation)
.Zip(managedToNativeInterfaceImplementations)
.Zip(nativeToManagedVtableMethods)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using InterfaceInfo = (Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol);
using DiagnosticOrInterfaceInfo = Microsoft.Interop.DiagnosticOr<(Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol)>;
using System.Diagnostics;

namespace Microsoft.Interop
{
Expand All @@ -25,6 +28,7 @@ internal sealed record ComInterfaceInfo
public Guid InterfaceId { get; init; }
public ComInterfaceOptions Options { get; init; }
public Location DiagnosticLocation { get; init; }
public bool IsExternallyDefined { get; init; }

private ComInterfaceInfo(
ManagedTypeInfo type,
Expand Down Expand Up @@ -90,8 +94,8 @@ public static DiagnosticOrInterfaceInfo From(INamedTypeSymbol symbol, InterfaceD
if (!OptionsAreValid(symbol, syntax, interfaceAttributeData, baseAttributeData, out DiagnosticInfo? optionsDiagnostic))
return DiagnosticOrInterfaceInfo.From(optionsDiagnostic);

return DiagnosticOrInterfaceInfo.From(
(new ComInterfaceInfo(
InterfaceInfo info = (
new ComInterfaceInfo(
ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol),
symbol.ToDisplayString(),
baseSymbol?.ToDisplayString(),
Expand All @@ -101,7 +105,72 @@ public static DiagnosticOrInterfaceInfo From(INamedTypeSymbol symbol, InterfaceD
guid ?? Guid.Empty,
interfaceAttributeData.Options,
syntax.Identifier.GetLocation()),
symbol));
symbol);

// Now that we've validated all of our requirements, we will check for some non-blocking scenarios
// and emit diagnostics.
ImmutableArray<DiagnosticInfo>.Builder nonFatalDiagnostics = ImmutableArray.CreateBuilder<DiagnosticInfo>();

// If there is a base interface and it is defined in another assembly,
// warn the user that they are in a scenario that has pitfalls.
// We check that either the base interface symbol is defined in a non-source assembly (ie an assembly referenced as metadata)
// or if it is defined in a different source assembly (ie another C# project in the same solution when loaded in an IDE)
// as Roslyn can provide the symbol information in either shape to us depending on the scenario.
if (baseSymbol is not null
&& (baseSymbol.ContainingAssembly is not ISourceAssemblySymbol
|| (baseSymbol.ContainingAssembly is ISourceAssemblySymbol { Compilation: Compilation baseComp }
&& baseComp != env.Compilation)))
{
nonFatalDiagnostics.Add(DiagnosticInfo.Create(
GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly,
syntax.Identifier.GetLocation(),
symbol.ToDisplayString(),
baseSymbol.ToDisplayString()));
}

if (nonFatalDiagnostics.Count != 0)
{
// Report non-fatal diagnostics with the result.
return DiagnosticOrInterfaceInfo.From(info, nonFatalDiagnostics.ToArray());
}

// We have no non-fatal diagnostics, so return the result.
return DiagnosticOrInterfaceInfo.From(info);
}

public static ImmutableArray<InterfaceInfo> CreateInterfaceInfoForBaseInterfacesInOtherCompilations(
INamedTypeSymbol symbol)
{
if (!TryGetBaseComInterface(symbol, null, out INamedTypeSymbol? baseSymbol, out _) || baseSymbol is null)
return ImmutableArray<InterfaceInfo>.Empty;

if (SymbolEqualityComparer.Default.Equals(baseSymbol.ContainingAssembly, symbol.ContainingAssembly))
return ImmutableArray<InterfaceInfo>.Empty;

ImmutableArray<InterfaceInfo>.Builder builder = ImmutableArray.CreateBuilder<InterfaceInfo>();
while (baseSymbol is not null)
{
var thisSymbol = baseSymbol;
TryGetBaseComInterface(thisSymbol, null, out baseSymbol, out _);
var interfaceAttributeData = GeneratedComInterfaceCompilationData.GetAttributeDataFromInterfaceSymbol(thisSymbol);
builder.Add((
new ComInterfaceInfo(
ManagedTypeInfo.CreateTypeInfoForTypeSymbol(thisSymbol),
thisSymbol.ToDisplayString(),
baseSymbol?.ToDisplayString(),
null!,
default,
default,
Guid.Empty,
interfaceAttributeData.Options,
Location.None)
{
IsExternallyDefined = true
},
thisSymbol));
}

return builder.ToImmutable();
}

private static bool IsInPartialContext(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(false)] out DiagnosticInfo? diagnostic)
Expand Down Expand Up @@ -219,8 +288,9 @@ private static bool OptionsAreValid(
/// <summary>
/// Returns true if there is 0 or 1 base Com interfaces (i.e. the inheritance is valid), and returns false when there are 2 or more base Com interfaces and sets <paramref name="diagnostic"/>.
/// </summary>
private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax syntax, out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out DiagnosticInfo? diagnostic)
private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax? syntax, out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out DiagnosticInfo? diagnostic)
{
diagnostic = null;
baseComIface = null;
foreach (var implemented in comIface.Interfaces)
{
Expand All @@ -230,17 +300,23 @@ private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceD
{
if (baseComIface is not null)
{
diagnostic = DiagnosticInfo.Create(
GeneratorDiagnostics.MultipleComInterfaceBaseTypes,
syntax.Identifier.GetLocation(),
comIface.ToDisplayString());
// If we're inspecting an external symbol,
// we don't have syntax.
// In that case, don't report a diagnostic. One will be reported
// when building that symbol's compilation.
if (syntax is not null)
{
diagnostic = DiagnosticInfo.Create(
GeneratorDiagnostics.MultipleComInterfaceBaseTypes,
syntax.Identifier.GetLocation(),
comIface.ToDisplayString());
}
return false;
}
baseComIface = implemented;
}
}
}
diagnostic = null;
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ ImmutableArray<Builder> AddMethods(ComInterfaceContext iface, IEnumerable<ComMet
{
baseMethods = pair;
}

methods.AddRange(baseMethods);
startingIndex += baseMethods.Length;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ private ComMethodInfo(
ct.ThrowIfCancellationRequested();
Debug.Assert(method is { IsStatic: false, MethodKind: MethodKind.Ordinary });

// For externally-defined contexts, we only need minimal information about the method
// to ensure that we have the right offsets for inheriting vtable types.
// Skip all validation as that will be done when that type is compiled.
if (ifaceContext.IsExternallyDefined)
{
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From((
new ComMethodInfo(null!, method.Name, method.GetAttributes().Select(AttributeInfo.From).ToImmutableArray().ToSequenceEqual(), false),
method));
}

// We only support methods that are defined in the same partial interface definition as the
// [GeneratedComInterface] attribute.
// This restriction not only makes finding the syntax for a given method cheaper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class Ids
public const string AnalysisFailed = Prefix + "1093";
public const string BaseInterfaceFailedGeneration = Prefix + "1094";
public const string InvalidGeneratedComClassAttributeUsage = Prefix + "1095";
public const string BaseInterfaceDefinedInOtherAssembly = Prefix + "1230";
}

private const string Category = "ComInterfaceGenerator";
Expand Down Expand Up @@ -425,10 +426,10 @@ public class Ids
DiagnosticSeverity.Info,
isEnabledByDefault: true,
description: GetResourceString(nameof(SR.UnnecessaryMarshallingInfoDescription)),
customTags: new[]
{
customTags:
[
WellKnownDiagnosticTags.Unnecessary
});
]);

/// <inheritdoc cref="SR.UnnecessaryReturnMarshallingInfoMessage"/>
public static readonly DiagnosticDescriptor UnnecessaryReturnMarshallingInfo =
Expand All @@ -440,10 +441,10 @@ public class Ids
DiagnosticSeverity.Info,
isEnabledByDefault: true,
description: GetResourceString(nameof(SR.UnnecessaryMarshallingInfoDescription)),
customTags: new[]
{
customTags:
[
WellKnownDiagnosticTags.Unnecessary
});
]);

/// <inheritdoc cref="SR.SizeOfCollectionMustBeKnownAtMarshalTimeMessageOutParam"/>
public static readonly DiagnosticDescriptor SizeOfInCollectionMustBeDefinedAtCallOutParam =
Expand Down Expand Up @@ -496,6 +497,17 @@ public class Ids
isEnabledByDefault: true,
helpLinkUri: "aka.ms/GeneratedComInterfaceUsage");

/// <inheritdoc cref="SR.BaseInterfaceDefinedInOtherAssemblyMessage" />
public static readonly DiagnosticDescriptor BaseInterfaceDefinedInOtherAssembly =
new DiagnosticDescriptor(
Ids.BaseInterfaceDefinedInOtherAssembly,
GetResourceString(nameof(SR.BaseInterfaceDefinedInOtherAssemblyTitle)),
GetResourceString(nameof(SR.BaseInterfaceDefinedInOtherAssemblyMessage)),
Category,
DiagnosticSeverity.Warning,
isEnabledByDefault: true,
helpLinkUri: "aka.ms/GeneratedComInterfaceUsage");

/// <summary>
/// Report diagnostic for invalid configuration for string marshalling.
/// </summary>
Expand Down
Loading

0 comments on commit 1563fec

Please sign in to comment.