Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for const length arrays ([u8; N]) #782

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 96 additions & 25 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ macro_rules! length_delimited {
B: Buf,
{
check_wire_type(WireType::LengthDelimited, wire_type)?;
let mut value = Default::default();
let mut value = crate::encoding::sealed::Newable::new();
merge(wire_type, &mut value, buf, ctx)?;
values.push(value);
Ok(())
Expand Down Expand Up @@ -873,42 +873,63 @@ pub mod string {
}
}

pub trait BytesAdapter: sealed::BytesAdapter {}
pub trait BytesAdapter: Sized + 'static {
/// Create a new instance (required as `Default` is not available for all types)
fn new() -> Self;

mod sealed {
use super::{Buf, BufMut};
fn len(&self) -> usize;

pub trait BytesAdapter: Default + Sized + 'static {
fn len(&self) -> usize;
/// Replace contents of this buffer with the contents of another buffer.
fn replace_with<B>(&mut self, buf: B) -> Result<(), DecodeError>
where
B: Buf;

/// Replace contents of this buffer with the contents of another buffer.
fn replace_with<B>(&mut self, buf: B)
where
B: Buf;
/// Appends this buffer to the (contents of) other buffer.
fn append_to<B>(&self, buf: &mut B)
where
B: BufMut;

/// Appends this buffer to the (contents of) other buffer.
fn append_to<B>(&self, buf: &mut B)
where
B: BufMut;
fn is_empty(&self) -> bool {
self.len() == 0
}

fn clear(&mut self);
}

mod sealed {
pub trait Newable: Sized {
fn new() -> Self;
}

impl<T: super::BytesAdapter> Newable for T {
fn new() -> Self {
super::BytesAdapter::new()
}
}

fn is_empty(&self) -> bool {
self.len() == 0
impl Newable for alloc::string::String {
fn new() -> Self {
Default::default()
}
}
}

impl BytesAdapter for Bytes {}
impl BytesAdapter for Bytes {
fn new() -> Self {
Default::default()
}

impl sealed::BytesAdapter for Bytes {
fn len(&self) -> usize {
Buf::remaining(self)
}

fn replace_with<B>(&mut self, mut buf: B)
fn replace_with<B>(&mut self, mut buf: B) -> Result<(), DecodeError>
where
B: Buf,
{
*self = buf.copy_to_bytes(buf.remaining());

Ok(())
}

fn append_to<B>(&self, buf: &mut B)
Expand All @@ -917,22 +938,30 @@ impl sealed::BytesAdapter for Bytes {
{
buf.put(self.clone())
}

fn clear(&mut self) {
Bytes::clear(self)
}
}

impl BytesAdapter for Vec<u8> {}
impl BytesAdapter for Vec<u8> {
fn new() -> Self {
Default::default()
}

impl sealed::BytesAdapter for Vec<u8> {
fn len(&self) -> usize {
Vec::len(self)
}

fn replace_with<B>(&mut self, buf: B)
fn replace_with<B>(&mut self, buf: B) -> Result<(), DecodeError>
where
B: Buf,
{
self.clear();
Vec::clear(self);
self.reserve(buf.remaining());
self.put(buf);

Ok(())
}

fn append_to<B>(&self, buf: &mut B)
Expand All @@ -941,6 +970,46 @@ impl sealed::BytesAdapter for Vec<u8> {
{
buf.put(self.as_slice())
}

fn clear(&mut self) {
Vec::clear(self)
}
}

impl<const N: usize> BytesAdapter for [u8; N] {
fn new() -> Self {
[0u8; N]
}

fn len(&self) -> usize {
N
}

fn replace_with<B>(&mut self, buf: B) -> Result<(), DecodeError>
where
B: Buf,
{
if buf.remaining() != N {
return Err(DecodeError::new("invalid byte array length"));
}

self.copy_from_slice(buf.chunk());

Ok(())
}

fn append_to<B>(&self, buf: &mut B)
where
B: BufMut,
{
buf.put(&self[..])
}

fn clear(&mut self) {
for b in &mut self[..] {
*b = 0;
}
}
}

pub mod bytes {
Expand Down Expand Up @@ -985,7 +1054,8 @@ pub mod bytes {
// This is intended for A and B both being Bytes so it is zero-copy.
// Some combinations of A and B types may cause a double-copy,
// in which case merge_one_copy() should be used instead.
value.replace_with(buf.copy_to_bytes(len));
value.replace_with(buf.copy_to_bytes(len))?;

Ok(())
}

Expand All @@ -1007,7 +1077,8 @@ pub mod bytes {
let len = len as usize;

// If we must copy, make sure to copy only once.
value.replace_with(buf.take(len));
value.replace_with(buf.take(len))?;

Ok(())
}

Expand Down
41 changes: 41 additions & 0 deletions tests/src/derive_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use prost::{encoding::BytesAdapter, Message};

// NOTE: [Message] still requires [Default], which is not implemented for [u8; N],
// so this will only work directly for [u8; <=32] for the moment...
// see: https://github.com/rust-lang/rust/issues/61415

/// Const array container A
#[derive(Clone, PartialEq, Message)]
pub struct TestA {
#[prost(bytes, required, tag = "1")]
pub b: [u8; 3],
}

/// Const array container B
#[derive(Clone, PartialEq, Message)]
pub struct TestB {
#[prost(bytes, required, tag = "1")]
pub b: [u8; 4],
}

// Test valid encode/decode
#[test]
fn const_array_encode_decode() {
let t = TestA { b: [1, 2, 3] };

let buff = t.encode_to_vec();

let t1 = TestA::decode(&*buff).unwrap();

assert_eq!(t, t1);
}

// test encode/decode length mismatch
#[test]
fn const_array_length_mismatch() {
let t = TestA { b: [1, 2, 3] };

let buff = t.encode_to_vec();

assert!(TestB::decode(&*buff).is_err());
}
2 changes: 2 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ mod debug;
#[cfg(test)]
mod deprecated_field;
#[cfg(test)]
mod derive_const;
#[cfg(test)]
mod generic_derive;
#[cfg(test)]
mod message_encoding;
Expand Down