Skip to content

Commit

Permalink
Improve performance of set_bits by avoiding to set individual bits (#…
Browse files Browse the repository at this point in the history
…6288)

* bench

* fix: Optimize set_bits

* clippy

* clippyj

* miri

* fix: Optimize set_bits

* fix: Optimize set_bits

* fix: Optimize set_bits

* fix: Optimize set_bits

* fix: Optimize set_bits

* fix: Optimize set_bits

* fix: Optimize set_bits

* fix: Optimize set_bits

* miri

* miri

* miri

* miri

* miri

* miri

* miri

* miri

* miri

* miri

* miri

* address review comments

* address review comments

* address review comments

* Revert "address review comments"

This reverts commit ef2864f.

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* Revert "address review comments"

This reverts commit a15db14.

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments

* address review comments
  • Loading branch information
kazuyukitanimura committed Sep 15, 2024
1 parent ba85fa3 commit b4de692
Showing 1 changed file with 158 additions and 26 deletions.
184 changes: 158 additions & 26 deletions arrow-buffer/src/util/bit_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,144 @@

//! Utils for working with packed bit masks

use crate::bit_chunk_iterator::BitChunks;
use crate::bit_util::{ceil, get_bit, set_bit};
use crate::bit_util::ceil;

/// Sets all bits on `write_data` in the range `[offset_write..offset_write+len]` to be equal to the
/// bits in `data` in the range `[offset_read..offset_read+len]`
/// returns the number of `0` bits `data[offset_read..offset_read+len]`
/// `offset_write`, `offset_read`, and `len` are in terms of bits
pub fn set_bits(
write_data: &mut [u8],
data: &[u8],
offset_write: usize,
offset_read: usize,
len: usize,
) -> usize {
assert!(offset_write + len <= write_data.len() * 8);
assert!(offset_read + len <= data.len() * 8);
let mut null_count = 0;

let mut bits_to_align = offset_write % 8;
if bits_to_align > 0 {
bits_to_align = std::cmp::min(len, 8 - bits_to_align);
let mut acc = 0;
while len > acc {
// SAFETY: the arguments to `set_upto_64bits` are within the valid range because
// (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8
// (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8
let (n, len_set) = unsafe {
set_upto_64bits(
write_data,
data,
offset_write + acc,
offset_read + acc,
len - acc,
)
};
null_count += n;
acc += len_set;
}
let mut write_byte_index = ceil(offset_write + bits_to_align, 8);

// Set full bytes provided by bit chunk iterator (which iterates in 64 bits at a time)
let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align);
chunks.iter().for_each(|chunk| {
null_count += chunk.count_zeros();
write_data[write_byte_index..write_byte_index + 8].copy_from_slice(&chunk.to_le_bytes());
write_byte_index += 8;
});

// Set individual bits both to align write_data to a byte offset and the remainder bits not covered by the bit chunk iterator
let remainder_offset = len - chunks.remainder_len();
(0..bits_to_align)
.chain(remainder_offset..len)
.for_each(|i| {
if get_bit(data, offset_read + i) {
set_bit(write_data, offset_write + i);

null_count
}

/// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary.
/// Returns a pair of the number of `0` bits and the number of bits set
///
/// # Safety
/// The caller must ensure all arguments are within the valid range.
#[inline]
unsafe fn set_upto_64bits(
write_data: &mut [u8],
data: &[u8],
offset_write: usize,
offset_read: usize,
len: usize,
) -> (usize, usize) {
let read_byte = offset_read / 8;
let read_shift = offset_read % 8;
let write_byte = offset_write / 8;
let write_shift = offset_write % 8;

if len >= 64 {
let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() };
if read_shift == 0 {
if write_shift == 0 {
// no shifting necessary
let len = 64;
let null_count = chunk.count_zeros() as usize;
unsafe { write_u64_bytes(write_data, write_byte, chunk) };
(null_count, len)
} else {
null_count += 1;
// only write shifting necessary
let len = 64 - write_shift;
let chunk = chunk << write_shift;
let null_count = len - chunk.count_ones() as usize;
unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
(null_count, len)
}
});
} else if write_shift == 0 {
// only read shifting necessary
let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0
let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask
let null_count = len - chunk.count_ones() as usize;
unsafe { write_u64_bytes(write_data, write_byte, chunk) };
(null_count, len)
} else {
let len = 64 - std::cmp::max(read_shift, write_shift);
let chunk = (chunk >> read_shift) << write_shift;
let null_count = len - chunk.count_ones() as usize;
unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
(null_count, len)
}
} else if len == 1 {
let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1;
unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift };
((byte_chunk ^ 1) as usize, 1)
} else {
let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift));
let bytes = ceil(len + read_shift, 8);
// SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len()
let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
let mask = u64::MAX >> (64 - len);
let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only
let chunk = chunk << write_shift; // shifting back to align with `write_data`
let null_count = len - chunk.count_ones() as usize;
let bytes = ceil(len + write_shift, 8);
for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c };
}
(null_count, len)
}
}

null_count as usize
/// # Safety
/// The caller must ensure all arguments are within the valid range.
#[inline]
unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 {
debug_assert!(count <= 8);
let mut tmp = std::mem::MaybeUninit::<u64>::new(0);
let src = data.as_ptr().add(offset);
unsafe {
std::ptr::copy_nonoverlapping(src, tmp.as_mut_ptr() as *mut u8, count);
tmp.assume_init()
}
}

/// # Safety
/// The caller must ensure `data` has `offset..(offset + 8)` range
#[inline]
unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
let ptr = data.as_mut_ptr().add(offset) as *mut u64;
ptr.write_unaligned(chunk);
}

/// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk`
/// instead of overwriting
///
/// # Safety
/// The caller must ensure `data` has `offset..(offset + 8)` range
#[inline]
unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
let ptr = data.as_mut_ptr().add(offset);
let chunk = chunk | (*ptr) as u64;
(ptr as *mut u64).write_unaligned(chunk);
}

#[cfg(test)]
Expand Down Expand Up @@ -185,4 +281,40 @@ mod tests {
assert_eq!(destination, expected_data);
assert_eq!(result, expected_null_count);
}

#[test]
fn test_set_upto_64bits() {
// len >= 64
let write_data: &mut [u8] = &mut [0; 9];
let data: &[u8] = &[
0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001,
0b00000001, 0b00000001,
];
let offset_write = 1;
let offset_read = 0;
let len = 65;
let (n, len_set) =
unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
assert_eq!(n, 55);
assert_eq!(len_set, 63);
assert_eq!(
write_data,
&[
0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010,
0b00000010, 0b00000000
]
);

// len = 1
let write_data: &mut [u8] = &mut [0b00000000];
let data: &[u8] = &[0b00000001];
let offset_write = 1;
let offset_read = 0;
let len = 1;
let (n, len_set) =
unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
assert_eq!(n, 0);
assert_eq!(len_set, 1);
assert_eq!(write_data, &[0b00000010]);
}
}

0 comments on commit b4de692

Please sign in to comment.