Skip to content

Commit

Permalink
Fix GKR-LogUp API (#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer committed Jun 24, 2024
1 parent ff5496b commit 00f2579
Show file tree
Hide file tree
Showing 14 changed files with 122 additions and 75 deletions.
74 changes: 59 additions & 15 deletions air/src/air/aux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,83 @@ use super::lagrange::LagrangeKernelRandElements;

/// Holds the randomly generated elements necessary to build the auxiliary trace.
///
/// Specifically, [`AuxRandElements`] currently supports 2 types of random elements:
/// Specifically, [`AuxRandElements`] currently supports 3 types of random elements:
/// - the ones needed to build the Lagrange kernel column (when using GKR to accelerate LogUp),
/// - the ones needed to build the "s" auxiliary column (when using GKR to accelerate LogUp),
/// - the ones needed to build all the other auxiliary columns
#[derive(Debug, Clone)]
pub struct AuxRandElements<E> {
rand_elements: Vec<E>,
lagrange: Option<LagrangeKernelRandElements<E>>,
gkr: Option<GkrRandElements<E>>,
}

impl<E> AuxRandElements<E> {
/// Creates a new [`AuxRandElements`], where the auxiliary trace doesn't contain a Lagrange
/// kernel column.
pub fn new(rand_elements: Vec<E>) -> Self {
Self { rand_elements, lagrange: None }
Self { rand_elements, gkr: None }
}

/// Creates a new [`AuxRandElements`], where the auxiliary trace contains a Lagrange kernel
/// column.
pub fn new_with_lagrange(
rand_elements: Vec<E>,
lagrange: Option<LagrangeKernelRandElements<E>>,
) -> Self {
Self { rand_elements, lagrange }
/// Creates a new [`AuxRandElements`], where the auxiliary trace contains columns needed when
/// using GKR to accelerate LogUp (i.e. a Lagrange kernel column and the "s" column).
pub fn new_with_gkr(rand_elements: Vec<E>, gkr: GkrRandElements<E>) -> Self {
Self { rand_elements, gkr: Some(gkr) }
}

/// Returns the random elements needed to build all columns other than the Lagrange kernel one.
/// Returns the random elements needed to build all columns other than the two GKR-related ones.
pub fn rand_elements(&self) -> &[E] {
&self.rand_elements
}

/// Returns the random elements needed to build the Lagrange kernel column.
pub fn lagrange(&self) -> Option<&LagrangeKernelRandElements<E>> {
self.lagrange.as_ref()
self.gkr.as_ref().map(|gkr| &gkr.lagrange)
}

/// Returns the random values used to linearly combine the openings returned from the GKR proof.
///
/// These correspond to the lambdas in our documentation.
pub fn gkr_openings_combining_randomness(&self) -> Option<&[E]> {
self.gkr.as_ref().map(|gkr| gkr.openings_combining_randomness.as_ref())
}
}

/// Holds all the random elements needed when using GKR to accelerate LogUp.
///
/// This consists of two sets of random values:
/// 1. The Lagrange kernel random elements (expanded on in [`LagrangeKernelRandElements`]), and
/// 2. The "openings combining randomness".
///
/// After the verifying the LogUp-GKR circuit, the verifier is left with unproven claims provided
/// nondeterministically by the prover about the evaluations of the MLE of the main trace columns at
/// the Lagrange kernel random elements. Those claims are (linearly) combined into one using the
/// openings combining randomness.
#[derive(Clone, Debug)]
pub struct GkrRandElements<E> {
lagrange: LagrangeKernelRandElements<E>,
openings_combining_randomness: Vec<E>,
}

impl<E> GkrRandElements<E> {
/// Constructs a new [`GkrRandElements`] from [`LagrangeKernelRandElements`], and the openings
/// combining randomness.
///
/// See [`GkrRandElements`] for a more detailed description.
pub fn new(
lagrange: LagrangeKernelRandElements<E>,
openings_combining_randomness: Vec<E>,
) -> Self {
Self { lagrange, openings_combining_randomness }
}

/// Returns the random elements needed to build the Lagrange kernel column.
pub fn lagrange_kernel_rand_elements(&self) -> &LagrangeKernelRandElements<E> {
&self.lagrange
}

/// Returns the random values used to linearly combine the openings returned from the GKR proof.
pub fn openings_combining_randomness(&self) -> &[E] {
&self.openings_combining_randomness
}
}

Expand All @@ -61,7 +105,7 @@ pub trait GkrVerifier {
&self,
gkr_proof: Self::GkrProof,
public_coin: &mut impl RandomCoin<BaseField = E::BaseField, Hasher = Hasher>,
) -> Result<LagrangeKernelRandElements<E>, Self::Error>
) -> Result<GkrRandElements<E>, Self::Error>
where
E: FieldElement,
Hasher: ElementHasher<BaseField = E::BaseField>;
Expand All @@ -75,11 +119,11 @@ impl GkrVerifier for () {
&self,
_gkr_proof: Self::GkrProof,
_public_coin: &mut impl RandomCoin<BaseField = E::BaseField, Hasher = Hasher>,
) -> Result<LagrangeKernelRandElements<E>, Self::Error>
) -> Result<GkrRandElements<E>, Self::Error>
where
E: FieldElement,
Hasher: ElementHasher<BaseField = E::BaseField>,
{
Ok(LagrangeKernelRandElements::new(Vec::new()))
Ok(GkrRandElements::new(LagrangeKernelRandElements::default(), Vec::new()))
}
}
6 changes: 5 additions & 1 deletion air/src/air/lagrange/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ impl<E: FieldElement> LagrangeKernelConstraints<E> {
}

