Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for default implementation of static virtuals with method constraints #89061

Merged
merged 8 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/coreclr/inc/enum_class_flags.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#ifndef ENUM_CLASS_FLAGS_OPERATORS
#define ENUM_CLASS_FLAGS_OPERATORS

template <typename T>
inline auto operator& (T left, T right) -> decltype(T::support_use_as_flags)
{
return static_cast<T>(static_cast<int>(left) & static_cast<int>(right));
}

template <typename T>
inline auto operator| (T left, T right) -> decltype(T::support_use_as_flags)
{
return static_cast<T>(static_cast<int>(left) | static_cast<int>(right));
}

template <typename T>
inline auto operator^ (T left, T right) -> decltype(T::support_use_as_flags)
{
return static_cast<T>(static_cast<int>(left) ^ static_cast<int>(right));
}

template <typename T>
inline auto operator~ (T value) -> decltype(T::support_use_as_flags)
{
return static_cast<T>(~static_cast<int>(value));
}

template <typename T>
inline auto operator |= (T& left, T right) -> const decltype(T::support_use_as_flags)&
{
left = left | right;
return left;
}

template <typename T>
inline auto operator &= (T& left, T right) -> const decltype(T::support_use_as_flags)&
{
left = left & right;
return left;
}

template <typename T>
inline auto operator ^= (T& left, T right) -> const decltype(T::support_use_as_flags)&
{
left = left ^ right;
return left;
}

#endif /* ENUM_CLASS_FLAGS_OPERATORS */
6 changes: 3 additions & 3 deletions src/coreclr/vm/genericdict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1144,9 +1144,9 @@ Dictionary::PopulateEntry(
pResolvedMD = constraintType.GetMethodTable()->ResolveVirtualStaticMethod(
ownerType.GetMethodTable(),
pMethod,
/* allowNullResult */ TRUE,
/* verifyImplemented */ FALSE,
/* allowVariantMatches */ TRUE,
ResolveVirtualStaticMethodFlags::AllowNullResult |
ResolveVirtualStaticMethodFlags::AllowVariantMatches |
ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc,
&uniqueResolution);

// If we couldn't get an exact result, fall back to using a stub to make the exact function call
Expand Down
94 changes: 68 additions & 26 deletions src/coreclr/vm/methodtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6210,22 +6210,26 @@ MethodTable::FindDispatchImpl(

// Try exact match first
MethodDesc *pDefaultMethod = NULL;

FindDefaultInterfaceImplementationFlags flags = FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc;
if (throwOnConflict)
flags = flags | FindDefaultInterfaceImplementationFlags::ThrowOnConflict;

BOOL foundDefaultInterfaceImplementation = FindDefaultInterfaceImplementation(
pIfcMD, // the interface method being resolved
pIfcMT, // the interface being resolved
&pDefaultMethod,
FALSE, // allowVariance
throwOnConflict);
flags);

// If there's no exact match, try a variant match
if (!foundDefaultInterfaceImplementation && pIfcMT->HasVariance())
{
flags = flags | FindDefaultInterfaceImplementationFlags::AllowVariance;
foundDefaultInterfaceImplementation = FindDefaultInterfaceImplementation(
pIfcMD, // the interface method being resolved
pIfcMT, // the interface being resolved
&pDefaultMethod,
TRUE, // allowVariance
throwOnConflict);
flags);
}

if (foundDefaultInterfaceImplementation)
Expand Down Expand Up @@ -6324,10 +6328,13 @@ namespace
MethodTable *pMT,
MethodDesc *interfaceMD,
MethodTable *interfaceMT,
BOOL allowVariance,
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
MethodDesc **candidateMD,
ClassLoadLevel level)
{
bool allowVariance = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::AllowVariance) != FindDefaultInterfaceImplementationFlags::None;
bool instantiateMethodInstantiation = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc) != FindDefaultInterfaceImplementationFlags::None;

*candidateMD = NULL;

