diff --git a/src/slice/array.rs b/src/slice/array.rs new file mode 100644 index 000000000..bae9ea686 --- /dev/null +++ b/src/slice/array.rs @@ -0,0 +1,236 @@ +use crate::iter::plumbing::*; +use crate::iter::*; + +use super::{Iter, IterMut, ParallelSlice}; + +/// Parallel iterator over immutable non-overlapping chunks of a slice +#[derive(Debug)] +pub struct ArrayChunks<'data, T: Sync, const N: usize> { + iter: Iter<'data, [T; N]>, + rem: &'data [T], +} + +impl<'data, T: Sync, const N: usize> ArrayChunks<'data, T, N> { + pub(super) fn new(slice: &'data [T]) -> Self { + assert_ne!(N, 0); + let len = slice.len() / N; + let (fst, snd) = slice.split_at(len * N); + // SAFETY: We cast a slice of `len * N` elements into + // a slice of `len` many `N` elements chunks. + let array_slice: &'data [[T; N]] = unsafe { + let ptr = fst.as_ptr() as *const [T; N]; + ::std::slice::from_raw_parts(ptr, len) + }; + Self { + iter: array_slice.par_iter(), + rem: snd, + } + } + + /// Return the remainder of the original slice that is not going to be + /// returned by the iterator. The returned slice has at most `N-1` + /// elements. + pub fn remainder(&self) -> &'data [T] { + self.rem + } +} + +impl<'data, T: Sync, const N: usize> Clone for ArrayChunks<'data, T, N> { + fn clone(&self) -> Self { + ArrayChunks { + iter: self.iter.clone(), + rem: self.rem, + } + } +} + +impl<'data, T: Sync + 'data, const N: usize> ParallelIterator for ArrayChunks<'data, T, N> { + type Item = &'data [T; N]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl<'data, T: Sync + 'data, const N: usize> IndexedParallelIterator for ArrayChunks<'data, T, N> { + fn drive(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + self.iter.len() + } + + fn with_producer(self, callback: CB) -> CB::Output + where + CB: ProducerCallback, + { + self.iter.with_producer(callback) + } +} + +/// Parallel iterator over immutable non-overlapping chunks of a slice +#[derive(Debug)] +pub struct ArrayChunksMut<'data, T: Send, const N: usize> { + iter: IterMut<'data, [T; N]>, + rem: &'data mut [T], +} + +impl<'data, T: Send, const N: usize> ArrayChunksMut<'data, T, N> { + pub(super) fn new(slice: &'data mut [T]) -> Self { + assert_ne!(N, 0); + let len = slice.len() / N; + let (fst, snd) = slice.split_at_mut(len * N); + // SAFETY: We cast a slice of `len * N` elements into + // a slice of `len` many `N` elements chunks. + let array_slice: &'data mut [[T; N]] = unsafe { + let ptr = fst.as_mut_ptr() as *mut [T; N]; + ::std::slice::from_raw_parts_mut(ptr, len) + }; + Self { + iter: array_slice.par_iter_mut(), + rem: snd, + } + } + + /// Return the remainder of the original slice that is not going to be + /// returned by the iterator. The returned slice has at most `N-1` + /// elements. + /// + /// Note that this has to consume `self` to return the original lifetime of + /// the data, which prevents this from actually being used as a parallel + /// iterator since that also consumes. This method is provided for parity + /// with `std::iter::ArrayChunksMut`, but consider calling `remainder()` or + /// `take_remainder()` as alternatives. + pub fn into_remainder(self) -> &'data mut [T] { + self.rem + } + + /// Return the remainder of the original slice that is not going to be + /// returned by the iterator. The returned slice has at most `N-1` + /// elements. + /// + /// Consider `take_remainder()` if you need access to the data with its + /// original lifetime, rather than borrowing through `&mut self` here. + pub fn remainder(&mut self) -> &mut [T] { + self.rem + } + + /// Return the remainder of the original slice that is not going to be + /// returned by the iterator. The returned slice has at most `N-1` + /// elements. Subsequent calls will return an empty slice. + pub fn take_remainder(&mut self) -> &'data mut [T] { + std::mem::replace(&mut self.rem, &mut []) + } +} + +impl<'data, T: Send + 'data, const N: usize> ParallelIterator for ArrayChunksMut<'data, T, N> { + type Item = &'data mut [T; N]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl<'data, T: Send + 'data, const N: usize> IndexedParallelIterator + for ArrayChunksMut<'data, T, N> +{ + fn drive(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + self.iter.len() + } + + fn with_producer(self, callback: CB) -> CB::Output + where + CB: ProducerCallback, + { + self.iter.with_producer(callback) + } +} + +/// Parallel iterator over immutable overlapping windows of a slice +#[derive(Debug)] +pub struct ArrayWindows<'data, T: Sync, const N: usize> { + slice: &'data [T], +} + +impl<'data, T: Sync, const N: usize> ArrayWindows<'data, T, N> { + pub(super) fn new(slice: &'data [T]) -> Self { + ArrayWindows { slice } + } +} + +impl<'data, T: Sync, const N: usize> Clone for ArrayWindows<'data, T, N> { + fn clone(&self) -> Self { + ArrayWindows { ..*self } + } +} + +impl<'data, T: Sync + 'data, const N: usize> ParallelIterator for ArrayWindows<'data, T, N> { + type Item = &'data [T; N]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl<'data, T: Sync + 'data, const N: usize> IndexedParallelIterator for ArrayWindows<'data, T, N> { + fn drive(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + assert!(N >= 1); + self.slice.len().saturating_sub(N - 1) + } + + fn with_producer(self, callback: CB) -> CB::Output + where + CB: ProducerCallback, + { + fn array(slice: &[T]) -> &[T; N] { + debug_assert_eq!(slice.len(), N); + let ptr = slice.as_ptr() as *const [T; N]; + unsafe { &*ptr } + } + + // FIXME: use our own producer and the standard `array_windows`, rust-lang/rust#75027 + self.slice + .par_windows(N) + .map(array::) + .with_producer(callback) + } +} diff --git a/src/slice/mod.rs b/src/slice/mod.rs index dab56deb3..ff38c8a05 100644 --- a/src/slice/mod.rs +++ b/src/slice/mod.rs @@ -5,6 +5,7 @@ //! //! [std::slice]: https://doc.rust-lang.org/stable/std/slice/ +mod array; mod chunks; mod mergesort; mod quicksort; @@ -12,6 +13,8 @@ mod rchunks; mod test; +pub use self::array::{ArrayChunks, ArrayChunksMut, ArrayWindows}; + use self::mergesort::par_mergesort; use self::quicksort::par_quicksort; use crate::iter::plumbing::*; @@ -71,6 +74,20 @@ pub trait ParallelSlice { } } + /// Returns a parallel iterator over all contiguous array windows of + /// length `N`. The windows overlap. + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let windows: Vec<_> = [1, 2, 3].par_array_windows().collect(); + /// assert_eq!(vec![&[1, 2], &[2, 3]], windows); + /// ``` + fn par_array_windows(&self) -> ArrayWindows<'_, T, N> { + ArrayWindows::new(self.as_parallel_slice()) + } + /// Returns a parallel iterator over at most `chunk_size` elements of /// `self` at a time. The chunks do not overlap. /// @@ -150,6 +167,24 @@ pub trait ParallelSlice { assert!(chunk_size != 0, "chunk_size must not be zero"); RChunksExact::new(chunk_size, self.as_parallel_slice()) } + + /// Returns a parallel iterator over `N`-element chunks of + /// `self` at a time. The chunks do not overlap. + /// + /// If `N` does not divide the length of the slice, then the + /// last up to `N-1` elements will be omitted and can be + /// retrieved from the remainder function of the iterator. + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let chunks: Vec<_> = [1, 2, 3, 4, 5].par_array_chunks().collect(); + /// assert_eq!(chunks, vec![&[1, 2], &[3, 4]]); + /// ``` + fn par_array_chunks(&self) -> ArrayChunks<'_, T, N> { + ArrayChunks::new(self.as_parallel_slice()) + } } impl ParallelSlice for [T] { @@ -275,6 +310,26 @@ pub trait ParallelSliceMut { RChunksExactMut::new(chunk_size, self.as_parallel_slice_mut()) } + /// Returns a parallel iterator over `N`-element chunks of + /// `self` at a time. The chunks are mutable and do not overlap. + /// + /// If `N` does not divide the length of the slice, then the + /// last up to `N-1` elements will be omitted and can be + /// retrieved from the remainder function of the iterator. + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let mut array = [1, 2, 3, 4, 5]; + /// array.par_array_chunks_mut() + /// .for_each(|[a, _, b]| std::mem::swap(a, b)); + /// assert_eq!(array, [3, 2, 1, 4, 5]); + /// ``` + fn par_array_chunks_mut(&mut self) -> ArrayChunksMut<'_, T, N> { + ArrayChunksMut::new(self.as_parallel_slice_mut()) + } + /// Sorts the slice in parallel. /// /// This sort is stable (i.e., does not reorder equal elements) and *O*(*n* \* log(*n*)) worst-case. diff --git a/src/slice/test.rs b/src/slice/test.rs index f74ca0f74..b27b74842 100644 --- a/src/slice/test.rs +++ b/src/slice/test.rs @@ -168,3 +168,25 @@ fn test_par_rchunks_exact_mut_remainder() { assert_eq!(c.take_remainder(), &[]); assert_eq!(c.len(), 2); } + +#[test] +fn test_par_array_chunks_remainder() { + let v: &[i32] = &[0, 1, 2, 3, 4]; + let c = v.par_array_chunks::<2>(); + assert_eq!(c.remainder(), &[4]); + assert_eq!(c.len(), 2); +} + +#[test] +fn test_par_array_chunks_mut_remainder() { + let v: &mut [i32] = &mut [0, 1, 2, 3, 4]; + let mut c = v.par_array_chunks_mut::<2>(); + assert_eq!(c.remainder(), &[4]); + assert_eq!(c.len(), 2); + assert_eq!(c.into_remainder(), &[4]); + + let mut c = v.par_array_chunks_mut::<2>(); + assert_eq!(c.take_remainder(), &[4]); + assert_eq!(c.take_remainder(), &[]); + assert_eq!(c.len(), 2); +} diff --git a/tests/clones.rs b/tests/clones.rs index 0d6c86487..eba309342 100644 --- a/tests/clones.rs +++ b/tests/clones.rs @@ -111,7 +111,9 @@ fn clone_vec() { check(v.par_chunks_exact(42)); check(v.par_rchunks(42)); check(v.par_rchunks_exact(42)); + check(v.par_array_chunks::<42>()); check(v.par_windows(42)); + check(v.par_array_windows::<42>()); check(v.par_split(|x| x % 3 == 0)); check(v.into_par_iter()); } diff --git a/tests/debug.rs b/tests/debug.rs index 14f37917b..bf195f326 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -121,13 +121,16 @@ fn debug_vec() { check(v.par_iter_mut()); check(v.par_chunks(42)); check(v.par_chunks_exact(42)); + check(v.par_array_chunks::<42>()); check(v.par_chunks_mut(42)); check(v.par_chunks_exact_mut(42)); + check(v.par_array_chunks_mut::<42>()); check(v.par_rchunks(42)); check(v.par_rchunks_exact(42)); check(v.par_rchunks_mut(42)); check(v.par_rchunks_exact_mut(42)); check(v.par_windows(42)); + check(v.par_array_windows::<42>()); check(v.par_split(|x| x % 3 == 0)); check(v.par_split_mut(|x| x % 3 == 0)); check(v.par_drain(..)); diff --git a/tests/producer_split_at.rs b/tests/producer_split_at.rs index d71050492..74e6101ad 100644 --- a/tests/producer_split_at.rs +++ b/tests/producer_split_at.rs @@ -187,6 +187,36 @@ fn slice_chunks_exact() { } } +#[test] +fn slice_array_chunks() { + use std::convert::{TryFrom, TryInto}; + fn check_len(s: &[i32]) + where + for<'a> &'a [i32; N]: PartialEq + TryFrom<&'a [i32]> + std::fmt::Debug, + { + // TODO: use https://github.com/rust-lang/rust/pull/74373 instead. + let v: Vec<_> = s + .chunks_exact(N) + .map(|s| s.try_into().ok().unwrap()) + .collect(); + check(&v, || s.par_array_chunks::()); + } + + let s: Vec<_> = (0..10).collect(); + check_len::<1>(&s); + check_len::<2>(&s); + check_len::<3>(&s); + check_len::<4>(&s); + check_len::<5>(&s); + check_len::<6>(&s); + check_len::<7>(&s); + check_len::<8>(&s); + check_len::<9>(&s); + check_len::<10>(&s); + check_len::<11>(&s); + check_len::<12>(&s); +} + #[test] fn slice_chunks_mut() { let mut s: Vec<_> = (0..10).collect(); @@ -213,6 +243,40 @@ fn slice_chunks_exact_mut() { } } +#[test] +fn slice_array_chunks_mut() { + use std::convert::{TryFrom, TryInto}; + fn check_len(s: &mut [i32], v: &mut [i32]) + where + for<'a> &'a mut [i32; N]: PartialEq + TryFrom<&'a mut [i32]> + std::fmt::Debug, + { + // TODO: use https://github.com/rust-lang/rust/pull/74373 instead. + let expected: Vec<_> = v + .chunks_exact_mut(N) + .map(|s| s.try_into().ok().unwrap()) + .collect(); + map_triples(expected.len() + 1, |i, j, k| { + Split::forward(s.par_array_chunks_mut::(), i, j, k, &expected); + Split::reverse(s.par_array_chunks_mut::(), i, j, k, &expected); + }); + } + + let mut s: Vec<_> = (0..10).collect(); + let mut v: Vec<_> = s.clone(); + check_len::<1>(&mut s, &mut v); + check_len::<2>(&mut s, &mut v); + check_len::<3>(&mut s, &mut v); + check_len::<4>(&mut s, &mut v); + check_len::<5>(&mut s, &mut v); + check_len::<6>(&mut s, &mut v); + check_len::<7>(&mut s, &mut v); + check_len::<8>(&mut s, &mut v); + check_len::<9>(&mut s, &mut v); + check_len::<10>(&mut s, &mut v); + check_len::<11>(&mut s, &mut v); + check_len::<12>(&mut s, &mut v); +} + #[test] fn slice_rchunks() { let s: Vec<_> = (0..10).collect(); @@ -264,6 +328,15 @@ fn slice_windows() { check(&v, || s.par_windows(2)); } +#[test] +fn slice_array_windows() { + use std::convert::TryInto; + let s: Vec<_> = (0..10).collect(); + // FIXME: use the standard `array_windows`, rust-lang/rust#75027 + let v: Vec<&[_; 2]> = s.windows(2).map(|s| s.try_into().unwrap()).collect(); + check(&v, || s.par_array_windows::<2>()); +} + #[test] fn vec() { let v: Vec<_> = (0..10).collect();