Skip to content

Commit

Permalink
Merge pull request #3663 from BenPope/schema-regstry-protobuf-file-de…
Browse files Browse the repository at this point in the history
…scriptor

schema_registry/proto: Accept protobuf as an encoded file descriptor
  • Loading branch information
BenPope committed Feb 2, 2022
2 parents b978b47 + dae8b89 commit 0773cf9
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 4 deletions.
21 changes: 18 additions & 3 deletions src/v/pandaproxy/schema_registry/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@

#include "pandaproxy/schema_registry/protobuf.h"

#include "pandaproxy/logger.h"
#include "pandaproxy/schema_registry/errors.h"
#include "pandaproxy/schema_registry/sharded_store.h"
#include "utils/base64.h"
#include "vlog.h"

#include <seastar/core/coroutine.hh>

#include <fmt/ostream.h>
#include <google/protobuf/compiler/parser.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/io/tokenizer.h>
#include <google/protobuf/io/zero_copy_stream.h>

Expand Down Expand Up @@ -132,11 +136,21 @@ class parser {
schema_def_input_stream is{schema.def()};
io_error_collector error_collector;
pb::io::Tokenizer t{&is, &error_collector};
_parser.RecordErrorsTo(&error_collector);

// Attempt parse a .proto file
if (!_parser.Parse(&t, &_fdp)) {
throw as_exception(error_collector.error());
// base64 decode the schema
std::string_view b64_def{
schema.def().raw()().data(), schema.def().raw()().size()};
auto bytes_def = base64_to_bytes(b64_def);

// Attempt parse as an encoded FileDescriptorProto.pb
if (!_fdp.ParseFromArray(
bytes_def.data(), static_cast<int>(bytes_def.size()))) {
throw as_exception(error_collector.error());
}
}

_fdp.set_name(schema.sub()());
return _fdp;
}
Expand Down Expand Up @@ -184,8 +198,9 @@ ss::future<const pb::FileDescriptor*> build_file_with_refs(
ss::future<const pb::FileDescriptor*> import_schema(
pb::DescriptorPool& dp, sharded_store& store, canonical_schema schema) {
try {
co_return co_await build_file_with_refs(dp, store, std::move(schema));
co_return co_await build_file_with_refs(dp, store, schema);
} catch (const exception& e) {
vlog(plog.warn, "Failed to decode schema: {}", e.what());
throw as_exception(invalid_schema(schema));
}
}
Expand Down
231 changes: 231 additions & 0 deletions tests/rptest/clients/python_librdkafka_serde_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright 2022 Vectorized, Inc.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.md
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0

import argparse
import logging
from collections import OrderedDict
from enum import Enum
from uuid import uuid4

from confluent_kafka import DeserializingConsumer, SerializingProducer
from confluent_kafka.serialization import StringDeserializer, StringSerializer
from confluent_kafka.schema_registry import SchemaRegistryClient
from confluent_kafka.schema_registry.avro import AvroDeserializer, AvroSerializer
from confluent_kafka.schema_registry.protobuf import ProtobufDeserializer, ProtobufSerializer

from google.protobuf.descriptor_pb2 import FieldDescriptorProto
from google.protobuf import proto_builder


class AvroPayload(OrderedDict):
def __init__(self, val: int):
OrderedDict.__init__(OrderedDict([('val', int)]))
self['val'] = val

@property
def val(self):
return self['val']


AVRO_SCHEMA = '''{
"type": "record",
"name": "payload",
"fields": [
{"name": "val", "type": "int"}
]
}
'''

ProtobufPayloadClass = proto_builder.MakeSimpleProtoClass(
OrderedDict([('val', FieldDescriptorProto.TYPE_INT64)]),
full_name="example.Payload")


def make_protobuf_payload(val: int):
p = ProtobufPayloadClass()
p.val = val
return p


class SchemaType(Enum):
AVRO = 1
PROTOBUF = 2


class SerdeClient:
"""
SerdeClient produces and consumes payloads in avro or protobuf formats
The expected offset is stored in the payload and checked.
"""
def __init__(self,
brokers,
schema_registry_url,
schema_type: SchemaType,
*,
topic=str(uuid4()),
group=str(uuid4()),
logger=logging.getLogger("SerdeClient")):
self.logger = logger
self.brokers = brokers
self.sr_client = SchemaRegistryClient({'url': schema_registry_url})
self.schema_type = schema_type
self.topic = topic
self.group = group

self.produced = 0
self.acked = 0
self.consumed = 0

def _make_serializer(self):
return {
SchemaType.AVRO:
AvroSerializer(self.sr_client, AVRO_SCHEMA),
SchemaType.PROTOBUF:
ProtobufSerializer(ProtobufPayloadClass, self.sr_client)
}[self.schema_type]

def _make_deserializer(self):
return {
SchemaType.AVRO:
AvroDeserializer(self.sr_client,
AVRO_SCHEMA,
from_dict=lambda d, _: AvroPayload(d['val'])),
SchemaType.PROTOBUF:
ProtobufDeserializer(ProtobufPayloadClass)
}[self.schema_type]

def _make_payload(self, val: int):
return {
SchemaType.AVRO: AvroPayload(val),
SchemaType.PROTOBUF: make_protobuf_payload(val)
}[self.schema_type]

def produce(self, count: int):
def increment(err, msg):
assert err is None
assert msg is not None
assert msg.offset() == self.acked
self.logger.debug("Acked offset %d", msg.offset())
self.acked += 1

producer = SerializingProducer({
'bootstrap.servers':
self.brokers,
'key.serializer':
StringSerializer('utf_8'),
'value.serializer':
self._make_serializer()
})

self.logger.info("Producing %d %s records to topic %s", count,
self.schema_type.name, self.topic)
for i in range(count):
# Prevent overflow of buffer
while len(producer) > 50000:
# Serve on_delivery callbacks from previous calls to produce()
producer.poll(0.1)

producer.produce(topic=self.topic,
key=str(uuid4()),
value=self._make_payload(i),
on_delivery=increment)
self.produced += 1

self.logger.info("Flushing records...")
producer.flush()
self.logger.info("Records flushed: %d", self.produced)
while self.acked < count:
producer.poll(0.01)
self.logger.info("Records acked: %d", self.acked)

def consume(self, count: int):
consumer = DeserializingConsumer({
'bootstrap.servers':
self.brokers,
'key.deserializer':
StringDeserializer('utf_8'),
'value.deserializer':
self._make_deserializer(),
'group.id':
self.group,
'auto.offset.reset':
"earliest"
})
consumer.subscribe([self.topic])

self.logger.info("Consuming %d %s records from topic %s with group %s",
count, self.schema_type.name, self.topic, self.group)
while self.consumed < count:
msg = consumer.poll(1)
if msg is None:
continue
payload = msg.value()
self.logger.debug("Consumed %d at %d", payload.val, msg.offset())
assert payload.val == self.consumed
self.consumed += 1

consumer.close()

def run(self, count: int):
self.produce(count)
assert self.produced == count
assert self.acked == count
self.consume(count)
assert self.consumed == count


def main(args):
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
handler.setLevel(logging.DEBUG)
logger = logging.getLogger("SerdeClient")
logger.addHandler(handler)

p = SerdeClient(args.bootstrap_servers,
args.schema_registry,
SchemaType[args.protocol],
topic=args.topic,
group=args.group,
logger=logger)
p.run(args.count)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="SerdeClient")
parser.add_argument('-b',
dest="bootstrap_servers",
required=True,
help="Bootstrap broker(s) (host[:port])")
parser.add_argument('-s',
dest="schema_registry",
required=True,
help="Schema Registry (http(s)://host[:port]")
parser.add_argument('-p',
dest="protocol",
default=SchemaType.AVRO.name,
choices=SchemaType._member_names_,
help="Topic name")
parser.add_argument('-t',
dest="topic",
default=str(uuid4()),
help="Topic name")
parser.add_argument('-g',
dest="group",
default=str(uuid4()),
help="Topic name")
parser.add_argument('-c',
dest="count",
default=1,
type=int,
help="Number of messages to send")