MethodDesc *candidateMaybe = NULL;
Expand Down Expand Up @@ -6418,11 +6425,20 @@ namespace
else
{
// Static virtual methods don't record MethodImpl slots so they need special handling
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags = ResolveVirtualStaticMethodFlags::None;
if (allowVariance)
{
resolveVirtualStaticMethodFlags |= ResolveVirtualStaticMethodFlags::AllowVariantMatches;
}
if (instantiateMethodInstantiation)
{
resolveVirtualStaticMethodFlags |= ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc;
}

candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType(
interfaceMT,
interfaceMD,
/* verifyImplemented */ FALSE,
/* allowVariance */ allowVariance,
resolveVirtualStaticMethodFlags,
/* level */ level);
}
}
Expand Down Expand Up @@ -6461,8 +6477,7 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
MethodDesc *pInterfaceMD,
MethodTable *pInterfaceMT,
MethodDesc **ppDefaultMethod,
BOOL allowVariance,
BOOL throwOnConflict,
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
ClassLoadLevel level
)
{
Expand All @@ -6478,12 +6493,13 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
} CONTRACT_END;

#ifdef FEATURE_DEFAULT_INTERFACES
bool allowVariance = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::AllowVariance) != FindDefaultInterfaceImplementationFlags::None;
CQuickArray<MatchCandidate> candidates;
unsigned candidatesCount = 0;

// Check the current method table itself
MethodDesc *candidateMaybe = NULL;
if (IsInterface() && TryGetCandidateImplementation(this, pInterfaceMD, pInterfaceMT, allowVariance, &candidateMaybe, level))
if (IsInterface() && TryGetCandidateImplementation(this, pInterfaceMD, pInterfaceMT, findDefaultImplementationFlags, &candidateMaybe, level))
{
_ASSERTE(candidateMaybe != NULL);

Expand Down Expand Up @@ -6523,7 +6539,7 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
MethodTable *pCurMT = it.GetInterface(pMT, level);

MethodDesc *pCurMD = NULL;
if (TryGetCandidateImplementation(pCurMT, pInterfaceMD, pInterfaceMT, allowVariance, &pCurMD, level))
if (TryGetCandidateImplementation(pCurMT, pInterfaceMD, pInterfaceMT, findDefaultImplementationFlags, &pCurMD, level))
{
//
// Found a match. But is it a more specific match (we want most specific interfaces)
Expand Down Expand Up @@ -6619,6 +6635,8 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
}
else if (pBestCandidateMT != candidates[i].pMT)
{
bool throwOnConflict = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::ThrowOnConflict) != FindDefaultInterfaceImplementationFlags::None;

if (throwOnConflict)
ThrowExceptionForConflictingOverride(this, pInterfaceMT, pInterfaceMD);

Expand Down Expand Up @@ -8875,12 +8893,15 @@ MethodDesc *
MethodTable::ResolveVirtualStaticMethod(
MethodTable* pInterfaceType,
MethodDesc* pInterfaceMD,
BOOL allowNullResult,
BOOL verifyImplemented,
BOOL allowVariantMatches,
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags,
BOOL* uniqueResolution,
ClassLoadLevel level)
{
bool verifyImplemented = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::VerifyImplemented) != ResolveVirtualStaticMethodFlags::None;
bool allowVariantMatches = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowVariantMatches) != ResolveVirtualStaticMethodFlags::None;
bool instantiateMethodParameters = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc) != ResolveVirtualStaticMethodFlags::None;
bool allowNullResult = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowNullResult) != ResolveVirtualStaticMethodFlags::None;

