Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GKR-LogUp API #287

Merged
merged 7 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 59 additions & 15 deletions air/src/air/aux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,83 @@ use super::lagrange::LagrangeKernelRandElements;

/// Holds the randomly generated elements necessary to build the auxiliary trace.
///
/// Specifically, [`AuxRandElements`] currently supports 2 types of random elements:
/// Specifically, [`AuxRandElements`] currently supports 3 types of random elements:
/// - the ones needed to build the Lagrange kernel column (when using GKR to accelerate LogUp),
/// - the ones needed to build the "s" auxiliary column (when using GKR to accelerate LogUp),
/// - the ones needed to build all the other auxiliary columns
#[derive(Debug, Clone)]
pub struct AuxRandElements<E> {
rand_elements: Vec<E>,
lagrange: Option<LagrangeKernelRandElements<E>>,
gkr: Option<GkrRandElements<E>>,
}

impl<E> AuxRandElements<E> {
/// Creates a new [`AuxRandElements`], where the auxiliary trace doesn't contain a Lagrange
/// kernel column.
pub fn new(rand_elements: Vec<E>) -> Self {
Self { rand_elements, lagrange: None }
Self { rand_elements, gkr: None }
}

/// Creates a new [`AuxRandElements`], where the auxiliary trace contains a Lagrange kernel
/// column.
pub fn new_with_lagrange(
rand_elements: Vec<E>,
lagrange: Option<LagrangeKernelRandElements<E>>,
) -> Self {
Self { rand_elements, lagrange }
/// Creates a new [`AuxRandElements`], where the auxiliary trace contains columns needed when
/// using GKR to accelerate LogUp (i.e. a Lagrange kernel column and the "s" column).
pub fn new_with_gkr(rand_elements: Vec<E>, gkr: GkrRandElements<E>) -> Self {
Self { rand_elements, gkr: Some(gkr) }
}

/// Returns the random elements needed to build all columns other than the Lagrange kernel one.
/// Returns the random elements needed to build all columns other than the two GKR-related ones.
pub fn rand_elements(&self) -> &[E] {
&self.rand_elements
}

/// Returns the random elements needed to build the Lagrange kernel column.
pub fn lagrange(&self) -> Option<&LagrangeKernelRandElements<E>> {
self.lagrange.as_ref()
self.gkr.as_ref().map(|gkr| &gkr.lagrange)
}

/// Returns the random values used to linearly combine the openings returned from the GKR proof.
///
/// These correspond to the lambdas in our documentation.
pub fn gkr_openings_combining_randomness(&self) -> Option<&[E]> {
self.gkr.as_ref().map(|gkr| gkr.openings_combining_randomness.as_ref())
}
}

/// Holds all the random elements needed when using GKR to accelerate LogUp.
///
/// This consists of two sets of random values:
/// 1. The Lagrange kernel random elements (expanded on in [`LagrangeKernelRandElements`]), and
/// 2. The "openings combining randomness".
///
/// After the verifying the LogUp-GKR circuit, the verifier is left with unproven claims provided
/// nondeterministically by the prover about the evaluations of the MLE of the main trace columns at
/// the Lagrange kernel random elements. Those claims are (linearly) combined into one using the
/// openings combining randomness.
#[derive(Clone, Debug)]
pub struct GkrRandElements<E> {
lagrange: LagrangeKernelRandElements<E>,
openings_combining_randomness: Vec<E>,
}

impl<E> GkrRandElements<E> {
/// Constructs a new [`GkrRandElements`] from [`LagrangeKernelRandElements`], and the openings
/// combining randomness.
///
/// See [`GkrRandElements`] for a more detailed description.
pub fn new(
plafer marked this conversation as resolved.
Show resolved Hide resolved
lagrange: LagrangeKernelRandElements<E>,
openings_combining_randomness: Vec<E>,
) -> Self {
Self { lagrange, openings_combining_randomness }
}

/// Returns the random elements needed to build the Lagrange kernel column.
pub fn lagrange_kernel_rand_elements(&self) -> &LagrangeKernelRandElements<E> {
&self.lagrange
}

/// Returns the random values used to linearly combine the openings returned from the GKR proof.
pub fn openings_combining_randomness(&self) -> &[E] {
&self.openings_combining_randomness
}
}

Expand All @@ -61,7 +105,7 @@ pub trait GkrVerifier {
&self,
gkr_proof: Self::GkrProof,
public_coin: &mut impl RandomCoin<BaseField = E::BaseField, Hasher = Hasher>,
) -> Result<LagrangeKernelRandElements<E>, Self::Error>
) -> Result<GkrRandElements<E>, Self::Error>
where
E: FieldElement,
Hasher: ElementHasher<BaseField = E::BaseField>;
Expand All @@ -75,11 +119,11 @@ impl GkrVerifier for () {
&self,
_gkr_proof: Self::GkrProof,
_public_coin: &mut impl RandomCoin<BaseField = E::BaseField, Hasher = Hasher>,
) -> Result<LagrangeKernelRandElements<E>, Self::Error>
) -> Result<GkrRandElements<E>, Self::Error>
where
E: FieldElement,
Hasher: ElementHasher<BaseField = E::BaseField>,
{
Ok(LagrangeKernelRandElements::new(Vec::new()))
Ok(GkrRandElements::new(LagrangeKernelRandElements::default(), Vec::new()))
}
}
6 changes: 5 additions & 1 deletion air/src/air/lagrange/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ impl<E: FieldElement> LagrangeKernelConstraints<E> {
}

