Skip to content

Commit

Permalink
GROOVY-11301, GROOVY-11365: SC: method reference to private or protected
Browse files Browse the repository at this point in the history
of outer or upper class (using access bridge)
  • Loading branch information
eric-milles committed Sep 18, 2024
1 parent 92b9bf4 commit 754d819
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.codehaus.groovy.ast.expr.MethodCallExpression;
import org.codehaus.groovy.ast.expr.MethodReferenceExpression;
import org.codehaus.groovy.ast.tools.GeneralUtils;
import org.codehaus.groovy.classgen.AsmClassGenerator;
import org.codehaus.groovy.classgen.asm.BytecodeHelper;
import org.codehaus.groovy.classgen.asm.MethodReferenceExpressionWriter;
import org.codehaus.groovy.classgen.asm.WriterController;
Expand Down Expand Up @@ -111,12 +112,21 @@ public void writeMethodReferenceExpression(final MethodReferenceExpression metho
} else {
// TODO: move the findMethodRefMethod and checking to StaticTypeCheckingVisitor
methodRefMethod = findMethodRefMethod(methodRefName, parametersWithExactType, typeOrTargetRef, typeOrTargetRefType);
if (methodReferenceExpression.getNodeMetaData(StaticTypesMarker.PV_METHODS_ACCESS) != null) { // GROOVY-11301, GROOVY-11365: access bridge indicated
Map<MethodNode,MethodNode> bridgeMethods = typeOrTargetRefType.redirect().getNodeMetaData(StaticCompilationMetadataKeys.PRIVATE_BRIDGE_METHODS);
if (bridgeMethods != null) methodRefMethod = bridgeMethods.getOrDefault(methodRefMethod, methodRefMethod); // bridge may not have been generated
}
}

validate(methodReferenceExpression, typeOrTargetRefType, methodRefName, methodRefMethod, parametersWithExactType,
resolveClassNodeGenerics(extractPlaceholders(functionalType), null, abstractMethod.getReturnType()));

