Skip to content

Commit

Permalink
Improve readability of protobuf decoding exception messages (#2768)
Browse files Browse the repository at this point in the history
- Wrap definition of wire types with enum class ProtoWireType to show both name and id in message,
- Catch and rethrow any ProtobufDecodingException in most decodeXXX functions, with proto number and type name in new exception message.
  • Loading branch information
xiaozhikang0916 committed Aug 22, 2024
1 parent b931598 commit 0b5145c
Show file tree
Hide file tree
Showing 8 changed files with 530 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,29 @@ import kotlinx.serialization.modules.*
import kotlinx.serialization.protobuf.*

internal typealias ProtoDesc = Long
internal const val VARINT = 0
internal const val i64 = 1
internal const val SIZE_DELIMITED = 2
internal const val i32 = 5

internal enum class ProtoWireType(val typeId: Int) {
INVALID(-1),
VARINT(0),
i64(1),
SIZE_DELIMITED(2),
i32(5),
;

companion object {
fun from(typeId: Int): ProtoWireType {
return ProtoWireType.entries.find { it.typeId == typeId } ?: INVALID
}
}

fun wireIntWithTag(tag: Int): Int {
return ((tag shl 3) or typeId)
}

override fun toString(): String {
return "${this.name}($typeId)"
}
}

internal const val ID_HOLDER_ONE_OF = -2

Expand Down Expand Up @@ -104,7 +123,7 @@ internal fun extractProtoId(descriptor: SerialDescriptor, index: Int, zeroBasedD
return result
}

internal class ProtobufDecodingException(message: String) : SerializationException(message)
internal class ProtobufDecodingException(message: String, e: Throwable? = null) : SerializationException(message, e)

internal expect fun Int.reverseBytes(): Int
internal expect fun Long.reverseBytes(): Long
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,41 +122,53 @@ internal open class ProtobufDecoder(
}

override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
return when (descriptor.kind) {
StructureKind.LIST -> {
val tag = currentTagOrDefault
return if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) {
val reader = makeDelimited(reader, tag)
// repeated decoder expects the first tag to be read already
reader.readTag()
// all elements always have id = 1
RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor)

} else if (reader.currentType == SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) {
val sliceReader = ProtobufReader(reader.objectInput())
PackedArrayDecoder(proto, sliceReader, descriptor)

} else {
RepeatedDecoder(proto, reader, tag, descriptor)
return try {
when (descriptor.kind) {
StructureKind.LIST -> {
val tag = currentTagOrDefault
return if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) {
val reader = makeDelimited(reader, tag)
// repeated decoder expects the first tag to be read already
reader.readTag()
// all elements always have id = 1
RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor)

} else if (reader.currentType == ProtoWireType.SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) {
val sliceReader = ProtobufReader(reader.objectInput())
PackedArrayDecoder(proto, sliceReader, descriptor)

} else {
RepeatedDecoder(proto, reader, tag, descriptor)
}
}
}
StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
val tag = currentTagOrDefault
// Do not create redundant copy
if (tag == MISSING_TAG && this.descriptor == descriptor) return this
if (tag.isOneOf) {
// If a tag is annotated as oneof
// [tag.protoId] here is overwritten with index-based default id in
// [kotlinx.serialization.protobuf.internal.HelpersKt.extractParameters]
// and restored the real id from index2IdMap, set by [decodeElementIndex]
val rawIndex = tag.protoId - 1
val restoredTag = index2IdMap?.get(rawIndex)?.let { tag.overrideId(it) } ?: tag
return OneOfPolymorphicReader(proto, reader, restoredTag, descriptor)

StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
val tag = currentTagOrDefault
// Do not create redundant copy
if (tag == MISSING_TAG && this.descriptor == descriptor) return this
if (tag.isOneOf) {
// If a tag is annotated as oneof
// [tag.protoId] here is overwritten with index-based default id in
// [kotlinx.serialization.protobuf.internal.HelpersKt.extractParameters]
// and restored the real id from index2IdMap, set by [decodeElementIndex]
val rawIndex = tag.protoId - 1
val restoredTag = index2IdMap?.get(rawIndex)?.let { tag.overrideId(it) } ?: tag
return OneOfPolymorphicReader(proto, reader, restoredTag, descriptor)
}
return ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor)
}
return ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor)

