Skip to content

Commit

Permalink
[kotlinx.serialization] Correctly handle square bracket syntax for fi…
Browse files Browse the repository at this point in the history
…le-level annotations

such as @UseContextualSerialization and @UseSerializers.

Fixes Kotlin/kotlinx.serialization#2783


Merge-request: KT-MR-17576
Merged-by: Leonid Startsev <leonid.startsev@jetbrains.com>
  • Loading branch information
sandwwraith authored and qodana-bot committed Aug 23, 2024
1 parent 5529c50 commit 87f948b
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ abstract class AbstractSerialGenerator(val bindingContext: BindingContext?, val
.find { it.fqName == annotationFqName }
?: return emptyList()

@Suppress("UNCHECKED_CAST")
val typeList: List<KClassValue> = annotation.firstArgument()?.value as? List<KClassValue> ?: return emptyList()
return typeList.map { it.getArgumentType(declarationInFile.module) }
val typeList = annotation.firstArgument()?.value as? List<*> ?: return emptyList()
return typeList.filterIsInstance<KClassValue>().map { it.getArgumentType(declarationInFile.module) }
}

val contextualKClassListInCurrentFile: Set<KotlinType> by lazy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,26 @@ class ContextualSerializersProvider(session: FirSession) : FirExtensionSessionCo
return additionalSerializersInScopeCache.getValue(file)
}

private fun FirExpression.unwrapArguments(): List<FirExpression> = when (this) {
is FirArrayLiteral -> arguments
is FirVarargArgumentsExpression -> arguments
else -> emptyList()
}

private fun getKClassListFromFileAnnotation(file: FirFile, annotationClassId: ClassId): List<ConeKotlinType> {
val annotation = file.symbol.resolvedAnnotationsWithArguments.getAnnotationByClassId(
annotationClassId, session
) ?: return emptyList()
val arguments = when (val argument = annotation.argumentMapping.mapping.values.firstOrNull()) {
is FirArrayLiteral -> argument.arguments
is FirVarargArgumentsExpression -> argument.arguments
else -> return emptyList()
val annotationArgument = annotation.argumentMapping.mapping.values.firstOrNull()
val arguments = annotationArgument?.unwrapArguments() ?: return emptyList()
val classes: List<FirGetClassCall> = arguments.flatMap {
when (it) {
is FirGetClassCall -> listOf(it)
is FirSpreadArgumentExpression -> it.expression.unwrapArguments().filterIsInstance<FirGetClassCall>()
else -> emptyList()
}
}
return arguments.mapNotNull { (it as? FirGetClassCall)?.getTargetType()?.fullyExpandedType(session) }
return classes.mapNotNull { it.getTargetType()?.fullyExpandedType(session) }
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// WITH_STDLIB

// FILE: a.kt

package a

import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.encoding.*

object MultiplyingIntSerializer : KSerializer<Int> {
override val descriptor: SerialDescriptor
get() = PrimitiveSerialDescriptor("MultiplyingInt", PrimitiveKind.INT)

override fun deserialize(decoder: Decoder): Int {
return decoder.decodeInt() / 2
}

override fun serialize(encoder: Encoder, value: Int) {
encoder.encodeInt(value * 2)
}
}

data class Cont(val i: Int)

object ContSerializer: KSerializer<Cont> {
override fun deserialize(decoder: Decoder): Cont {
return Cont(decoder.decodeInt())
}

override val descriptor: SerialDescriptor = PrimitiveSerialDescriptor("ContSerializer", PrimitiveKind.INT)

override fun serialize(encoder: Encoder, value: Cont) {
encoder.encodeInt(value.i)
}
}

// FILE: test.kt

@file:UseContextualSerialization(forClasses = [Cont::class])
@file:UseSerializers(*[MultiplyingIntSerializer::class])

package a

import kotlinx.serialization.*
import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*

@Serializable
class Holder(
val i: Int,
val c: Cont
)

fun testOnFile(): String {
val j = Json {
serializersModule = SerializersModule {
contextual(ContSerializer)
}
}
val h = Holder(3, Cont(4))
val str = j.encodeToString(
Holder.serializer(),
h
)
if ("""{"i":6,"c":4}""" != str) return str
val decoded = j.decodeFromString(Holder.serializer(), str)
if (decoded.i != h.i) return "i: ${decoded.i}"
if (decoded.c.i != h.c.i) return "c.i: ${decoded.c.i}"
return "OK"
}

fun box(): String {
return testOnFile()
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 87f948b

Please sign in to comment.