Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MerkleTree, Segment, ColMatrix changes for GPU RPO support #185

Merged
merged 1 commit into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion crypto/src/merkle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ pub struct MerkleTree<H: Hasher> {
// ================================================================================================

impl<H: Hasher> MerkleTree<H> {
// CONSTRUCTOR
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------

/// Returns new Merkle tree built from the provide leaves using hash function specified by the
/// `H` generic parameter.
///
Expand Down Expand Up @@ -126,6 +127,31 @@ impl<H: Hasher> MerkleTree<H> {
Ok(MerkleTree { nodes, leaves })
}

/// Forms a MerkleTree from a list of nodes and leaves.
///
/// Nodes are supplied as a vector where the root is stored at position 1.
///
/// # Errors
/// Returns an error if:
/// * Fewer than two leaves were provided.
/// * Number of leaves is not a power of two.
///
/// # Panics
/// Panics if nodes doesn't have the same length as leaves.
pub fn from_raw_parts(
nodes: Vec<H::Digest>,
leaves: Vec<H::Digest>,
) -> Result<Self, MerkleTreeError> {
if leaves.len() < 2 {
return Err(MerkleTreeError::TooFewLeaves(2, leaves.len()));
}
if !leaves.len().is_power_of_two() {
return Err(MerkleTreeError::NumberOfLeavesNotPowerOfTwo(leaves.len()));
}
assert_eq!(nodes.len(), leaves.len());
Ok(MerkleTree { nodes, leaves })
}

// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use std::time::Instant;
mod domain;
pub use domain::StarkDomain;

mod matrix;
pub mod matrix;
pub use matrix::{ColMatrix, RowMatrix};

mod constraints;
Expand Down
11 changes: 11 additions & 0 deletions prover/src/matrix/col_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ impl<E: FieldElement> ColMatrix<E> {
}
}

/// Merges a column to the end of the matrix provided its length matches the matrix.
///
/// # Panics
/// Panics if the column has a different length to other columns in the matrix.
pub fn merge_column(&mut self, column: Vec<E>) {
if let Some(first_column) = self.columns.first() {
assert_eq!(first_column.len(), column.len());
}
self.columns.push(column);
}

// ITERATION
// --------------------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion prover/src/matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// LICENSE file in the root directory of this source tree.

mod row_matrix;
pub use row_matrix::RowMatrix;
pub use row_matrix::{get_evaluation_offsets, RowMatrix};

mod col_matrix;
pub use col_matrix::{ColMatrix, ColumnIter, MultiColumnIter};
Expand Down
10 changes: 6 additions & 4 deletions prover/src/matrix/row_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ impl<E: FieldElement> RowMatrix<E> {

// pre-compute offsets for each row
let poly_size = polys.num_rows();
let offsets = get_offsets::<E>(poly_size, blowup_factor, E::BaseField::GENERATOR);
let offsets =
get_evaluation_offsets::<E>(poly_size, blowup_factor, E::BaseField::GENERATOR);

// compute twiddles for polynomial evaluation
let twiddles = fft::get_twiddles::<E::BaseField>(polys.num_rows());
Expand Down Expand Up @@ -86,7 +87,8 @@ impl<E: FieldElement> RowMatrix<E> {

// pre-compute offsets for each row
let poly_size = polys.num_rows();
let offsets = get_offsets::<E>(poly_size, domain.trace_to_lde_blowup(), domain.offset());
let offsets =
get_evaluation_offsets::<E>(poly_size, domain.trace_to_lde_blowup(), domain.offset());

// build matrix segments by evaluating all polynomials
let segments = build_segments::<E, N>(polys, domain.trace_twiddles(), &offsets);
Expand Down Expand Up @@ -207,7 +209,7 @@ impl<E: FieldElement> RowMatrix<E> {
/// factor and domain offset.
///
/// When `concurrent` feature is enabled, offsets are computed in multiple threads.
fn get_offsets<E: FieldElement>(
pub fn get_evaluation_offsets<E: FieldElement>(
poly_size: usize,
blowup_factor: usize,
domain_offset: E::BaseField,
Expand Down Expand Up @@ -299,7 +301,7 @@ fn transpose<B: StarkField, const N: usize>(mut segments: Vec<Segment<B, N>>) ->
for i in 0..rows_per_batch {
let row_idx = i + row_offset;
for j in 0..num_segs {
let v = &segments[j].data()[row_idx];
let v = &segments[j][row_idx];
batch[i * num_segs + j].copy_from_slice(v);
}
}
Expand Down
79 changes: 61 additions & 18 deletions prover/src/matrix/segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// LICENSE file in the root directory of this source tree.

use super::ColMatrix;
use core::ops::Deref;
use math::{fft::fft_inputs::FftInputs, FieldElement, StarkField};
use utils::{collections::Vec, group_vector_elements, uninit_vector};

Expand Down Expand Up @@ -31,10 +32,11 @@ pub struct Segment<B: StarkField, const N: usize> {
}

impl<B: StarkField, const N: usize> Segment<B, N> {
// CONSTRUCTOR
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Instantiates a new [Segment] by evaluating polynomials from the provided [Matrix] starting
/// at the specified offset.

/// Instantiates a new [Segment] by evaluating polynomials from the provided [ColMatrix]
/// starting at the specified offset.
///
/// The offset is assumed to be an offset into the view of the matrix where extension field
/// elements are decomposed into base field elements. This offset must be compatible with the
Expand All @@ -54,11 +56,59 @@ impl<B: StarkField, const N: usize> Segment<B, N> {
{
let poly_size = polys.num_rows();
let domain_size = offsets.len();
assert!(domain_size.is_power_of_two());
assert!(domain_size > poly_size);
assert_eq!(poly_size, twiddles.len() * 2);
assert!(poly_offset < polys.num_base_cols());

// allocate memory for the segment
let data = if polys.num_base_cols() - poly_offset >= N {
// if we will fill the entire segment, we allocate uninitialized memory
unsafe { uninit_vector::<[B; N]>(domain_size) }
} else {
// but if some columns in the segment will remain unfilled, we allocate memory initialized
// to zeros to make sure we don't end up with memory with undefined values
group_vector_elements(B::zeroed_vector(N * domain_size))
};

Self::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
}

/// Instantiates a new [Segment] using the provided data buffer by evaluating polynomials in
/// the [ColMatrix] starting at the specified offset.
///
/// The offset is assumed to be an offset into the view of the matrix where extension field
/// elements are decomposed into base field elements. This offset must be compatible with the
/// values supplied into [Matrix::get_base_element()] method.
///
/// Evaluation is performed over the domain specified by the provided twiddles and offsets.
///
/// # Panics
/// Panics if:
/// - `poly_offset` greater than or equal to the number of base field columns in `polys`.
/// - Number of offsets is not a power of two.
/// - Number of offsets is smaller than or equal to the polynomial size.
/// - The number of twiddles is not half the size of the polynomial size.
/// - Number of offsets is smaller than the length of the data buffer
pub fn new_with_buffer<E>(
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
data_buffer: Vec<[B; N]>,
polys: &ColMatrix<E>,
poly_offset: usize,
offsets: &[B],
twiddles: &[B],
) -> Self
where
E: FieldElement<BaseField = B>,
{
let poly_size = polys.num_rows();
let domain_size = offsets.len();
let mut data = data_buffer;

assert!(domain_size.is_power_of_two());
assert!(domain_size > poly_size);
assert_eq!(poly_size, twiddles.len() * 2);
assert!(poly_offset < polys.num_base_cols());
assert_eq!(data.len(), domain_size);

// determine the number of polynomials to add to this segment; this number can be either N,
// or smaller than N when there are fewer than N polynomials remaining to be processed
Expand All @@ -69,16 +119,6 @@ impl<B: StarkField, const N: usize> Segment<B, N> {
N
};

// allocate memory for the segment
let mut data = if num_polys == N {
// if we will fill the entire segment, we allocate uninitialized memory
unsafe { uninit_vector::<[B; N]>(domain_size) }
} else {
// but if some columns in the segment will remain unfilled, we allocate memory initialized
// to zeros to make sure we don't end up with memory with undefined values
group_vector_elements(B::zeroed_vector(N * domain_size))
};

// evaluate the polynomials either in a single thread or multiple threads, depending
// on whether `concurrent` feature is enabled and domain size is greater than 1024;

Expand Down Expand Up @@ -122,11 +162,6 @@ impl<B: StarkField, const N: usize> Segment<B, N> {
self.data.len()
}

/// Returns the data in this segment as a slice of arrays.
pub fn data(&self) -> &[[B; N]] {
&self.data
}

/// Returns the underlying vector of arrays for this segment.
pub fn into_data(self) -> Vec<[B; N]> {
self.data
Expand Down Expand Up @@ -172,6 +207,14 @@ impl<B: StarkField, const N: usize> Segment<B, N> {
}
}

impl<B: StarkField, const N: usize> Deref for Segment<B, N> {
type Target = Vec<[B; N]>;

fn deref(&self) -> &Self::Target {
&self.data
}
}

// CONCURRENT FFT IMPLEMENTATION
// ================================================================================================

Expand Down