StructureKind.MAP -> MapEntryReader(
proto,
makeDelimitedForced(reader, currentTagOrDefault),
currentTagOrDefault,
descriptor
)

else -> throw SerializationException("Primitives are not supported at top-level")
}
StructureKind.MAP -> MapEntryReader(proto, makeDelimitedForced(reader, currentTagOrDefault), currentTagOrDefault, descriptor)
else -> throw SerializationException("Primitives are not supported at top-level")
} catch (e: ProtobufDecodingException) {
throw ProtobufDecodingException("Fail to begin structure for ${descriptor.serialName} in ${this.descriptor.serialName} at proto number ${currentTagOrDefault.protoId}", e)
}
}

Expand All @@ -173,41 +185,51 @@ internal open class ProtobufDecoder(
override fun decodeTaggedByte(tag: ProtoDesc): Byte = decodeTaggedInt(tag).toByte()
override fun decodeTaggedShort(tag: ProtoDesc): Short = decodeTaggedInt(tag).toShort()
override fun decodeTaggedInt(tag: ProtoDesc): Int {
return if (tag == MISSING_TAG) {
reader.readInt32NoTag()
} else {
reader.readInt(tag.integerType)
return decodeOrThrow(tag) {
if (tag == MISSING_TAG) {
reader.readInt32NoTag()
} else {
reader.readInt(tag.integerType)
}
}
}
override fun decodeTaggedLong(tag: ProtoDesc): Long {
return if (tag == MISSING_TAG) {
reader.readLongNoTag()
} else {
reader.readLong(tag.integerType)
return decodeOrThrow(tag) {
if (tag == MISSING_TAG) {
reader.readLongNoTag()
} else {
reader.readLong(tag.integerType)
}
}
}

override fun decodeTaggedFloat(tag: ProtoDesc): Float {
return if (tag == MISSING_TAG) {
reader.readFloatNoTag()
} else {
reader.readFloat()
return decodeOrThrow(tag) {
if (tag == MISSING_TAG) {
reader.readFloatNoTag()
} else {
reader.readFloat()
}
}
}
override fun decodeTaggedDouble(tag: ProtoDesc): Double {
return if (tag == MISSING_TAG) {
reader.readDoubleNoTag()
} else {
reader.readDouble()
return decodeOrThrow(tag) {
if (tag == MISSING_TAG) {
reader.readDoubleNoTag()
} else {
reader.readDouble()
}
}
}
override fun decodeTaggedChar(tag: ProtoDesc): Char = decodeTaggedInt(tag).toChar()

override fun decodeTaggedString(tag: ProtoDesc): String {
return if (tag == MISSING_TAG) {
reader.readStringNoTag()
} else {
reader.readString()
return decodeOrThrow(tag) {
if (tag == MISSING_TAG) {
reader.readStringNoTag()
} else {
reader.readString()
}
}
}

Expand All @@ -218,22 +240,49 @@ internal open class ProtobufDecoder(
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T = decodeSerializableValue(deserializer, null)

@Suppress("UNCHECKED_CAST")
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>, previousValue: T?): T = when {
deserializer is MapLikeSerializer<*, *, *, *> -> {
deserializeMap(deserializer as DeserializationStrategy<T>, previousValue)
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>, previousValue: T?): T = try {
when {
deserializer is MapLikeSerializer<*, *, *, *> -> {
deserializeMap(deserializer as DeserializationStrategy<T>, previousValue)
}

deserializer.descriptor == ByteArraySerializer().descriptor -> deserializeByteArray(previousValue as ByteArray?) as T
deserializer is AbstractCollectionSerializer<*, *, *> ->
(deserializer as AbstractCollectionSerializer<*, T, *>).merge(this, previousValue)

else -> deserializer.deserialize(this)
}
} catch (e: ProtobufDecodingException) {
val currentTag = currentTagOrDefault
val msg = if (descriptor != deserializer.descriptor) {
// Decoding child element
if (descriptor.kind == StructureKind.LIST && deserializer.descriptor.kind != StructureKind.MAP) {
// Decoding repeated field
"Error while decoding index ${currentTag.protoId - 1} in repeated field of ${deserializer.descriptor.serialName}"
} else if (descriptor.kind == StructureKind.MAP) {
// Decoding map field
val index = (currentTag.protoId - 1) / 2
val field = if ((currentTag.protoId - 1) % 2 == 0) { "key" } else "value"
"Error while decoding $field of index $index in map field of ${deserializer.descriptor.serialName}"
} else {
// Decoding common class
"Error while decoding ${deserializer.descriptor.serialName} at proto number ${currentTag.protoId} of ${descriptor.serialName}"
}
} else {
// Decoding self
"Error while decoding ${descriptor.serialName}"
}
deserializer.descriptor == ByteArraySerializer().descriptor -> deserializeByteArray(previousValue as ByteArray?) as T
deserializer is AbstractCollectionSerializer<*, *, *> ->
(deserializer as AbstractCollectionSerializer<*, T, *>).merge(this, previousValue)
else -> deserializer.deserialize(this)
throw ProtobufDecodingException(msg, e)
}

private fun deserializeByteArray(previousValue: ByteArray?): ByteArray {
val tag = currentTagOrDefault
val array = if (tag == MISSING_TAG) {
reader.readByteArrayNoTag()
} else {
reader.readByteArray()
val array = decodeOrThrow(tag) {
if (tag == MISSING_TAG) {
reader.readByteArrayNoTag()
} else {
reader.readByteArray()
}
}
return if (previousValue == null) array else previousValue + array
}
Expand All @@ -252,29 +301,33 @@ internal open class ProtobufDecoder(
override fun SerialDescriptor.getTag(index: Int) = extractParameters(index)

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
while (true) {
val protoId = reader.readTag()
if (protoId == -1) { // EOF
return elementMarker.nextUnmarkedIndex()
}
val index = getIndexByNum(protoId)
if (index == -1) { // not found
reader.skipElement()
} else {
if (descriptor.extractParameters(index).isOneOf) {
/**
* While decoding message with one-of field,
* the proto id read from wire data cannot be easily found
* in the properties of this type,
* So the index of this one-of property and the id read from the wire
* are saved in this map, then restored in [beginStructure]
* and passed to [OneOfPolymorphicReader] to get the actual deserializer.
*/
index2IdMap?.put(index, protoId)
try {
while (true) {
val protoId = reader.readTag()
if (protoId == -1) { // EOF
return elementMarker.nextUnmarkedIndex()
}
val index = getIndexByNum(protoId)
if (index == -1) { // not found
reader.skipElement()
} else {
if (descriptor.extractParameters(index).isOneOf) {
/**
* While decoding message with one-of field,
* the proto id read from wire data cannot be easily found
* in the properties of this type,
* So the index of this one-of property and the id read from the wire
* are saved in this map, then restored in [beginStructure]
* and passed to [OneOfPolymorphicReader] to get the actual deserializer.
*/
index2IdMap?.put(index, protoId)
}
elementMarker.mark(index)
return index
}
elementMarker.mark(index)
return index
}
} catch (e: ProtobufDecodingException) {
throw ProtobufDecodingException("Fail to get element index for ${descriptor.serialName} in ${this.descriptor.serialName}", e)
}
}

Expand All @@ -296,6 +349,19 @@ internal open class ProtobufDecoder(
}
return false
}

private inline fun <T> decodeOrThrow(tag: ProtoDesc, action: (tag: ProtoDesc) -> T): T {
try {
return action(tag)
} catch (e: ProtobufDecodingException) {
rethrowException(tag, e)
}
}

@Suppress("NOTHING_TO_INLINE")
private inline fun rethrowException(tag: ProtoDesc, e: ProtobufDecodingException): Nothing {
throw ProtobufDecodingException("Error while decoding proto number ${tag.protoId} of ${descriptor.serialName}", e)
}
}

private class RepeatedDecoder(
Expand Down
Loading

0 comments on commit 0b5145c

Please sign in to comment.