Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove From impls resulting in silent field element conversions #243

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ mod tests {
]);
let expected = vec![
BaseElement::from(ext_fri),
BaseElement::from(grinding_factor as u32),
BaseElement::from(grinding_factor),
BaseElement::from(blowup_factor as u32),
BaseElement::from(num_queries as u32),
];
Expand Down
18 changes: 15 additions & 3 deletions air/src/proof/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,22 @@ impl Context {
// --------------------------------------------------------------------------------------------
/// Creates a new context for a computation described by the specified field, trace info, and
/// proof options.
///
/// # Panics
/// Panics if either trace length or the LDE domain size implied by the trace length and the
/// blowup factor is greater then [u32::MAX].
pub fn new<B: StarkField>(trace_info: &TraceInfo, options: ProofOptions) -> Self {
// TODO: return errors instead of panicking?

let trace_length = trace_info.length();
assert!(trace_length <= u32::MAX as usize, "trace length too big");
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

let lde_domain_size = trace_length * options.blowup_factor();
assert!(lde_domain_size <= u32::MAX as usize, "LDE domain size too big");
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

Context {
trace_layout: trace_info.layout().clone(),
trace_length: trace_info.length(),
trace_length,
trace_meta: trace_info.meta().to_vec(),
field_modulus_bytes: B::get_modulus_le_bytes(),
options,
Expand Down Expand Up @@ -117,7 +129,7 @@ impl<E: StarkField> ToElements<E> for Context {

// convert proof options and trace length to elements
result.append(&mut self.options.to_elements());
result.push(E::from(self.trace_length as u64));
result.push(E::from(self.trace_length as u32));
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

// convert trace metadata to elements; this is done by breaking trace metadata into chunks
// of bytes which are slightly smaller than the number of bytes needed to encode a field
Expand Down Expand Up @@ -257,7 +269,7 @@ mod tests {
BaseElement::from(1_u32), // lower bits of field modulus
BaseElement::from(u32::MAX), // upper bits of field modulus
BaseElement::from(ext_fri),
BaseElement::from(grinding_factor as u32),
BaseElement::from(grinding_factor),
BaseElement::from(blowup_factor as u32),
BaseElement::from(num_queries as u32),
BaseElement::from(trace_length as u32),
Expand Down
5 changes: 3 additions & 2 deletions crypto/benches/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use winter_crypto::{build_merkle_nodes, concurrent, hashers::Blake3_256, Hasher}
type Blake3 = Blake3_256<BaseElement>;
type Blake3Digest = <Blake3 as Hasher>::Digest;

#[allow(clippy::needless_range_loop)]
pub fn merkle_tree_construction(c: &mut Criterion) {
let mut merkle_group = c.benchmark_group("merkle tree construction");

Expand All @@ -26,10 +27,10 @@ pub fn merkle_tree_construction(c: &mut Criterion) {
res
};
merkle_group.bench_with_input(BenchmarkId::new("sequential", size), &data, |b, i| {
b.iter(|| build_merkle_nodes::<Blake3>(&i))
b.iter(|| build_merkle_nodes::<Blake3>(i))
});
merkle_group.bench_with_input(BenchmarkId::new("concurrent", size), &data, |b, i| {
b.iter(|| concurrent::build_merkle_nodes::<Blake3>(&i))
b.iter(|| concurrent::build_merkle_nodes::<Blake3>(i))
});
}
}
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/hash/griffin/griffin64_256_jive/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use super::{Digest, DIGEST_SIZE};
use core::slice;
use math::{fields::f64::BaseElement, StarkField};
use math::fields::f64::BaseElement;
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// DIGEST TRAIT IMPLEMENTATIONS
Expand Down
7 changes: 4 additions & 3 deletions crypto/src/hash/griffin/griffin64_256_jive/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use proptest::prelude::*;

use rand_utils::{rand_array, rand_value};

#[allow(clippy::needless_range_loop)]
#[test]
fn mds_inv_test() {
let mut mul_result = [[BaseElement::new(0); STATE_WIDTH]; STATE_WIDTH];
Expand Down Expand Up @@ -196,15 +197,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) {

proptest! {
#[test]
fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) {
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {

let mut v1 = [BaseElement::ZERO;STATE_WIDTH];
let mut v1 = [BaseElement::ZERO; STATE_WIDTH];
let mut v2;

for i in 0..STATE_WIDTH {
v1[i] = BaseElement::new(a[i]);
}
v2 = v1.clone();
v2 = v1;

apply_mds_naive(&mut v1);
GriffinJive64_256::apply_linear(&mut v2);
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/hash/rescue/rp64_256/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use super::{Digest, DIGEST_SIZE};
use core::slice;
use math::{fields::f64::BaseElement, StarkField};
use math::fields::f64::BaseElement;
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// DIGEST TRAIT IMPLEMENTATIONS
Expand Down
6 changes: 3 additions & 3 deletions crypto/src/hash/rescue/rp64_256/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) {

proptest! {
#[test]
fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) {
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {

let mut v1 = [BaseElement::ZERO;STATE_WIDTH];
let mut v1 = [BaseElement::ZERO; STATE_WIDTH];
let mut v2;

for i in 0..STATE_WIDTH {
v1[i] = BaseElement::new(a[i]);
}
v2 = v1.clone();
v2 = v1;

apply_mds_naive(&mut v1);
Rp64_256::apply_mds(&mut v2);
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/hash/rescue/rp64_256_jive/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use super::{Digest, DIGEST_SIZE};
use core::slice;
use math::{fields::f64::BaseElement, StarkField};
use math::fields::f64::BaseElement;
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// DIGEST TRAIT IMPLEMENTATIONS
Expand Down
9 changes: 5 additions & 4 deletions crypto/src/hash/rescue/rp64_256_jive/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use proptest::prelude::*;

use rand_utils::{rand_array, rand_value};

#[allow(clippy::needless_range_loop)]
#[test]
fn mds_inv_test() {
let mut mul_result = [[BaseElement::new(0); STATE_WIDTH]; STATE_WIDTH];
Expand All @@ -36,7 +37,7 @@ fn mds_inv_test() {
#[test]
fn test_alphas() {
let e: BaseElement = rand_value();
let e_exp = e.exp(ALPHA.into());
let e_exp = e.exp(ALPHA);
assert_eq!(e, e_exp.exp(INV_ALPHA));
}

Expand Down Expand Up @@ -189,15 +190,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) {

proptest! {
#[test]
fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) {
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {

let mut v1 = [BaseElement::ZERO;STATE_WIDTH];
let mut v1 = [BaseElement::ZERO; STATE_WIDTH];
let mut v2;

for i in 0..STATE_WIDTH {
v1[i] = BaseElement::new(a[i]);
}
v2 = v1.clone();
v2 = v1;

apply_mds_naive(&mut v1);
RpJive64_256::apply_mds(&mut v2);
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/merkle/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ mod tests {
proptest! {
#[test]
fn build_merkle_nodes_concurrent(ref data in vec(any::<[u8; 32]>(), 256..257).no_shrink()) {
let leaves = ByteDigest::bytes_as_digests(&data).to_vec();
let leaves = ByteDigest::bytes_as_digests(data).to_vec();
let sequential = super::super::build_merkle_nodes::<Sha3_256<BaseElement>>(&leaves);
let concurrent = super::build_merkle_nodes::<Sha3_256<BaseElement>>(&leaves);
assert_eq!(concurrent, sequential);
Expand Down
4 changes: 2 additions & 2 deletions examples/src/lamport/aggregate/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ fn apply_message_acc(
let m0_bit = state[0];
let m1_bit = state[1];

state[0] = BaseElement::from((m0 >> (cycle_num + 1)) & 1);
state[1] = BaseElement::from((m1 >> (cycle_num + 1)) & 1);
state[0] = BaseElement::new((m0 >> (cycle_num + 1)) & 1);
state[1] = BaseElement::new((m1 >> (cycle_num + 1)) & 1);
state[2] += power_of_two * m0_bit;
state[3] += power_of_two * m1_bit;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/src/lamport/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ pub fn message_to_elements(message: &[u8]) -> [BaseElement; 2] {
let checksum = m0.count_zeros() + m1.count_zeros();
let m1 = m1 | ((checksum as u128) << 119);

[BaseElement::from(m0), BaseElement::from(m1)]
[BaseElement::new(m0), BaseElement::new(m1)]
}

/// Reduces a list of public key elements to a single 32-byte value. The reduction is done
Expand Down
4 changes: 2 additions & 2 deletions examples/src/lamport/threshold/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ impl Air for LamportThresholdAir {
let mut m1_bits = Vec::with_capacity(SIG_CYCLE_LEN);
for i in 0..SIG_CYCLE_LEN {
let cycle_num = i / HASH_CYCLE_LEN;
m0_bits.push(BaseElement::from((m0 >> cycle_num) & 1));
m1_bits.push(BaseElement::from((m1 >> cycle_num) & 1));
m0_bits.push(BaseElement::new((m0 >> cycle_num) & 1));
m1_bits.push(BaseElement::new((m1 >> cycle_num) & 1));
}
result.push(m0_bits);
result.push(m1_bits);
Expand Down
6 changes: 3 additions & 3 deletions examples/src/lamport/threshold/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ fn update_sig_verification_state(
} else {
// for the 8th step of very cycle do the following:

let m0_bit = BaseElement::from((sig_info.m0 >> cycle_num) & 1);
let m1_bit = BaseElement::from((sig_info.m1 >> cycle_num) & 1);
let m0_bit = BaseElement::new((sig_info.m0 >> cycle_num) & 1);
let m1_bit = BaseElement::new((sig_info.m1 >> cycle_num) & 1);
let mp_bit = merkle_path_idx[0];

// copy next set of public keys into the registers computing hash of the public key
Expand Down Expand Up @@ -345,7 +345,7 @@ fn update_merkle_path_index(
let index_bit = state[0];
// the cycle is offset by +1 because the first node in the Merkle path is redundant and we
// get it by hashing the public key
state[0] = BaseElement::from((index >> (cycle_num + 1)) & 1);
state[0] = BaseElement::new((index >> (cycle_num + 1)) & 1);
state[1] += power_of_two * index_bit;
}

Expand Down
2 changes: 1 addition & 1 deletion fri/src/folding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ where
// build offset inverses and twiddles used during polynomial interpolation
let inv_offsets = get_inv_offsets(values.len(), domain_offset, N);
let inv_twiddles = get_inv_twiddles::<B>(N);
let len_offset = E::inv((N as u64).into());
let len_offset = E::inv((N as u32).into());

let mut result = unsafe { uninit_vector(values.len()) };
iter_mut!(result)
Expand Down
35 changes: 23 additions & 12 deletions math/src/fft/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,47 @@ pub fn evaluate_poly_with_offset<B: StarkField, E: FieldElement<BaseField = B>>(

/// Uses FFT algorithm to interpolate a polynomial from provided `values`; the interpolation
/// is done in-place, meaning `values` are updated with polynomial coefficients.
pub fn interpolate_poly<B, E>(v: &mut [E], inv_twiddles: &[B])
///
/// # Panics
/// Panics if the length of `values` is greater than [u32::MAX].
pub fn interpolate_poly<B, E>(values: &mut [E], inv_twiddles: &[B])
where
B: StarkField,
E: FieldElement<BaseField = B>,
{
split_radix_fft(v, inv_twiddles);
let inv_length = E::inv((v.len() as u64).into());
v.par_iter_mut().for_each(|e| *e *= inv_length);
permute(v);
assert!(values.len() <= u32::MAX as usize, "too many values");
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

split_radix_fft(values, inv_twiddles);
let inv_length = E::inv((values.len() as u32).into());
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
values.par_iter_mut().for_each(|e| *e *= inv_length);
permute(values);
}

/// Uses FFT algorithm to interpolate a polynomial from provided `values` over the domain defined
/// by `inv_twiddles` and offset by `domain_offset` factor.
///
///
/// # Panics
/// Panics if the length of `values` is greater than [u32::MAX].
pub fn interpolate_poly_with_offset<B, E>(values: &mut [E], inv_twiddles: &[B], domain_offset: B)
where
B: StarkField,
E: FieldElement<BaseField = B>,
{
assert!(values.len() <= u32::MAX as usize, "too many values");

split_radix_fft(values, inv_twiddles);
permute(values);

let domain_offset = E::inv(domain_offset.into());
let inv_len = E::inv((values.len() as u64).into());
let inv_len = E::inv((values.len() as u32).into());
let batch_size = values.len() / rayon::current_num_threads().next_power_of_two();

values.par_chunks_mut(batch_size).enumerate().for_each(|(i, batch)| {
let mut offset = domain_offset.exp(((i * batch_size) as u64).into()) * inv_len;
for coeff in batch.iter_mut() {
*coeff = *coeff * offset;
offset = offset * domain_offset;
*coeff *= offset;
offset *= domain_offset;
}
});
}
Expand Down Expand Up @@ -136,7 +147,7 @@ pub(super) fn split_radix_fft<B: StarkField, E: FieldElement<BaseField = B>>(
// apply inner FFTs
values
.par_chunks_mut(outer_len)
.for_each(|row| row.fft_in_place_raw(&twiddles, stretch, stretch, 0));
.for_each(|row| row.fft_in_place_raw(twiddles, stretch, stretch, 0));

// transpose inner x inner x stretch square matrix
transpose_square_stretch(values, inner_len, stretch);
Expand All @@ -149,10 +160,10 @@ pub(super) fn split_radix_fft<B: StarkField, E: FieldElement<BaseField = B>>(
let mut outer_twiddle = inner_twiddle;
for element in row.iter_mut().skip(1) {
*element = (*element).mul_base(outer_twiddle);
outer_twiddle = outer_twiddle * inner_twiddle;
outer_twiddle *= inner_twiddle;
}
}
row.fft_in_place(&twiddles);
row.fft_in_place(twiddles);
});
}

Expand Down Expand Up @@ -216,7 +227,7 @@ fn clone_and_shift<E: FieldElement>(source: &[E], destination: &mut [E], offset:
let mut factor = offset.exp(((i * batch_size) as u64).into());
for (s, d) in source.iter().zip(destination.iter_mut()) {
*d = (*s).mul_base(factor);
factor = factor * offset;
factor *= offset;
}
});
}
13 changes: 11 additions & 2 deletions math/src/fft/serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,16 @@ where

/// Interpolates `evaluations` over a domain of length `evaluations.len()` in the field specified
/// `B` into a polynomial in coefficient form using the FFT algorithm.
///
/// # Panics
/// Panics if the length of `evaluations` is greater than [u32::MAX].
pub fn interpolate_poly<B, E>(evaluations: &mut [E], inv_twiddles: &[B])
where
B: StarkField,
E: FieldElement<BaseField = B>,
{
let inv_length = B::inv((evaluations.len() as u64).into());
assert!(evaluations.len() <= u32::MAX as usize, "too many evaluations");
let inv_length = B::inv((evaluations.len() as u32).into());
evaluations.fft_in_place(inv_twiddles);
evaluations.shift_by(inv_length);
evaluations.permute();
Expand All @@ -71,6 +75,9 @@ where
/// Interpolates `evaluations` over a domain of length `evaluations.len()` and shifted by
/// `domain_offset` in the field specified by `B` into a polynomial in coefficient form using
/// the FFT algorithm.
///
/// # Panics
/// Panics if the length of `evaluations` is greater than [u32::MAX].
pub fn interpolate_poly_with_offset<B, E>(
evaluations: &mut [E],
inv_twiddles: &[B],
Expand All @@ -79,11 +86,13 @@ pub fn interpolate_poly_with_offset<B, E>(
B: StarkField,
E: FieldElement<BaseField = B>,
{
assert!(evaluations.len() <= u32::MAX as usize, "too many evaluations");

evaluations.fft_in_place(inv_twiddles);
evaluations.permute();

let domain_offset = B::inv(domain_offset);
let offset = B::inv((evaluations.len() as u64).into());
let offset = B::inv((evaluations.len() as u32).into());

evaluations.shift_by_series(offset, domain_offset);
}
12 changes: 0 additions & 12 deletions math/src/field/extensions/cubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,18 +301,6 @@ impl<B: ExtensibleField<3>> From<B> for CubeExtension<B> {
}
}

impl<B: ExtensibleField<3>> From<u128> for CubeExtension<B> {
fn from(value: u128) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}

impl<B: ExtensibleField<3>> From<u64> for CubeExtension<B> {
fn from(value: u64) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}

impl<B: ExtensibleField<3>> From<u32> for CubeExtension<B> {
fn from(value: u32) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
Expand Down
Loading
Loading