Skip to content

Commit

Permalink
Add const-generic par_array_chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Aug 22, 2022
1 parent c00b997 commit 4a18b9b
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 0 deletions.
82 changes: 82 additions & 0 deletions src/slice/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#![cfg(has_min_const_generics)]

use crate::iter::plumbing::*;
use crate::iter::*;

use super::Iter;

/// Parallel iterator over immutable non-overlapping chunks of a slice
#[derive(Debug)]
pub struct ArrayChunks<'data, T: Sync, const N: usize> {
iter: Iter<'data, [T; N]>,
rem: &'data [T],
}

impl<'data, T: Sync, const N: usize> ArrayChunks<'data, T, N> {
pub(super) fn new(slice: &'data [T]) -> Self {
assert_ne!(N, 0);
let len = slice.len() / N;
let (fst, snd) = slice.split_at(len * N);
// SAFETY: We cast a slice of `len * N` elements into
// a slice of `len` many `N` elements chunks.
let array_slice: &'data [[T; N]] = unsafe {
let ptr = fst.as_ptr() as *const [T; N];
::std::slice::from_raw_parts(ptr, len)
};
Self {
iter: array_slice.par_iter(),
rem: snd,
}
}

/// Return the remainder of the original slice that is not going to be
/// returned by the iterator. The returned slice has at most `N-1`
/// elements.
pub fn remainder(&self) -> &'data [T] {
self.rem
}
}

impl<'data, T: Sync, const N: usize> Clone for ArrayChunks<'data, T, N> {
fn clone(&self) -> Self {
ArrayChunks {
iter: self.iter.clone(),
rem: self.rem,
}
}
}

impl<'data, T: Sync + 'data, const N: usize> ParallelIterator for ArrayChunks<'data, T, N> {
type Item = &'data [T; N];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}

impl<'data, T: Sync + 'data, const N: usize> IndexedParallelIterator for ArrayChunks<'data, T, N> {
fn drive<C>(self, consumer: C) -> C::Result
where
C: Consumer<Self::Item>,
{
bridge(self, consumer)
}

fn len(&self) -> usize {
self.iter.len()
}

fn with_producer<CB>(self, callback: CB) -> CB::Output
where
CB: ProducerCallback<Self::Item>,
{
self.iter.with_producer(callback)
}
}
23 changes: 23 additions & 0 deletions src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
//!
//! [std::slice]: https://doc.rust-lang.org/stable/std/slice/

mod array;
mod chunks;
mod mergesort;
mod quicksort;
mod rchunks;

mod test;

#[cfg(min_const_generics)]
pub use self::array::ArrayChunks;

use self::mergesort::par_mergesort;
use self::quicksort::par_quicksort;
use crate::iter::plumbing::*;
Expand Down Expand Up @@ -146,6 +150,25 @@ pub trait ParallelSlice<T: Sync> {
assert!(chunk_size != 0, "chunk_size must not be zero");
RChunksExact::new(chunk_size, self.as_parallel_slice())
}

/// Returns a parallel iterator over `N`-element chunks of
/// `self` at a time. The chunks do not overlap.
///
/// If `N` does not divide the length of the slice, then the
/// last up to `N-1` elements will be omitted and can be
/// retrieved from the remainder function of the iterator.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let chunks: Vec<_> = [1, 2, 3, 4, 5].par_array_chunks().collect();
/// assert_eq!(chunks, vec![&[1, 2], &[3, 4]]);
/// ```
#[cfg(has_min_const_generics)]
fn par_array_chunks<const N: usize>(&self) -> ArrayChunks<'_, T, N> {
ArrayChunks::new(self.as_parallel_slice())
}
}

impl<T: Sync> ParallelSlice<T> for [T] {
Expand Down
8 changes: 8 additions & 0 deletions src/slice/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,11 @@ fn test_par_rchunks_exact_mut_remainder() {
assert_eq!(c.take_remainder(), &[]);
assert_eq!(c.len(), 2);
}

#[test]
fn test_par_array_chunks_remainder() {
let v: &[i32] = &[0, 1, 2, 3, 4];
let c = v.par_array_chunks::<2>();
assert_eq!(c.remainder(), &[4]);
assert_eq!(c.len(), 2);
}
1 change: 1 addition & 0 deletions tests/clones.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ fn clone_vec() {
check(v.par_chunks_exact(42));
check(v.par_rchunks(42));
check(v.par_rchunks_exact(42));
check(v.par_array_chunks::<42>());
check(v.par_windows(42));
check(v.par_split(|x| x % 3 == 0));
check(v.into_par_iter());
Expand Down
1 change: 1 addition & 0 deletions tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ fn debug_vec() {
check(v.par_iter_mut());
check(v.par_chunks(42));
check(v.par_chunks_exact(42));
check(v.par_array_chunks::<42>());
check(v.par_chunks_mut(42));
check(v.par_chunks_exact_mut(42));
check(v.par_rchunks(42));
Expand Down
30 changes: 30 additions & 0 deletions tests/producer_split_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,36 @@ fn slice_chunks_exact() {
}
}

#[test]
fn slice_array_chunks() {
use std::convert::{TryFrom, TryInto};
fn check_len<const N: usize>(s: &[i32])
where
for<'a> &'a [i32; N]: PartialEq + TryFrom<&'a [i32]> + std::fmt::Debug,
{
// TODO: use https://github.com/rust-lang/rust/pull/74373 instead.
let v: Vec<_> = s
.chunks_exact(N)
.map(|s| s.try_into().ok().unwrap())
.collect();
check(&v, || s.par_array_chunks::<N>());
}

let s: Vec<_> = (0..10).collect();
check_len::<1>(&s);
check_len::<2>(&s);
check_len::<3>(&s);
check_len::<4>(&s);
check_len::<5>(&s);
check_len::<6>(&s);
check_len::<7>(&s);
check_len::<8>(&s);
check_len::<9>(&s);
check_len::<10>(&s);
check_len::<11>(&s);
check_len::<12>(&s);
}

#[test]
fn slice_chunks_mut() {
let mut s: Vec<_> = (0..10).collect();
Expand Down

0 comments on commit 4a18b9b

Please sign in to comment.