main(parser.parse_args())
28 changes: 28 additions & 0 deletions tests/rptest/tests/schema_registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from rptest.clients.types import TopicSpec
from rptest.clients.kafka_cli_tools import KafkaCliTools
from rptest.clients.python_librdkafka_serde_client import SerdeClient, SchemaType
from rptest.tests.redpanda_test import RedpandaTest
from rptest.services.redpanda import ResourceSettings

Expand Down Expand Up @@ -1013,3 +1014,30 @@ def test_protobuf(self):
self.logger.info(result_raw)
assert result_raw.status_code == requests.codes.ok
assert result_raw.json() == [2]

@cluster(num_nodes=3)
def test_serde_client(self):
"""
Verify basic serialization client
"""
protocols = [SchemaType.AVRO, SchemaType.PROTOBUF]
topics = [f"serde-topic-{x.name}" for x in protocols]
self._create_topics(topics)
schema_reg = self.redpanda.schema_reg().split(',', 1)[0]
for i in range(len(protocols)):
self.logger.info(
f"Connecting to redpanda: {self.redpanda.brokers()} schema_reg: {schema_reg}"
)
client = SerdeClient(self.redpanda.brokers(),
schema_reg,
protocols[i],
topic=topics[i],
logger=self.logger)
client.run(2)
schema = self._get_subjects_subject_versions_version(
f"{topics[i]}-value", "latest")
self.logger.info(schema.json())
if protocols[i] == SchemaType.AVRO:
assert schema.json().get("schemaType") is None
else:
assert schema.json()["schemaType"] == protocols[i].name
2 changes: 1 addition & 1 deletion tests/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
'ducktape@git+https://github.com/vectorizedio/ducktape.git@6e2af9173a79feb8661c4c7a5776080721710a43',
'prometheus-client==0.9.0', 'pyyaml==5.3.1', 'kafka-python==2.0.2',
'crc32c==2.2', 'confluent-kafka==1.7.0', 'zstandard==0.15.2',
'xxhash==2.0.2'
'xxhash==2.0.2', 'protobuf==3.19.3', 'fastavro==1.4.9'
],
scripts=[],
)

0 comments on commit 0773cf9

Please sign in to comment.