if (isExtensionMethod(methodRefMethod)) {
if (isBridgeMethod(methodRefMethod)) {
targetIsArgument = true; // GROOVY-11301, GROOVY-11365
if (isClassExpression) { // method expects an instance argument
methodRefMethod = addSyntheticMethodForDGSM(methodRefMethod);
}
} else if (isExtensionMethod(methodRefMethod)) {
ExtensionMethodNode extensionMethodNode = (ExtensionMethodNode) methodRefMethod;
methodRefMethod = extensionMethodNode.getExtensionMethodNode();
boolean isStatic = extensionMethodNode.isStaticExtension();
Expand Down Expand Up @@ -208,6 +218,8 @@ private void validate(final MethodReferenceExpression methodReference, final Cla
addFatalError(error, methodReference);
} else if (methodNode.isVoidMethod() && !ClassHelper.isPrimitiveVoid(samReturnType)) {
addFatalError("Invalid return type: void is not convertible to " + samReturnType.getText(), methodReference);
} else if (!AsmClassGenerator.isMemberDirectlyAccessible(methodNode.getModifiers(), methodNode.getDeclaringClass(), controller.getClassNode())) {
addFatalError("Cannot access method: " + methodName + " of class: " + methodNode.getDeclaringClass().getText(), methodReference); // GROOVY-11365
} else if (samParameters.length > 0 && isTypeReferringInstanceMethod(methodReference.getExpression(), methodNode) && !isAssignableTo(samParameters[0].getType(), targetType)) {
throw new RuntimeParserException("Invalid receiver type: " + samParameters[0].getType().getText() + " is not compatible with " + targetType.getText(), methodReference.getExpression());
}
Expand Down Expand Up @@ -422,6 +434,11 @@ private void addFatalError(final String msg, final ASTNode node) {

//--------------------------------------------------------------------------

private static boolean isBridgeMethod(final MethodNode mn) {
int staticSynthetic = Opcodes.ACC_STATIC | Opcodes.ACC_SYNTHETIC;
return ((mn.getModifiers() & staticSynthetic) == staticSynthetic) && mn.getName().startsWith("access$");
}

private static boolean isConstructorReference(final String name) {
return "new".equals(name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,44 +591,40 @@ private static void addPrivateFieldOrMethodAccess(final Expression source, final
}

/**
* Checks for private field access from inner or outer class.
* Checks for private field access from closure or nestmate.
*/
private void checkOrMarkPrivateAccess(final Expression source, final FieldNode fn, final boolean lhsOfAssignment) {
if (fn != null && fn.isPrivate() && !fn.isSynthetic()) {
ClassNode declaringClass = fn.getDeclaringClass();
ClassNode enclosingClass = typeCheckingContext.getEnclosingClassNode();
if (declaringClass == enclosingClass && typeCheckingContext.getEnclosingClosure() == null) return;
if (declaringClass == enclosingClass || getOutermost(declaringClass) == getOutermost(enclosingClass)) {
if (declaringClass == enclosingClass ? typeCheckingContext.getEnclosingClosure() != null : getOutermost(declaringClass) == getOutermost(enclosingClass)) {
StaticTypesMarker accessKind = lhsOfAssignment ? PV_FIELDS_MUTATION : PV_FIELDS_ACCESS;
addPrivateFieldOrMethodAccess(source, declaringClass, accessKind, fn);
}
}
}

/**
* Checks for private method call from inner or outer class.
* Checks for private or protected method access from closure or nestmate.
*/
private void checkOrMarkPrivateAccess(final Expression source, final MethodNode mn) {
ClassNode declaringClass = mn.getDeclaringClass();
ClassNode enclosingClassNode = typeCheckingContext.getEnclosingClassNode();
if (declaringClass != enclosingClassNode || typeCheckingContext.getEnclosingClosure() != null) {
int mods = mn.getModifiers();
boolean sameModule = declaringClass.getModule() == enclosingClassNode.getModule();
String packageName = declaringClass.getPackageName();
if (packageName == null) {
packageName = "";
}
if (Modifier.isPrivate(mods) && sameModule) {
ClassNode enclosingClass = typeCheckingContext.getEnclosingClassNode();
if (declaringClass != enclosingClass || typeCheckingContext.getEnclosingClosure() != null) {
if (mn.isPrivate()
&& declaringClass.getModule() == enclosingClass.getModule()) {
addPrivateFieldOrMethodAccess(source, declaringClass, PV_METHODS_ACCESS, mn);
} else if (Modifier.isProtected(mods) && !packageName.equals(enclosingClassNode.getPackageName())
&& !implementsInterfaceOrIsSubclassOf(enclosingClassNode, declaringClass)) {
ClassNode cn = enclosingClassNode;
while ((cn = cn.getOuterClass()) != null) {
} else if (mn.isProtected()
&& !inSamePackage(enclosingClass, declaringClass)
&& (!implementsInterfaceOrIsSubclassOf(enclosingClass, declaringClass)
|| typeCheckingContext.getEnclosingClosure() != null)) {
ClassNode cn = enclosingClass;
do {
if (implementsInterfaceOrIsSubclassOf(cn, declaringClass)) {
addPrivateFieldOrMethodAccess(source, cn, PV_METHODS_ACCESS, mn);
break;
}
}
} while ((cn = cn.getOuterClass()) != null);
}
}
}
Expand Down Expand Up @@ -2651,7 +2647,8 @@ && isStringType(getType(nameExpr))) {

ClassNode ownerType = receiverType;
candidates.stream()
.map(candidate -> {
.peek(candidate -> checkOrMarkPrivateAccess(expression, candidate)) // GROOVY-11365
.map (candidate -> {
ClassNode returnType = candidate.getReturnType();
if (!candidate.isStatic() && GenericsUtils.hasUnresolvedGenerics(returnType)) {
Map<GenericsTypeName, GenericsType> spec = new HashMap<>(); // GROOVY-11364
Expand Down
36 changes: 29 additions & 7 deletions src/test/groovy/transform/stc/MethodReferenceTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -1526,22 +1526,44 @@ final class MethodReferenceTest {

@Test // GROOVY-11301
void testInnerClassPrivateMethodReference() {
def script = '''
assertScript shell, '''
@CompileStatic
class C {
static class D {
private static String m() { 'D' }
}
@CompileStatic
static main(args) {
Supplier<String> str = D::m
assert str.get() == 'D'
}
}
'''
if (Runtime.version().feature() < 15) {
shouldFail(shell, IllegalAccessError, script)
} else {
assertScript(shell, script)
}
}

@Test // GROOVY-11365
void testInnerClassProtectedMethodReference() {
assertScript shell, '''package p
abstract class A<E> {
protected E op(E e) { result = e }
protected E result
}
true
'''
assertScript shell, '''
@CompileStatic
class C extends p.A<Integer> {
void test() {
def runnable = { ->
Consumer<Integer> consumer = this::op
consumer.accept(42) // IllegalAccessError
}
runnable.run()
assert result == Integer.valueOf(42)
}
}
new C().test()
'''
}
}

0 comments on commit 754d819

Please sign in to comment.