/// Holds the randomly generated elements needed to build the Lagrange kernel auxiliary column.
#[derive(Debug, Clone)]
///
/// The Lagrange kernel consists of evaluating the function $eq(x, r)$, where $x$ is the binary
/// decomposition of the row index, and $r$ is some random point. The "Lagrange kernel random
/// elements" refer to this (multidimensional) point $r$.
#[derive(Debug, Clone, Default)]
pub struct LagrangeKernelRandElements<E> {
elements: Vec<E>,
}
Expand Down
10 changes: 5 additions & 5 deletions air/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use math::{fft, ExtensibleField, ExtensionOf, FieldElement, StarkField, ToElemen
use crate::ProofOptions;

mod aux;
pub use aux::{AuxRandElements, GkrVerifier};
pub use aux::{AuxRandElements, GkrRandElements, GkrVerifier};

mod trace_info;
pub use trace_info::TraceInfo;
Expand Down Expand Up @@ -269,7 +269,7 @@ pub trait Air: Send + Sync {
main_frame: &EvaluationFrame<F>,
aux_frame: &EvaluationFrame<E>,
periodic_values: &[F],
aux_rand_elements: &[E],
aux_rand_elements: &AuxRandElements<E>,
result: &mut [E],
) where
F: FieldElement<BaseField = Self::BaseField>,
Expand Down Expand Up @@ -298,7 +298,7 @@ pub trait Air: Send + Sync {
#[allow(unused_variables)]
fn get_aux_assertions<E: FieldElement<BaseField = Self::BaseField>>(
&self,
aux_rand_elements: &[E],
aux_rand_elements: &AuxRandElements<E>,
) -> Vec<Assertion<E>> {
Vec::new()
}
Expand All @@ -309,7 +309,7 @@ pub trait Air: Send + Sync {
/// Returns the [`GkrVerifier`] to be used to verify the GKR proof.
///
/// Leave unimplemented if the `Air` doesn't use a GKR proof.
fn get_auxiliary_proof_verifier<E: FieldElement<BaseField = Self::BaseField>>(
fn get_gkr_proof_verifier<E: FieldElement<BaseField = Self::BaseField>>(
&self,
) -> Self::GkrVerifier {
unimplemented!("`get_auxiliary_proof_verifier()` must be implemented when the proof contains a GKR proof");
Expand Down Expand Up @@ -422,7 +422,7 @@ pub trait Air: Send + Sync {
/// combination of boundary constraints during constraint merging.
fn get_boundary_constraints<E: FieldElement<BaseField = Self::BaseField>>(
&self,
aux_rand_elements: Option<&[E]>,
aux_rand_elements: Option<&AuxRandElements<E>>,
composition_coefficients: &[E],
) -> BoundaryConstraints<E> {
BoundaryConstraints::new(
Expand Down
2 changes: 1 addition & 1 deletion air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mod air;
pub use air::{
Air, AirContext, Assertion, AuxRandElements, BoundaryConstraint, BoundaryConstraintGroup,
BoundaryConstraints, ConstraintCompositionCoefficients, ConstraintDivisor,
DeepCompositionCoefficients, EvaluationFrame, GkrVerifier,
DeepCompositionCoefficients, EvaluationFrame, GkrRandElements, GkrVerifier,
LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint,
LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements,
LagrangeKernelTransitionConstraints, TraceInfo, TransitionConstraintDegree,
Expand Down
8 changes: 5 additions & 3 deletions examples/src/rescue_raps/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use core_utils::flatten_slice_elements;
use winterfell::{
math::ToElements, Air, AirContext, Assertion, EvaluationFrame, TraceInfo,
math::ToElements, Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, TraceInfo,
TransitionConstraintDegree,
};

Expand Down Expand Up @@ -162,7 +162,7 @@ impl Air for RescueRapsAir {
main_frame: &EvaluationFrame<F>,
aux_frame: &EvaluationFrame<E>,
periodic_values: &[F],
aux_rand_elements: &[E],
aux_rand_elements: &AuxRandElements<E>,
result: &mut [E],
) where
F: FieldElement<BaseField = Self::BaseField>,
Expand All @@ -174,6 +174,8 @@ impl Air for RescueRapsAir {
let aux_current = aux_frame.current();
let aux_next = aux_frame.next();

let aux_rand_elements = aux_rand_elements.rand_elements();

let absorption_flag = periodic_values[1];

// We want to enforce that the absorbed values of the first hash chain are a
Expand Down Expand Up @@ -233,7 +235,7 @@ impl Air for RescueRapsAir {
]
}

fn get_aux_assertions<E>(&self, _aux_rand_elements: &[E]) -> Vec<Assertion<E>>
fn get_aux_assertions<E>(&self, _aux_rand_elements: &AuxRandElements<E>) -> Vec<Assertion<E>>
where
E: FieldElement<BaseField = Self::BaseField>,
{
Expand Down
16 changes: 8 additions & 8 deletions prover/benches/lagrange_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::time::Duration;

use air::{
Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients,
EvaluationFrame, FieldExtension, LagrangeKernelRandElements, ProofOptions, TraceInfo,
TransitionConstraintDegree,
EvaluationFrame, FieldExtension, GkrRandElements, LagrangeKernelRandElements, ProofOptions,
TraceInfo, TransitionConstraintDegree,
};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use crypto::{hashers::Blake3_256, DefaultRandomCoin, RandomCoin};
Expand Down Expand Up @@ -144,7 +144,7 @@ impl Air for LagrangeKernelAir {
_main_frame: &EvaluationFrame<F>,
_aux_frame: &EvaluationFrame<E>,
_periodic_values: &[F],
_aux_rand_elements: &[E],
_aux_rand_elements: &AuxRandElements<E>,
_result: &mut [E],
) where
F: FieldElement<BaseField = Self::BaseField>,
Expand All @@ -155,7 +155,7 @@ impl Air for LagrangeKernelAir {

fn get_aux_assertions<E: FieldElement<BaseField = Self::BaseField>>(
&self,
_aux_rand_elements: &[E],
_aux_rand_elements: &AuxRandElements<E>,
) -> Vec<Assertion<E>> {
vec![Assertion::single(1, 0, E::ZERO)]
}
Expand Down Expand Up @@ -223,22 +223,22 @@ impl Prover for LagrangeProver {
&self,
main_trace: &Self::Trace,
public_coin: &mut Self::RandomCoin,
) -> (ProverGkrProof<Self>, LagrangeKernelRandElements<E>)
) -> (ProverGkrProof<Self>, GkrRandElements<E>)
where
E: FieldElement<BaseField = Self::BaseField>,
{
let main_trace = main_trace.main_segment();
let lagrange_kernel_rand_elements: Vec<E> = {
let lagrange_kernel_rand_elements = {
let log_trace_len = main_trace.num_rows().ilog2() as usize;
let mut rand_elements = Vec::with_capacity(log_trace_len);
for _ in 0..log_trace_len {
rand_elements.push(public_coin.draw().unwrap());
}

rand_elements
LagrangeKernelRandElements::new(rand_elements)
};

((), LagrangeKernelRandElements::new(lagrange_kernel_rand_elements))
((), GkrRandElements::new(lagrange_kernel_rand_elements, Vec::new()))
}

fn build_aux_trace<E>(
Expand Down
4 changes: 2 additions & 2 deletions prover/src/constraints/evaluator/boundary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use alloc::{collections::BTreeMap, vec::Vec};

use air::{Air, ConstraintDivisor};
use air::{Air, AuxRandElements, ConstraintDivisor};
use math::{fft, ExtensionOf, FieldElement};

use super::StarkDomain;
Expand Down Expand Up @@ -35,7 +35,7 @@ impl<E: FieldElement> BoundaryConstraints<E> {
/// by an instance of AIR for a specific computation.
pub fn new<A: Air<BaseField = E::BaseField>>(
air: &A,
aux_rand_elements: Option<&[E]>,
aux_rand_elements: Option<&AuxRandElements<E>>,
composition_coefficients: &[E],
) -> Self {
// get constraints from the AIR instance
Expand Down
7 changes: 2 additions & 5 deletions prover/src/constraints/evaluator/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ where
// constraint evaluations.
let boundary_constraints = BoundaryConstraints::new(
air,
aux_rand_elements
.as_ref()
.map(|aux_rand_elements| aux_rand_elements.rand_elements()),
aux_rand_elements.as_ref(),
&composition_coefficients.boundary,
);

Expand Down Expand Up @@ -378,8 +376,7 @@ where
periodic_values,
self.aux_rand_elements
.as_ref()
.expect("expected aux rand elements to be present")
.rand_elements(),
.expect("expected aux rand elements to be present"),
evaluations,
);

Expand Down
28 changes: 15 additions & 13 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
#[macro_use]
extern crate alloc;

use air::AuxRandElements;
pub use air::{
proof, proof::Proof, Air, AirContext, Assertion, BoundaryConstraint, BoundaryConstraintGroup,
ConstraintCompositionCoefficients, ConstraintDivisor, DeepCompositionCoefficients,
EvaluationFrame, FieldExtension, LagrangeKernelRandElements, ProofOptions, TraceInfo,
TransitionConstraintDegree,
};
use air::{AuxRandElements, GkrRandElements};
pub use crypto;
use crypto::{ElementHasher, RandomCoin};
use fri::FriProver;
Expand Down Expand Up @@ -205,7 +205,7 @@ pub trait Prover {
&self,
main_trace: &Self::Trace,
public_coin: &mut Self::RandomCoin,
) -> (ProverGkrProof<Self>, LagrangeKernelRandElements<E>)
) -> (ProverGkrProof<Self>, GkrRandElements<E>)
where
E: FieldElement<BaseField = Self::BaseField>,
{
Expand Down Expand Up @@ -280,7 +280,7 @@ pub trait Prover {
let pub_inputs = self.get_pub_inputs(&trace);
let pub_inputs_elements = pub_inputs.to_elements();

// create an instance of AIR for the provided parameters. this takes a generic description
// create an instance of AIR for the provided parameters. This takes a generic description
// of the computation (provided via AIR type), and creates a description of a specific
// execution of the computation for the provided public inputs.
let air = Self::Air::new(trace.info().clone(), pub_inputs, self.options().clone());
Expand Down Expand Up @@ -310,22 +310,24 @@ pub trait Prover {
// build the auxiliary trace segment, and append the resulting segments to trace commitment
// and trace polynomial table structs
let aux_trace_with_metadata = if air.trace_info().is_multi_segment() {
let (gkr_proof, lagrange_rand_elements) =
if air.context().has_lagrange_kernel_aux_column() {
let (gkr_proof, lagrange_rand_elements) =
maybe_await!(self.generate_gkr_proof(&trace, channel.public_coin()));
let (gkr_proof, aux_rand_elements) = if air.context().has_lagrange_kernel_aux_column() {
let (gkr_proof, gkr_rand_elements) =
maybe_await!(self.generate_gkr_proof(&trace, channel.public_coin()));

(Some(gkr_proof), Some(lagrange_rand_elements))
} else {
(None, None)
};
let rand_elements = air
.get_aux_rand_elements(channel.public_coin())
.expect("failed to draw random elements for the auxiliary trace segment");

let aux_rand_elements =
AuxRandElements::new_with_gkr(rand_elements, gkr_rand_elements);

let aux_rand_elements = {
(Some(gkr_proof), aux_rand_elements)
} else {
let rand_elements = air
.get_aux_rand_elements(channel.public_coin())
.expect("failed to draw random elements for the auxiliary trace segment");

AuxRandElements::new_with_lagrange(rand_elements, lagrange_rand_elements)
(None, AuxRandElements::new(rand_elements))
};

let aux_trace = maybe_await!(self.build_aux_trace(&trace, &aux_rand_elements));
Expand Down
Loading