Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 130 additions & 15 deletions collections/src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,18 @@ impl<T: BorshDeserialize> BorshDeserialize for TrailingVec<T> {
unsafe impl<T, C> SchemaWrite<C> for TrailingVec<T>
where
C: ConfigCore,
T: SchemaWrite<C>,
T: SchemaWrite<C, Src = T>,
{
type Src = Self;

#[inline(always)]
fn size_of(src: &Self::Src) -> WriteResult<usize> {
let expected_size = src.0.len().saturating_mul(core::mem::size_of::<T>());
// Sum the serialized size of each element, matching the per-element
// decoding performed by the read side.
let mut expected_size = 0usize;
for item in src.0.iter() {
expected_size = expected_size.saturating_add(<T as SchemaWrite<C>>::size_of(item)?);
}

// `Vec` capacity is limited to `isize::MAX`.
if expected_size > isize::MAX as usize {
Expand All @@ -156,12 +161,13 @@ where

#[inline(always)]
fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> {
// SAFETY: Serializing a slice `[T]` without a length prefix.
unsafe {
writer
.write_slice_t(src.0.as_slice())
.map_err(wincode::WriteError::Io)
// Serialize each item via its schema so the written bytes match the
// per-element decoding performed by the read side.
for item in src.0.iter() {
<T as SchemaWrite<C>>::write(&mut writer, item)?;
}

Ok(())
}
}

Expand Down Expand Up @@ -259,14 +265,20 @@ macro_rules! prefixed_vec_type {
unsafe impl<T, C> SchemaWrite<C> for $name<T>
where
C: ConfigCore,
T: SchemaWrite<C>,
T: SchemaWrite<C, Src = T>,
{
type Src = Self;

#[inline(always)]
fn size_of(src: &Self::Src) -> WriteResult<usize> {
let expected_size = core::mem::size_of::<$prefix_type>().saturating_add(
src.0.len().saturating_mul(size_of::<T>()));
// Start with the length prefix, then sum the serialized size of
// each element, matching the per-element decoding performed by
// the read side.
let mut expected_size = core::mem::size_of::<$prefix_type>();
for item in src.0.iter() {
expected_size = expected_size
.saturating_add(<T as SchemaWrite<C>>::size_of(item)?);
}

// `Vec` capacity is limited to `isize::MAX`.
if expected_size > isize::MAX as usize {
Expand All @@ -285,12 +297,13 @@ macro_rules! prefixed_vec_type {
&$prefix_type::try_from(src.0.len())
.map_err(|_| write_length_encoding_overflow(stringify!($prefix_type::MAX)))?,
)?;
// SAFETY: Serializing a slice `[T]`.
unsafe {
writer
.write_slice_t(src.0.as_slice())
.map_err(wincode::WriteError::Io)
// Serialize each item via its schema so the written bytes match
// the per-element decoding performed by the read side.
for item in src.0.iter() {
<T as SchemaWrite<C>>::write(&mut writer, item)?;
}

Ok(())
}
}

Expand Down Expand Up @@ -343,6 +356,8 @@ prefixed_vec_type!(U64PrefixedVec, u64);

#[cfg(test)]
mod tests {
use alloc::vec;

use borsh::{BorshDeserialize, BorshSerialize};
use core::mem::size_of;
use wincode::WriteError;
Expand Down Expand Up @@ -515,4 +530,104 @@ mod tests {
assert_eq!(serialized.len(), 8);
assert_eq!(serialized.as_slice(), &[!(0u64); 8]);
}

/// A non-POD element type: its wincode-serialized size (5 bytes: a `u8`
/// followed by a little-endian `u32`) differs from `size_of::<NonPod>()`
/// (8 bytes, due to `u32` alignment padding). Serializing a `Vec` of these
/// with a raw byte copy would write the in-memory padding and drift the
/// parse boundary against the per-element read side; serializing each
/// element through its schema keeps writer and reader in sync.
#[cfg(feature = "wincode")]
#[derive(Clone, Debug, Eq, PartialEq, wincode::SchemaRead, wincode::SchemaWrite)]
struct NonPod {
a: u8,
b: u32,
}

/// The regression cases below wrap the vec in an outer struct with a
/// trailing scalar field to prove the parse boundary is preserved.
#[cfg(feature = "wincode")]
#[derive(Debug, Eq, PartialEq, wincode::SchemaRead, wincode::SchemaWrite)]
struct TrailingWrapper {
// `TrailingVec` must be the last field, so the trailing scalar comes
// before it in the struct layout.
trailing_marker: u64,
items: TrailingVec<NonPod>,
}

#[cfg(feature = "wincode")]
#[derive(Debug, Eq, PartialEq, wincode::SchemaRead, wincode::SchemaWrite)]
struct PrefixedWrapper {
items: U16PrefixedVec<NonPod>,
// A field *after* the vec: only decodable if the vec consumed exactly
// the bytes it wrote.
trailing_marker: u64,
}

#[cfg(feature = "wincode")]
#[test]
fn trailing_vec_wincode_non_pod_round_trip() {
// Guard the premise of this test: the element is genuinely non-POD.
assert_ne!(size_of::<NonPod>(), 5);

let items = TrailingVec::from(vec![
NonPod {
a: 1,
b: 0x1122_3344,
},
NonPod {
a: 2,
b: 0x5566_7788,
},
NonPod {
a: 3,
b: 0x99aa_bbcc,
},
]);
let original = TrailingWrapper {
trailing_marker: 0xdead_beef_cafe_babe,
items,
};

let bytes = wincode::serialize(&original).unwrap();
// 8 (marker) + 3 * 5 (each `NonPod` serialized) == 23 bytes, *not*
// 8 + 3 * size_of::<NonPod>().
assert_eq!(bytes.len(), 8 + 3 * 5);

let decoded = wincode::deserialize::<TrailingWrapper>(&bytes).unwrap();
assert_eq!(decoded, original);
}

#[cfg(feature = "wincode")]
#[test]
fn prefixed_vec_wincode_non_pod_round_trip() {
// Guard the premise of this test: the element is genuinely non-POD.
assert_ne!(size_of::<NonPod>(), 5);

let items = U16PrefixedVec::from(vec![
NonPod {
a: 10,
b: 0x0102_0304,
},
NonPod {
a: 20,
b: 0x0506_0708,
},
]);
let original = PrefixedWrapper {
items,
trailing_marker: 0x0011_2233_4455_6677,
};

let bytes = wincode::serialize(&original).unwrap();
// 2 (u16 prefix) + 2 * 5 (each `NonPod`) + 8 (trailing marker) == 20
// bytes, *not* 2 + 2 * size_of::<NonPod>() + 8.
assert_eq!(bytes.len(), 2 + 2 * 5 + 8);

let decoded = wincode::deserialize::<PrefixedWrapper>(&bytes).unwrap();
assert_eq!(decoded, original);
// The trailing field decoded correctly, proving the prefixed vec
// consumed exactly the bytes it wrote.
assert_eq!(decoded.trailing_marker, 0x0011_2233_4455_6677);
}
}
Loading