if (uniqueResolution != nullptr)
{
*uniqueResolution = TRUE;
Expand Down Expand Up @@ -8912,7 +8933,7 @@ MethodTable::ResolveVirtualStaticMethod(
// Search for match on a per-level in the type hierarchy
for (MethodTable* pMT = this; pMT != nullptr; pMT = pMT->GetParentMethodTable())
{
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, resolveVirtualStaticMethodFlags & ~ResolveVirtualStaticMethodFlags::AllowVariantMatches, level);
if (pMD != nullptr)
{
return pMD;
Expand Down Expand Up @@ -8956,7 +8977,7 @@ MethodTable::ResolveVirtualStaticMethod(
{
// Variant or equivalent matching interface found
// Attempt to resolve on variance matched interface
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, resolveVirtualStaticMethodFlags & ~ResolveVirtualStaticMethodFlags::AllowVariantMatches, level);
if (pMD != nullptr)
{
return pMD;
Expand All @@ -8970,12 +8991,25 @@ MethodTable::ResolveVirtualStaticMethod(
BOOL allowVariantMatchInDefaultImplementationLookup = FALSE;
do
{
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags = FindDefaultInterfaceImplementationFlags::None;
if (allowVariantMatchInDefaultImplementationLookup)
{
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::AllowVariance;
}
if (uniqueResolution == nullptr)
{
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::ThrowOnConflict;
}
if (instantiateMethodParameters)
{
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc;
}

BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation(
pInterfaceMD,
pInterfaceType,
&pMDDefaultImpl,
/* allowVariance */ allowVariantMatchInDefaultImplementationLookup,
/* throwOnConflict */ uniqueResolution == nullptr,
findDefaultImplementationFlags,
level);
if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr)))
{
Expand Down Expand Up @@ -9018,8 +9052,12 @@ MethodTable::ResolveVirtualStaticMethod(
// Try to locate the appropriate MethodImpl matching a given interface static virtual method.
// Returns nullptr on failure.
MethodDesc*
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level)
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level)
{
bool instantiateMethodParameters = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc) != ResolveVirtualStaticMethodFlags::None;
bool allowVariance = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowVariantMatches) != ResolveVirtualStaticMethodFlags::None;
bool verifyImplemented = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::VerifyImplemented) != ResolveVirtualStaticMethodFlags::None;

HRESULT hr = S_OK;
IMDInternalImport* pMDInternalImport = GetMDImport();
HENUMInternalMethodImplHolder hEnumMethodImpl(pMDInternalImport);
Expand Down Expand Up @@ -9148,7 +9186,7 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
COMPlusThrow(kTypeLoadException, E_FAIL);
}

if (!verifyImplemented)
if (!verifyImplemented && instantiateMethodParameters)
{
pMethodImpl = pMethodImpl->FindOrCreateAssociatedMethodDesc(
pMethodImpl,
Expand Down Expand Up @@ -9202,9 +9240,7 @@ MethodTable::VerifyThatAllVirtualStaticMethodsAreImplemented()
!ResolveVirtualStaticMethod(
pInterfaceMT,
pMD,
/* allowNullResult */ TRUE,
/* verifyImplemented */ TRUE,
/* allowVariantMatches */ FALSE,
ResolveVirtualStaticMethodFlags::AllowNullResult | ResolveVirtualStaticMethodFlags::VerifyImplemented,
/* uniqueResolution */ &uniqueResolution,
/* level */ CLASS_LOAD_EXACTPARENTS)))
{
Expand Down Expand Up @@ -9240,12 +9276,18 @@ MethodTable::TryResolveConstraintMethodApprox(
_ASSERTE(!thInterfaceType.IsTypeDesc());
_ASSERTE(thInterfaceType.IsInterface());
BOOL uniqueResolution;

ResolveVirtualStaticMethodFlags flags = ResolveVirtualStaticMethodFlags::AllowVariantMatches
| ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc;
if (pfForceUseRuntimeLookup != NULL)
{
flags |= ResolveVirtualStaticMethodFlags::AllowNullResult;
}

MethodDesc *result = ResolveVirtualStaticMethod(
thInterfaceType.GetMethodTable(),
pInterfaceMD,
/* allowNullResult */pfForceUseRuntimeLookup != NULL,
/* verifyImplemented */ FALSE,
/* allowVariantMatches */ TRUE,
flags,
&uniqueResolution);
if (result == NULL || !uniqueResolution)
{
Expand Down
33 changes: 26 additions & 7 deletions src/coreclr/vm/methodtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "contractimpl.h"
#include "generics.h"
#include "gcinfotypes.h"
#include "enum_class_flags.h"

/*
* Forward Declarations
Expand Down Expand Up @@ -63,6 +64,28 @@ class ClassFactoryBase;
class ArgDestination;
enum class WellKnownAttribute : DWORD;

enum class ResolveVirtualStaticMethodFlags
{
None = 0,
AllowNullResult = 1,
VerifyImplemented = 2,
AllowVariantMatches = 4,
InstantiateResultOverFinalMethodDesc = 8,

support_use_as_flags // Enable the template functions in enum_class_flags.h
};


enum class FindDefaultInterfaceImplementationFlags
{
None,
AllowVariance = 1,
ThrowOnConflict = 2,
InstantiateFoundMethodDesc = 4,

support_use_as_flags // Enable the template functions in enum_class_flags.h
};

//============================================================================
// This is the in-memory structure of a class and it will evolve.
//============================================================================
Expand Down Expand Up @@ -2084,7 +2107,6 @@ class MethodTable
MethodDesc *GetMethodDescForComInterfaceMethod(MethodDesc *pItfMD, bool fNullOk);
#endif // FEATURE_COMINTEROP


// Resolve virtual static interface method pInterfaceMD on this type.
//
// Specify allowNullResult to return NULL instead of throwing if the there is no implementation
Expand All @@ -2096,9 +2118,7 @@ class MethodTable
MethodDesc *ResolveVirtualStaticMethod(
MethodTable* pInterfaceType,
MethodDesc* pInterfaceMD,
BOOL allowNullResult,
BOOL verifyImplemented = FALSE,
BOOL allowVariantMatches = TRUE,
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags,
BOOL *uniqueResolution = NULL,
ClassLoadLevel level = CLASS_LOADED);

Expand Down Expand Up @@ -2178,8 +2198,7 @@ class MethodTable
MethodDesc *pInterfaceMD,
MethodTable *pObjectMT,
MethodDesc **ppDefaultMethod,
BOOL allowVariance,
BOOL throwOnConflict,
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
ClassLoadLevel level = CLASS_LOADED);
#endif // DACCESS_COMPILE

Expand Down Expand Up @@ -2219,7 +2238,7 @@ class MethodTable

// Try to resolve a given static virtual method override on this type. Return nullptr
// when not found.
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level);
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level);

public:
static MethodDesc *MapMethodDeclToMethodImpl(MethodDesc *pMDDecl);
Expand Down
5 changes: 2 additions & 3 deletions src/coreclr/vm/runtimehandles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1102,9 +1102,8 @@ extern "C" MethodDesc* QCALLTYPE RuntimeTypeHandle_GetInterfaceMethodImplementat
pResult = typeHandle.GetMethodTable()->ResolveVirtualStaticMethod(
thOwnerOfMD.GetMethodTable(),
pMD,
/* allowNullResult */ TRUE,
/* verifyImplemented*/ FALSE,
/* allowVariantMatches */ TRUE);
ResolveVirtualStaticMethodFlags::AllowNullResult |
ResolveVirtualStaticMethodFlags::AllowVariantMatches);
}
else
{
Expand Down
5 changes: 3 additions & 2 deletions src/coreclr/vm/typedesc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1611,8 +1611,9 @@ BOOL TypeVarTypeDesc::SatisfiesConstraints(SigTypeContext *pTypeContextOfConstra
if (pMD->IsVirtual() &&
pMD->IsStatic() &&
(pMD->IsAbstract() && !thElem.AsMethodTable()->ResolveVirtualStaticMethod(
pInterfaceMT, pMD, /* allowNullResult */ TRUE, /* verifyImplemented */ TRUE,
/*allowVariantMatches*/ TRUE, /*uniqueResolution*/ NULL, CLASS_DEPENDENCIES_LOADED)))
pInterfaceMT, pMD,
ResolveVirtualStaticMethodFlags::AllowNullResult | ResolveVirtualStaticMethodFlags::VerifyImplemented | ResolveVirtualStaticMethodFlags::AllowVariantMatches,
/*uniqueResolution*/ NULL, CLASS_DEPENDENCIES_LOADED)))
{
virtualStaticResolutionCheckFailed = true;
break;
Expand Down
Loading
Loading