Skip to content

Commit

Permalink
Merge pull request #185 from andrewmilson/gpu-rpo-support
Browse files Browse the repository at this point in the history
MerkleTree, Segment, ColMatrix changes for GPU RPO support
  • Loading branch information
irakliyk committed Apr 16, 2023
2 parents 9b33ffe + f0a5ccb commit 7c1a56d
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 25 deletions.
28 changes: 27 additions & 1 deletion crypto/src/merkle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ pub struct MerkleTree<H: Hasher> {
// ================================================================================================

impl<H: Hasher> MerkleTree<H> {
// CONSTRUCTOR
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------

/// Returns new Merkle tree built from the provide leaves using hash function specified by the
/// `H` generic parameter.
///
Expand Down Expand Up @@ -126,6 +127,31 @@ impl<H: Hasher> MerkleTree<H> {
Ok(MerkleTree { nodes, leaves })
}

/// Forms a MerkleTree from a list of nodes and leaves.
///
/// Nodes are supplied as a vector where the root is stored at position 1.
///
/// # Errors
/// Returns an error if:
/// * Fewer than two leaves were provided.
/// * Number of leaves is not a power of two.
///
/// # Panics
/// Panics if nodes doesn't have the same length as leaves.
pub fn from_raw_parts(
nodes: Vec<H::Digest>,
leaves: Vec<H::Digest>,
) -> Result<Self, MerkleTreeError> {
if leaves.len() < 2 {
return Err(MerkleTreeError::TooFewLeaves(2, leaves.len()));
}
if !leaves.len().is_power_of_two() {
return Err(MerkleTreeError::NumberOfLeavesNotPowerOfTwo(leaves.len()));
}
assert_eq!(nodes.len(), leaves.len());
Ok(MerkleTree { nodes, leaves })
}

// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use std::time::Instant;
mod domain;
pub use domain::StarkDomain;

mod matrix;
pub mod matrix;
pub use matrix::{ColMatrix, RowMatrix};

mod constraints;
Expand Down
11 changes: 11 additions & 0 deletions prover/src/matrix/col_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ impl<E: FieldElement> ColMatrix<E> {
}
}

/// Merges a column to the end of the matrix provided its length matches the matrix.
///
/// # Panics
/// Panics if the column has a different length to other columns in the matrix.
pub fn merge_column(&mut self, column: Vec<E>) {
if let Some(first_column) = self.columns.first() {
assert_eq!(first_column.len(), column.len());
}
self.columns.push(column);
}

// ITERATION
// --------------------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion prover/src/matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// LICENSE file in the root directory of this source tree.

mod row_matrix;
pub use row_matrix::RowMatrix;
pub use row_matrix::{get_evaluation_offsets, RowMatrix};

mod col_matrix;
pub use col_matrix::{ColMatrix, ColumnIter, MultiColumnIter};
Expand Down
10 changes: 6 additions & 4 deletions prover/src/matrix/row_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ impl<E: FieldElement> RowMatrix<E> {

// pre-compute offsets for each row
let poly_size = polys.num_rows();
let offsets = get_offsets::<E>(poly_size, blowup_factor, E::BaseField::GENERATOR);
let offsets =
get_evaluation_offsets::<E>(poly_size, blowup_factor, E::BaseField::GENERATOR);

// compute twiddles for polynomial evaluation
let twiddles = fft::get_twiddles::<E::BaseField>(polys.num_rows());
Expand Down Expand Up @@ -86,7 +87,8 @@ impl<E: FieldElement> RowMatrix<E> {

// pre-compute offsets for each row
let poly_size = polys.num_rows();
let offsets = get_offsets::<E>(poly_size, domain.trace_to_lde_blowup(), domain.offset());
let offsets =
get_evaluation_offsets::<E>(poly_size, domain.trace_to_lde_blowup(), domain.offset());

// build matrix segments by evaluating all polynomials
let segments = build_segments::<E, N>(polys, domain.trace_twiddles(), &offsets);
Expand Down Expand Up @@ -207,7 +209,7 @@ impl<E: FieldElement> RowMatrix<E> {
/// factor and domain offset.
///
/// When `concurrent` feature is enabled, offsets are computed in multiple threads.
fn get_offsets<E: FieldElement>(
pub fn get_evaluation_offsets<E: FieldElement>(
poly_size: usize,
blowup_factor: usize,
domain_offset: E::BaseField,
Expand Down Expand Up @@ -299,7 +301,7 @@ fn transpose<B: StarkField, const N: usize>(mut segments: Vec<Segment<B, N>>) ->
for i in 0..rows_per_batch {
let row_idx = i + row_offset;
for j in 0..num_segs {
let v = &segments[j].data()[row_idx];
let v = &segments[j][row_idx];
batch[i * num_segs + j].copy_from_slice(v);
}
}
Expand Down
79 changes: 61 additions & 18 deletions prover/src/matrix/segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// LICENSE file in the root directory of this source tree.

use super::ColMatrix;
use core::ops::Deref;
use math::{fft::fft_inputs::FftInputs, FieldElement, StarkField};
use utils::{collections::Vec, group_vector_elements, uninit_vector};

Expand Down Expand Up @@ -31,10 +32,11 @@ pub struct Segment<B: StarkField, const N: usize> {
}

impl<B: StarkField, const N: usize> Segment<B, N> {
// CONSTRUCTOR
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Instantiates a new [Segment] by evaluating polynomials from the provided [Matrix] starting
/// at the specified offset.

/// Instantiates a new [Segment] by evaluating polynomials from the provided [ColMatrix]
/// starting at the specified offset.
///
/// The offset is assumed to be an offset into the view of the matrix where extension field
/// elements are decomposed into base field elements. This offset must be compatible with the
Expand All @@ -54,11 +56,59 @@ impl<B: StarkField, const N: usize> Segment<B, N> {
{
let poly_size = polys.num_rows();
let domain_size = offsets.len();
assert!(domain_size.is_power_of_two());
assert!(domain_size > poly_size);
assert_eq!(poly_size, twiddles.len() * 2);
assert!(poly_offset < polys.num_base_cols());

// allocate memory for the segment
let data = if polys.num_base_cols() - poly_offset >= N {
// if we will fill the entire segment, we allocate uninitialized memory
unsafe { uninit_vector::<[B; N]>(domain_size) }
} else {
// but if some columns in the segment will remain unfilled, we allocate memory initialized
// to zeros to make sure we don't end up with memory with undefined values
group_vector_elements(B::zeroed_vector(N * domain_size))
};

Self::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
}

/// Instantiates a new [Segment] using the provided data buffer by evaluating polynomials in
/// the [ColMatrix] starting at the specified offset.
///
/// The offset is assumed to be an offset into the view of the matrix where extension field
/// elements are decomposed into base field elements. This offset must be compatible with the
/// values supplied into [Matrix::get_base_element()] method.
///
/// Evaluation is performed over the domain specified by the provided twiddles and offsets.
///
/// # Panics
/// Panics if:
/// - `poly_offset` greater than or equal to the number of base field columns in `polys`.
/// - Number of offsets is not a power of two.
/// - Number of offsets is smaller than or equal to the polynomial size.
/// - The number of twiddles is not half the size of the polynomial size.
/// - Number of offsets is smaller than the length of the data buffer
pub fn new_with_buffer<E>(
data_buffer: Vec<[B; N]>,
polys: &ColMatrix<E>,
poly_offset: usize,
offsets: &[B],
twiddles: &[B],
) -> Self
where
E: FieldElement<BaseField = B>,
{
let poly_size = polys.num_rows();
let domain_size = offsets.len();
let mut data = data_buffer;

assert!(domain_size.is_power_of_two());
assert!(domain_size > poly_size);
assert_eq!(poly_size, twiddles.len() * 2);
assert!(poly_offset < polys.num_base_cols());
assert_eq!(data.len(), domain_size);

// determine the number of polynomials to add to this segment; this number can be either N,
// or smaller than N when there are fewer than N polynomials remaining to be processed
Expand All @@ -69,16 +119,6 @@ impl<B: StarkField, const N: usize> Segment<B, N> {
N
};

// allocate memory for the segment
let mut data = if num_polys == N {
// if we will fill the entire segment, we allocate uninitialized memory
unsafe { uninit_vector::<[B; N]>(domain_size) }
} else {
// but if some columns in the segment will remain unfilled, we allocate memory initialized
// to zeros to make sure we don't end up with memory with undefined values
group_vector_elements(B::zeroed_vector(N * domain_size))
};

// evaluate the polynomials either in a single thread or multiple threads, depending
// on whether `concurrent` feature is enabled and domain size is greater than 1024;

Expand Down Expand Up @@ -122,11 +162,6 @@ impl<B: StarkField, const N: usize> Segment<B, N> {
self.data.len()
}

/// Returns the data in this segment as a slice of arrays.
pub fn data(&self) -> &[[B; N]] {
&self.data
}

/// Returns the underlying vector of arrays for this segment.
pub fn into_data(self) -> Vec<[B; N]> {
self.data
Expand Down Expand Up @@ -172,6 +207,14 @@ impl<B: StarkField, const N: usize> Segment<B, N> {
}
}

impl<B: StarkField, const N: usize> Deref for Segment<B, N> {
type Target = Vec<[B; N]>;

fn deref(&self) -> &Self::Target {
&self.data
}
}

// CONCURRENT FFT IMPLEMENTATION
// ================================================================================================

Expand Down

0 comments on commit 7c1a56d

Please sign in to comment.