Skip to content

Commit

Permalink
Refactor BinaryUtils.kt
Browse files Browse the repository at this point in the history
so that
1. Cache is queried before reading class file contents.
2. Handles not found classes.
3. Skip on non-jvm platforms.
  • Loading branch information
ting-yuan committed Sep 16, 2024
1 parent f47c2b6 commit 41345b8
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import java.io.File
import java.util.jar.*

@RunWith(Parameterized::class)
class KMPImplementedIT(useKSP2: Boolean) {
class KMPImplementedIT(val useKSP2: Boolean) {
@Rule
@JvmField
val project: TemporaryTestProject = TemporaryTestProject("kmp", useKSP2 = useKSP2)
Expand Down Expand Up @@ -132,6 +132,33 @@ class KMPImplementedIT(useKSP2: Boolean) {
}
}

@Test
fun testDefaultArgumentsImpl() {
Assume.assumeFalse(System.getProperty("os.name").startsWith("Windows", ignoreCase = true))
// FIXME: KSP1
Assume.assumeTrue(useKSP2)
val gradleRunner = GradleRunner.create().withProjectDir(project.root)

val newSrc = File(project.root, "workload-wasm/src/wasmJsMain/kotlin/com/example/AnnoOnProperty.kt")
newSrc.appendText(
"""
@Target(AnnotationTarget.PROPERTY)
annotation class OnProperty
class AnnoOnProperty {
@OnProperty
val value: Int = 0
}
""".trimIndent()
)

gradleRunner.withArguments(
"--configuration-cache-problems=warn",
"clean",
":workload-wasm:build"
).build()
}

@Test
fun testJsErrorLog() {
Assume.assumeFalse(System.getProperty("os.name").startsWith("Windows", ignoreCase = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ import com.google.devtools.ksp.common.impl.*
import com.google.devtools.ksp.common.visitor.CollectAnnotatedSymbolsVisitor
import com.google.devtools.ksp.impl.symbol.java.KSAnnotationJavaImpl
import com.google.devtools.ksp.impl.symbol.kotlin.*
import com.google.devtools.ksp.impl.symbol.util.BinaryClassInfoCache
import com.google.devtools.ksp.impl.symbol.util.*
import com.google.devtools.ksp.impl.symbol.util.DeclarationOrdering
import com.google.devtools.ksp.impl.symbol.util.extractThrowsFromClassFile
import com.google.devtools.ksp.impl.symbol.util.hasAnnotation
import com.google.devtools.ksp.processing.KSBuiltIns
import com.google.devtools.ksp.processing.Resolver
import com.google.devtools.ksp.processing.SymbolProcessor
Expand All @@ -51,7 +49,6 @@ import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCliJavaFileManagerImpl
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.fir.types.isRaw
import org.jetbrains.kotlin.fir.types.typeContext
import org.jetbrains.kotlin.load.java.structure.impl.JavaClassImpl
import org.jetbrains.kotlin.load.kotlin.JvmPackagePartSource
import org.jetbrains.kotlin.load.kotlin.TypeMappingMode
import org.jetbrains.kotlin.load.kotlin.getOptimalModeForReturnType
Expand All @@ -60,6 +57,7 @@ import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.FqNameUnsafe
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.platform.jvm.JvmPlatform
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.org.objectweb.asm.Opcodes

Expand Down Expand Up @@ -247,11 +245,8 @@ class ResolverAAImpl(
val fileManager = instance.javaFileManager
val parentClass = this.findParentOfType<KSClassDeclaration>()
val classId = (parentClass as KSClassDeclarationImpl).ktClassOrObjectSymbol.classId!!
val virtualFileContent = analyze {
(fileManager.findClass(classId, analysisScope) as JavaClassImpl).virtualFile!!.contentsToByteArray()
}
BinaryClassInfoCache.getCached(classId, virtualFileContent)
.fieldAccFlags[this.simpleName.asString()] ?: 0
BinaryClassInfoCache.getCached(classId, fileManager)
?.fieldAccFlags?.get(this.simpleName.asString()) ?: 0
}
else -> throw IllegalStateException("this function expects only KOTLIN_LIB or JAVA_LIB")
}
Expand All @@ -263,11 +258,8 @@ class ResolverAAImpl(
val fileManager = instance.javaFileManager
val parentClass = this.findParentOfType<KSClassDeclaration>()
val classId = (parentClass as KSClassDeclarationImpl).ktClassOrObjectSymbol.classId!!
val virtualFileContent = analyze {
(fileManager.findClass(classId, analysisScope) as JavaClassImpl).virtualFile!!.contentsToByteArray()
}
BinaryClassInfoCache.getCached(classId, virtualFileContent)
.methodAccFlags[this.simpleName.asString() + jvmDesc] ?: 0
BinaryClassInfoCache.getCached(classId, fileManager)
?.methodAccFlags?.get(this.simpleName.asString() + jvmDesc) ?: 0
}
else -> throw IllegalStateException("this function expects only KOTLIN_LIB or JAVA_LIB")
}
Expand Down Expand Up @@ -337,6 +329,12 @@ class ResolverAAImpl(
if (container.origin != Origin.KOTLIN_LIB) {
return container.declarations
}

// TODO: multiplatform
if (!isJvm) {
return container.declarations
}

require(container is AbstractKSDeclarationImpl)
val fileManager = instance.javaFileManager
var parentClass: KSNode = container
Expand All @@ -357,9 +355,7 @@ class ResolverAAImpl(
}

val classId = parentClass.ktClassOrObjectSymbol.classId ?: return container.declarations
val virtualFile = analyze {
(fileManager.findClass(classId, analysisScope) as? JavaClassImpl)?.virtualFile
} ?: return container.declarations
val virtualFile = classId.getVirtualFile(fileManager) ?: return container.declarations
val kotlinClass = classBinaryCache.getKotlinBinaryClass(virtualFile) ?: return container.declarations
val declarationOrdering = DeclarationOrdering(kotlinClass)

Expand Down Expand Up @@ -429,10 +425,9 @@ class ResolverAAImpl(
Origin.KOTLIN_LIB, Origin.JAVA_LIB -> {
val fileManager = javaFileManager
val parentClass = accessor.findParentOfType<KSClassDeclaration>()
val classId = (parentClass as KSClassDeclarationImpl).ktClassOrObjectSymbol.classId!!
val virtualFileContent = analyze {
(fileManager.findClass(classId, analysisScope) as JavaClassImpl).virtualFile!!.contentsToByteArray()
}
val classId = (parentClass as KSClassDeclarationImpl).ktClassOrObjectSymbol.classId
?: return emptySequence()
val virtualFileContent = classId.getFileContent(fileManager) ?: return emptySequence()
val jvmDesc = this.mapToJvmSignatureInternal(accessor)
extractThrowsFromClassFile(
virtualFileContent,
Expand Down Expand Up @@ -462,10 +457,9 @@ class ResolverAAImpl(
Origin.KOTLIN_LIB, Origin.JAVA_LIB -> {
val fileManager = javaFileManager
val parentClass = function.findParentOfType<KSClassDeclaration>()
val classId = (parentClass as KSClassDeclarationImpl).ktClassOrObjectSymbol.classId!!
val virtualFileContent = analyze {
(fileManager.findClass(classId, analysisScope) as JavaClassImpl).virtualFile!!.contentsToByteArray()
}
val classId = (parentClass as KSClassDeclarationImpl).ktClassOrObjectSymbol.classId
?: return emptySequence()
val virtualFileContent = classId.getFileContent(fileManager) ?: return emptySequence()
val jvmDesc = this.mapToJvmSignature(function)
extractThrowsFromClassFile(virtualFileContent, jvmDesc, function.simpleName.asString())
}
Expand Down Expand Up @@ -924,4 +918,6 @@ class ResolverAAImpl(
} else KSFunctionErrorImpl(function)
}
}

internal val isJvm = ktModule.targetPlatform.all { it is JvmPlatform }
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import org.jetbrains.kotlin.analysis.api.symbols.KaSymbolModality
import org.jetbrains.kotlin.analysis.api.symbols.KaSymbolVisibility
import org.jetbrains.kotlin.analysis.api.symbols.receiverType
import org.jetbrains.kotlin.descriptors.annotations.AnnotationUseSiteTarget
import org.jetbrains.kotlin.load.java.structure.impl.JavaClassImpl
import org.jetbrains.kotlin.load.kotlin.JvmPackagePartSource
import org.jetbrains.kotlin.load.kotlin.KotlinJvmBinarySourceElement
import org.jetbrains.kotlin.psi.KtAnnotationEntry
Expand Down Expand Up @@ -140,6 +139,8 @@ class KSPropertyDeclarationImpl private constructor(internal val ktPropertySymbo
(ktPropertySymbol as? KaKotlinPropertySymbol)?.isLateInit == true -> true
ktPropertySymbol.modality == KaSymbolModality.ABSTRACT -> false
else -> {
if (!ResolverAAImpl.instance.isJvm)
return@lazy ktPropertySymbol.hasBackingField
val classId = when (
val containerSource =
(ktPropertySymbol as? KaFirKotlinPropertySymbol)?.firSymbol?.containerSource
Expand All @@ -149,12 +150,8 @@ class KSPropertyDeclarationImpl private constructor(internal val ktPropertySymbo
else -> null
} ?: return@lazy ktPropertySymbol.hasBackingField
val fileManager = ResolverAAImpl.instance.javaFileManager
val virtualFileContent = analyze {
(fileManager.findClass(classId, analysisScope) as JavaClassImpl)
.virtualFile!!.contentsToByteArray()
}
BinaryClassInfoCache.getCached(classId, virtualFileContent)
.fieldAccFlags.containsKey(simpleName.asString())
BinaryClassInfoCache.getCached(classId, fileManager)
?.fieldAccFlags?.containsKey(simpleName.asString()) ?: ktPropertySymbol.hasBackingField
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.google.devtools.ksp.impl.symbol.kotlin.resolved.KSAnnotationResolvedI
import com.google.devtools.ksp.impl.symbol.kotlin.resolved.KSClassifierParameterImpl
import com.google.devtools.ksp.impl.symbol.kotlin.resolved.KSClassifierReferenceResolvedImpl
import com.google.devtools.ksp.impl.symbol.util.getDocString
import com.google.devtools.ksp.impl.symbol.util.getFileContent
import com.google.devtools.ksp.symbol.*
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiJavaFile
Expand Down Expand Up @@ -64,7 +65,6 @@ import org.jetbrains.kotlin.fir.symbols.SymbolInternals
import org.jetbrains.kotlin.fir.symbols.lazyResolveToPhase
import org.jetbrains.kotlin.fir.types.*
import org.jetbrains.kotlin.load.java.structure.JavaAnnotationArgument
import org.jetbrains.kotlin.load.java.structure.impl.JavaClassImpl
import org.jetbrains.kotlin.load.java.structure.impl.JavaUnknownAnnotationArgumentImpl
import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryClassSignatureParser
import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaAnnotationVisitor
Expand Down Expand Up @@ -547,14 +547,16 @@ internal fun KaValueParameterSymbol.getDefaultValue(): KaAnnotationValue? {
}
// ClsMethodImpl means the psi is decompiled psi.
null, is ClsMemberImpl<*> -> {
// TODO: multiplatform
if (!ResolverAAImpl.instance.isJvm)
return@let null
val fileManager = ResolverAAImpl.instance.javaFileManager
val parentClass = this.getContainingKSSymbol()!!.findParentOfType<KSClassDeclaration>()
val classId = (parentClass as KSClassDeclarationImpl).ktClassOrObjectSymbol.classId!!
val file = analyze {
(fileManager.findClass(classId, analysisScope) as JavaClassImpl).virtualFile!!.contentsToByteArray()
}
val classId = (parentClass as KSClassDeclarationImpl).ktClassOrObjectSymbol.classId
?: return@let null
val fileContent = classId.getFileContent(fileManager) ?: return@let null
var defaultValue: JavaAnnotationArgument? = null
ClassReader(file).accept(
ClassReader(fileContent).accept(
object : ClassVisitor(Opcodes.API_VERSION) {
override fun visitMethod(
access: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ import com.google.devtools.ksp.impl.ResolverAAImpl
import com.google.devtools.ksp.impl.symbol.kotlin.AbstractKSDeclarationImpl
import com.google.devtools.ksp.impl.symbol.kotlin.KSFunctionDeclarationImpl
import com.google.devtools.ksp.impl.symbol.kotlin.KSPropertyDeclarationImpl
import com.google.devtools.ksp.impl.symbol.kotlin.analyze
import com.google.devtools.ksp.processing.Resolver
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.intellij.openapi.vfs.VirtualFile
import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCliJavaFileManagerImpl
import org.jetbrains.kotlin.load.java.structure.impl.JavaClassImpl
import org.jetbrains.kotlin.load.kotlin.KotlinJvmBinaryClass
import org.jetbrains.kotlin.load.kotlin.VirtualFileKotlinClass
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.Name
import org.jetbrains.org.objectweb.asm.ClassReader
Expand All @@ -48,16 +51,11 @@ data class BinaryClassInfo(
* Lookup cache for field names for deserialized classes.
* To check if a field has backing field, we need to look for binary field names, hence they are cached here.
*/
object BinaryClassInfoCache : KSObjectCache<ClassId, BinaryClassInfo>() {
fun getCached(
kotlinJvmBinaryClass: KotlinJvmBinaryClass,
) = getCached(
kotlinJvmBinaryClass.classId, (kotlinJvmBinaryClass as? VirtualFileKotlinClass)?.file?.contentsToByteArray()
)

fun getCached(classId: ClassId, virtualFileContent: ByteArray?) = cache.getOrPut(classId) {
object BinaryClassInfoCache : KSObjectCache<ClassId, BinaryClassInfo?>() {
fun getCached(classId: ClassId, fileManager: KotlinCliJavaFileManagerImpl) = cache.getOrPut(classId) {
val fieldAccFlags = mutableMapOf<String, Int>()
val methodAccFlags = mutableMapOf<String, Int>()
val virtualFileContent = classId.getFileContent(fileManager) ?: return@getOrPut null
ClassReader(virtualFileContent).accept(
object : ClassVisitor(Opcodes.API_VERSION) {
override fun visitField(
Expand Down Expand Up @@ -283,3 +281,12 @@ internal class DeclarationOrdering(
var STRICT_MODE = false
}
}

// Expensive; Use with caution.
internal fun ClassId.getFileContent(fileManager: KotlinCliJavaFileManagerImpl): ByteArray? =
getVirtualFile(fileManager)?.contentsToByteArray()

internal fun ClassId.getVirtualFile(fileManager: KotlinCliJavaFileManagerImpl): VirtualFile? =
analyze {
(fileManager.findClass(this@getVirtualFile, analysisScope) as? JavaClassImpl)?.virtualFile
}

0 comments on commit 41345b8

Please sign in to comment.