/// Holds the randomly generated elements needed to build the Lagrange kernel auxiliary column.
#[derive(Debug, Clone)]
///
/// The Lagrange kernel consists of evaluating the function $eq(x, r)$, where $x$ is the binary
/// decomposition of the row index, and $r$ is some random point. The "Lagrange kernel random
/// elements" refer to this (multidimensional) point $r$.
#[derive(Debug, Clone, Default)]
pub struct LagrangeKernelRandElements<E> {
elements: Vec<E>,
}
Expand Down
10 changes: 5 additions & 5 deletions air/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use math::{fft, ExtensibleField, ExtensionOf, FieldElement, StarkField, ToElemen
use crate::ProofOptions;

mod aux;
pub use aux::{AuxRandElements, GkrVerifier};
pub use aux::{AuxRandElements, GkrRandElements, GkrVerifier};

mod trace_info;
pub use trace_info::TraceInfo;
Expand Down Expand Up @@ -269,7 +269,7 @@ pub trait Air: Send + Sync {
main_frame: &EvaluationFrame<F>,
aux_frame: &EvaluationFrame<E>,
periodic_values: &[F],
aux_rand_elements: &[E],
aux_rand_elements: &AuxRandElements<E>,
result: &mut [E],
) where
F: FieldElement<BaseField = Self::BaseField>,
Expand Down Expand Up @@ -298,7 +298,7 @@ pub trait Air: Send + Sync {
#[allow(unused_variables)]
fn get_aux_assertions<E: FieldElement<BaseField = Self::BaseField>>(
&self,
aux_rand_elements: &[E],
aux_rand_elements: &AuxRandElements<E>,
) -> Vec<Assertion<E>> {
Vec::new()
}
Expand All @@ -309,7 +309,7 @@ pub trait Air: Send + Sync {
/// Returns the [`GkrVerifier`] to be used to verify the GKR proof.
///
/// Leave unimplemented if the `Air` doesn't use a GKR proof.
fn get_auxiliary_proof_verifier<E: FieldElement<BaseField = Self::BaseField>>(
fn get_gkr_proof_verifier<E: FieldElement<BaseField = Self::BaseField>>(
&self,
) -> Self::GkrVerifier {
unimplemented!("`get_auxiliary_proof_verifier()` must be implemented when the proof contains a GKR proof");
Expand Down Expand Up @@ -422,7 +422,7 @@ pub trait Air: Send + Sync {
/// combination of boundary constraints during constraint merging.
fn get_boundary_constraints<E: FieldElement<BaseField = Self::BaseField>>(
&self,
aux_rand_elements: Option<&[E]>,
aux_rand_elements: Option<&AuxRandElements<E>>,
composition_coefficients: &[E],
) -> BoundaryConstraints<E> {
BoundaryConstraints::new(
Expand Down
2 changes: 1 addition & 1 deletion air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mod air;
pub use air::{
Air, AirContext, Assertion, AuxRandElements, BoundaryConstraint, BoundaryConstraintGroup,
BoundaryConstraints, ConstraintCompositionCoefficients, ConstraintDivisor,
DeepCompositionCoefficients, EvaluationFrame, GkrVerifier,
DeepCompositionCoefficients, EvaluationFrame, GkrRandElements, GkrVerifier,
LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint,
LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements,
LagrangeKernelTransitionConstraints, TraceInfo, TransitionConstraintDegree,
Expand Down
8 changes: 5 additions & 3 deletions examples/src/rescue_raps/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use core_utils::flatten_slice_elements;
use winterfell::{
math::ToElements, Air, AirContext, Assertion, EvaluationFrame, TraceInfo,
math::ToElements, Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, TraceInfo,
TransitionConstraintDegree,
};

Expand Down Expand Up @@ -162,7 +162,7 @@ impl Air for RescueRapsAir {
main_frame: &EvaluationFrame<F>,
aux_frame: &EvaluationFrame<E>,
periodic_values: &[F],
aux_rand_elements: &[E],
aux_rand_elements: &AuxRandElements<E>,
result: &mut [E],
) where
F: FieldElement<BaseField = Self::BaseField>,
Expand All @@ -174,6 +174,8 @@ impl Air for RescueRapsAir {
let aux_current = aux_frame.current();
let aux_next = aux_frame.next();

let aux_rand_elements = aux_rand_elements.rand_elements();

let absorption_flag = periodic_values[1];

// We want to enforce that the absorbed values of the first hash chain are a
Expand Down Expand Up @@ -233,7 +235,7 @@ impl Air for RescueRapsAir {
]
}

fn get_aux_assertions<E>(&self, _aux_rand_elements: &[E]) -> Vec<Assertion<E>>
fn get_aux_assertions<E>(&self, _aux_rand_elements: &AuxRandElements<E>) -> Vec<Assertion<E>>
where
E: FieldElement<BaseField = Self::BaseField>,
{
Expand Down
16 changes: 8 additions & 8 deletions prover/benches/lagrange_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::time::Duration;

use air::{
Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients,
EvaluationFrame, FieldExtension, LagrangeKernelRandElements, ProofOptions, TraceInfo,
TransitionConstraintDegree,
EvaluationFrame, FieldExtension, GkrRandElements, LagrangeKernelRandElements, ProofOptions,
TraceInfo, TransitionConstraintDegree,
};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use crypto::{hashers::Blake3_256, DefaultRandomCoin, RandomCoin};
Expand Down Expand Up @@ -144,7 +144,7 @@ impl Air for LagrangeKernelAir {
_main_frame: &EvaluationFrame<F>,
_aux_frame: &EvaluationFrame<E>,
_periodic_values: &[F],
_aux_rand_elements: &[E],
_aux_rand_elements: &AuxRandElements<E>,
_result: &mut [E],
) where
F: FieldElement<BaseField = Self::BaseField>,
Expand All @@ -155,7 +155,7 @@ impl Air for LagrangeKernelAir {

fn get_aux_assertions<E: FieldElement<BaseField = Self::BaseField>>(
&self,
_aux_rand_elements: &[E],
_aux_rand_elements: &AuxRandElements<E>,
) -> Vec<Assertion<E>> {
vec![Assertion::single(1, 0, E::ZERO)]
}
Expand Down Expand Up @@ -223,22 +223,22 @@ impl Prover for LagrangeProver {
&self,
main_trace: &Self::Trace,
public_coin: &mut Self::RandomCoin,
) -> (ProverGkrProof<Self>, LagrangeKernelRandElements<E>)
) -> (ProverGkrProof<Self>, GkrRandElements<E>)
where
E: FieldElement<BaseField = Self::BaseField>,
{
let main_trace = main_trace.main_segment();
let lagrange_kernel_rand_elements: Vec<E> = {
let lagrange_kernel_rand_elements = {
let log_trace_len = main_trace.num_rows().ilog2() as usize;
let mut rand_elements = Vec::with_capacity(log_trace_len);
for _ in 0..log_trace_len {
rand_elements.push(public_coin.draw().unwrap());
}

rand_elements
LagrangeKernelRandElements::new(rand_elements)
};

((), LagrangeKernelRandElements::new(lagrange_kernel_rand_elements))
((), GkrRandElements::new(lagrange_kernel_rand_elements, Vec::new()))
}

fn build_aux_trace<E>(
Expand Down
4 changes: 2 additions & 2 deletions prover/src/constraints/evaluator/boundary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use alloc::{collections::BTreeMap, vec::Vec};

use air::{Air, ConstraintDivisor};
use air::{Air, AuxRandElements, ConstraintDivisor};
use math::{fft, ExtensionOf, FieldElement};

use super::StarkDomain;
Expand Down Expand Up @@ -35,7 +35,7 @@ impl<E: FieldElement> BoundaryConstraints<E> {
/// by an instance of AIR for a specific computation.
pub fn new<A: Air<BaseField = E::BaseField>>(
air: &A,
aux_rand_elements: Option<&[E]>,
aux_rand_elements: Option<&AuxRandElements<E>>,
composition_coefficients: &[E],
) -> Self {
// get constraints from the AIR instance
Expand Down
7 changes: 2 additions & 5 deletions prover/src/constraints/evaluator/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ where
// constraint evaluations.
let boundary_constraints = BoundaryConstraints::new(
air,
aux_rand_elements
.as_ref()
.map(|aux_rand_elements| aux_rand_elements.rand_elements()),
aux_rand_elements.as_ref(),
&composition_coefficients.boundary,
);

Expand Down Expand Up @@ -378,8 +376,7 @@ where
periodic_values,
self.aux_rand_elements
.as_ref()
.expect("expected aux rand elements to be present")
.rand_elements(),
.expect("expected aux rand elements to be present"),
evaluations,
);

Expand Down
28 changes: 15 additions & 13 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
#[macro_use]
extern crate alloc;

use air::AuxRandElements;
pub use air::{
proof, proof::Proof, Air, AirContext, Assertion, BoundaryConstraint, BoundaryConstraintGroup,
ConstraintCompositionCoefficients, ConstraintDivisor, DeepCompositionCoefficients,
EvaluationFrame, FieldExtension, LagrangeKernelRandElements, ProofOptions, TraceInfo,
TransitionConstraintDegree,
};
use air::{AuxRandElements, GkrRandElements};
pub use crypto;
use crypto::{ElementHasher, RandomCoin};
use fri::FriProver;
Expand Down Expand Up @@ -205,7 +205,7 @@ pub trait Prover {
&self,
main_trace: &Self::Trace,
public_coin: &mut Self::RandomCoin,
) -> (ProverGkrProof<Self>, LagrangeKernelRandElements<E>)
) -> (ProverGkrProof<Self>, GkrRandElements<E>)
where
E: FieldElement<BaseField = Self::BaseField>,
{
Expand Down Expand Up @@ -280,7 +280,7 @@ pub trait Prover {
let pub_inputs = self.get_pub_inputs(&trace);
let pub_inputs_elements = pub_inputs.to_elements();

// create an instance of AIR for the provided parameters. this takes a generic description
// create an instance of AIR for the provided parameters. This takes a generic description
// of the computation (provided via AIR type), and creates a description of a specific
// execution of the computation for the provided public inputs.
let air = Self::Air::new(trace.info().clone(), pub_inputs, self.options().clone());
Expand Down Expand Up @@ -310,22 +310,24 @@ pub trait Prover {
// build the auxiliary trace segment, and append the resulting segments to trace commitment
// and trace polynomial table structs
let aux_trace_with_metadata = if air.trace_info().is_multi_segment() {
let (gkr_proof, lagrange_rand_elements) =
if air.context().has_lagrange_kernel_aux_column() {
let (gkr_proof, lagrange_rand_elements) =
maybe_await!(self.generate_gkr_proof(&trace, channel.public_coin()));
let (gkr_proof, aux_rand_elements) = if air.context().has_lagrange_kernel_aux_column() {
let (gkr_proof, gkr_rand_elements) =
maybe_await!(self.generate_gkr_proof(&trace, channel.public_coin()));

(Some(gkr_proof), Some(lagrange_rand_elements))
} else {
(None, None)
};
let rand_elements = air
.get_aux_rand_elements(channel.public_coin())
.expect("failed to draw random elements for the auxiliary trace segment");

let aux_rand_elements =
AuxRandElements::new_with_gkr(rand_elements, gkr_rand_elements);

let aux_rand_elements = {
(Some(gkr_proof), aux_rand_elements)
} else {
let rand_elements = air
.get_aux_rand_elements(channel.public_coin())
.expect("failed to draw random elements for the auxiliary trace segment");

AuxRandElements::new_with_lagrange(rand_elements, lagrange_rand_elements)
(None, AuxRandElements::new(rand_elements))
};

let aux_trace = maybe_await!(self.build_aux_trace(&trace, &aux_rand_elements));
Expand Down
Loading

0 comments on commit 00f2579

Please sign in to comment.