diff --git a/air/src/options.rs b/air/src/options.rs index a34f8939b..19d0aa18a 100644 --- a/air/src/options.rs +++ b/air/src/options.rs @@ -326,7 +326,7 @@ mod tests { ]); let expected = vec![ BaseElement::from(ext_fri), - BaseElement::from(grinding_factor as u32), + BaseElement::from(grinding_factor), BaseElement::from(blowup_factor as u32), BaseElement::from(num_queries as u32), ]; diff --git a/air/src/proof/context.rs b/air/src/proof/context.rs index 549e3b2a9..d60f54b04 100644 --- a/air/src/proof/context.rs +++ b/air/src/proof/context.rs @@ -27,10 +27,22 @@ impl Context { // -------------------------------------------------------------------------------------------- /// Creates a new context for a computation described by the specified field, trace info, and /// proof options. + /// + /// # Panics + /// Panics if either trace length or the LDE domain size implied by the trace length and the + /// blowup factor is greater then [u32::MAX]. pub fn new(trace_info: &TraceInfo, options: ProofOptions) -> Self { + // TODO: return errors instead of panicking? + + let trace_length = trace_info.length(); + assert!(trace_length <= u32::MAX as usize, "trace length too big"); + + let lde_domain_size = trace_length * options.blowup_factor(); + assert!(lde_domain_size <= u32::MAX as usize, "LDE domain size too big"); + Context { trace_layout: trace_info.layout().clone(), - trace_length: trace_info.length(), + trace_length, trace_meta: trace_info.meta().to_vec(), field_modulus_bytes: B::get_modulus_le_bytes(), options, @@ -117,7 +129,7 @@ impl ToElements for Context { // convert proof options and trace length to elements result.append(&mut self.options.to_elements()); - result.push(E::from(self.trace_length as u64)); + result.push(E::from(self.trace_length as u32)); // convert trace metadata to elements; this is done by breaking trace metadata into chunks // of bytes which are slightly smaller than the number of bytes needed to encode a field @@ -257,7 +269,7 @@ mod tests { BaseElement::from(1_u32), // lower bits of field modulus BaseElement::from(u32::MAX), // upper bits of field modulus BaseElement::from(ext_fri), - BaseElement::from(grinding_factor as u32), + BaseElement::from(grinding_factor), BaseElement::from(blowup_factor as u32), BaseElement::from(num_queries as u32), BaseElement::from(trace_length as u32), diff --git a/crypto/benches/merkle.rs b/crypto/benches/merkle.rs index df86c8767..c44db8af3 100644 --- a/crypto/benches/merkle.rs +++ b/crypto/benches/merkle.rs @@ -12,6 +12,7 @@ use winter_crypto::{build_merkle_nodes, concurrent, hashers::Blake3_256, Hasher} type Blake3 = Blake3_256; type Blake3Digest = ::Digest; +#[allow(clippy::needless_range_loop)] pub fn merkle_tree_construction(c: &mut Criterion) { let mut merkle_group = c.benchmark_group("merkle tree construction"); @@ -26,10 +27,10 @@ pub fn merkle_tree_construction(c: &mut Criterion) { res }; merkle_group.bench_with_input(BenchmarkId::new("sequential", size), &data, |b, i| { - b.iter(|| build_merkle_nodes::(&i)) + b.iter(|| build_merkle_nodes::(i)) }); merkle_group.bench_with_input(BenchmarkId::new("concurrent", size), &data, |b, i| { - b.iter(|| concurrent::build_merkle_nodes::(&i)) + b.iter(|| concurrent::build_merkle_nodes::(i)) }); } } diff --git a/crypto/src/hash/griffin/griffin64_256_jive/digest.rs b/crypto/src/hash/griffin/griffin64_256_jive/digest.rs index 91ab4216e..3a599a0ef 100644 --- a/crypto/src/hash/griffin/griffin64_256_jive/digest.rs +++ b/crypto/src/hash/griffin/griffin64_256_jive/digest.rs @@ -5,7 +5,7 @@ use super::{Digest, DIGEST_SIZE}; use core::slice; -use math::{fields::f64::BaseElement, StarkField}; +use math::fields::f64::BaseElement; use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; // DIGEST TRAIT IMPLEMENTATIONS diff --git a/crypto/src/hash/griffin/griffin64_256_jive/tests.rs b/crypto/src/hash/griffin/griffin64_256_jive/tests.rs index fd621cc86..c33e0acc5 100644 --- a/crypto/src/hash/griffin/griffin64_256_jive/tests.rs +++ b/crypto/src/hash/griffin/griffin64_256_jive/tests.rs @@ -12,6 +12,7 @@ use proptest::prelude::*; use rand_utils::{rand_array, rand_value}; +#[allow(clippy::needless_range_loop)] #[test] fn mds_inv_test() { let mut mul_result = [[BaseElement::new(0); STATE_WIDTH]; STATE_WIDTH]; @@ -196,15 +197,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) { proptest! { #[test] - fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) { + fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) { - let mut v1 = [BaseElement::ZERO;STATE_WIDTH]; + let mut v1 = [BaseElement::ZERO; STATE_WIDTH]; let mut v2; for i in 0..STATE_WIDTH { v1[i] = BaseElement::new(a[i]); } - v2 = v1.clone(); + v2 = v1; apply_mds_naive(&mut v1); GriffinJive64_256::apply_linear(&mut v2); diff --git a/crypto/src/hash/rescue/rp64_256/digest.rs b/crypto/src/hash/rescue/rp64_256/digest.rs index 91ab4216e..3a599a0ef 100644 --- a/crypto/src/hash/rescue/rp64_256/digest.rs +++ b/crypto/src/hash/rescue/rp64_256/digest.rs @@ -5,7 +5,7 @@ use super::{Digest, DIGEST_SIZE}; use core::slice; -use math::{fields::f64::BaseElement, StarkField}; +use math::fields::f64::BaseElement; use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; // DIGEST TRAIT IMPLEMENTATIONS diff --git a/crypto/src/hash/rescue/rp64_256/tests.rs b/crypto/src/hash/rescue/rp64_256/tests.rs index 49bb11273..2067b41c7 100644 --- a/crypto/src/hash/rescue/rp64_256/tests.rs +++ b/crypto/src/hash/rescue/rp64_256/tests.rs @@ -191,15 +191,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) { proptest! { #[test] - fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) { + fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) { - let mut v1 = [BaseElement::ZERO;STATE_WIDTH]; + let mut v1 = [BaseElement::ZERO; STATE_WIDTH]; let mut v2; for i in 0..STATE_WIDTH { v1[i] = BaseElement::new(a[i]); } - v2 = v1.clone(); + v2 = v1; apply_mds_naive(&mut v1); Rp64_256::apply_mds(&mut v2); diff --git a/crypto/src/hash/rescue/rp64_256_jive/digest.rs b/crypto/src/hash/rescue/rp64_256_jive/digest.rs index 91ab4216e..3a599a0ef 100644 --- a/crypto/src/hash/rescue/rp64_256_jive/digest.rs +++ b/crypto/src/hash/rescue/rp64_256_jive/digest.rs @@ -5,7 +5,7 @@ use super::{Digest, DIGEST_SIZE}; use core::slice; -use math::{fields::f64::BaseElement, StarkField}; +use math::fields::f64::BaseElement; use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; // DIGEST TRAIT IMPLEMENTATIONS diff --git a/crypto/src/hash/rescue/rp64_256_jive/tests.rs b/crypto/src/hash/rescue/rp64_256_jive/tests.rs index 8148173f6..6c1b7f4a1 100644 --- a/crypto/src/hash/rescue/rp64_256_jive/tests.rs +++ b/crypto/src/hash/rescue/rp64_256_jive/tests.rs @@ -12,6 +12,7 @@ use proptest::prelude::*; use rand_utils::{rand_array, rand_value}; +#[allow(clippy::needless_range_loop)] #[test] fn mds_inv_test() { let mut mul_result = [[BaseElement::new(0); STATE_WIDTH]; STATE_WIDTH]; @@ -36,7 +37,7 @@ fn mds_inv_test() { #[test] fn test_alphas() { let e: BaseElement = rand_value(); - let e_exp = e.exp(ALPHA.into()); + let e_exp = e.exp(ALPHA); assert_eq!(e, e_exp.exp(INV_ALPHA)); } @@ -189,15 +190,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) { proptest! { #[test] - fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) { + fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) { - let mut v1 = [BaseElement::ZERO;STATE_WIDTH]; + let mut v1 = [BaseElement::ZERO; STATE_WIDTH]; let mut v2; for i in 0..STATE_WIDTH { v1[i] = BaseElement::new(a[i]); } - v2 = v1.clone(); + v2 = v1; apply_mds_naive(&mut v1); RpJive64_256::apply_mds(&mut v2); diff --git a/crypto/src/merkle/concurrent.rs b/crypto/src/merkle/concurrent.rs index 9633dfd44..c2302dd43 100644 --- a/crypto/src/merkle/concurrent.rs +++ b/crypto/src/merkle/concurrent.rs @@ -82,7 +82,7 @@ mod tests { proptest! { #[test] fn build_merkle_nodes_concurrent(ref data in vec(any::<[u8; 32]>(), 256..257).no_shrink()) { - let leaves = ByteDigest::bytes_as_digests(&data).to_vec(); + let leaves = ByteDigest::bytes_as_digests(data).to_vec(); let sequential = super::super::build_merkle_nodes::>(&leaves); let concurrent = super::build_merkle_nodes::>(&leaves); assert_eq!(concurrent, sequential); diff --git a/examples/src/lamport/aggregate/prover.rs b/examples/src/lamport/aggregate/prover.rs index 4120fae0a..ac5dea2fa 100644 --- a/examples/src/lamport/aggregate/prover.rs +++ b/examples/src/lamport/aggregate/prover.rs @@ -232,8 +232,8 @@ fn apply_message_acc( let m0_bit = state[0]; let m1_bit = state[1]; - state[0] = BaseElement::from((m0 >> (cycle_num + 1)) & 1); - state[1] = BaseElement::from((m1 >> (cycle_num + 1)) & 1); + state[0] = BaseElement::new((m0 >> (cycle_num + 1)) & 1); + state[1] = BaseElement::new((m1 >> (cycle_num + 1)) & 1); state[2] += power_of_two * m0_bit; state[3] += power_of_two * m1_bit; } diff --git a/examples/src/lamport/signature.rs b/examples/src/lamport/signature.rs index f20534501..e7d93dbe9 100644 --- a/examples/src/lamport/signature.rs +++ b/examples/src/lamport/signature.rs @@ -187,7 +187,7 @@ pub fn message_to_elements(message: &[u8]) -> [BaseElement; 2] { let checksum = m0.count_zeros() + m1.count_zeros(); let m1 = m1 | ((checksum as u128) << 119); - [BaseElement::from(m0), BaseElement::from(m1)] + [BaseElement::new(m0), BaseElement::new(m1)] } /// Reduces a list of public key elements to a single 32-byte value. The reduction is done diff --git a/examples/src/lamport/threshold/air.rs b/examples/src/lamport/threshold/air.rs index 1c593290c..41fa86926 100644 --- a/examples/src/lamport/threshold/air.rs +++ b/examples/src/lamport/threshold/air.rs @@ -228,8 +228,8 @@ impl Air for LamportThresholdAir { let mut m1_bits = Vec::with_capacity(SIG_CYCLE_LEN); for i in 0..SIG_CYCLE_LEN { let cycle_num = i / HASH_CYCLE_LEN; - m0_bits.push(BaseElement::from((m0 >> cycle_num) & 1)); - m1_bits.push(BaseElement::from((m1 >> cycle_num) & 1)); + m0_bits.push(BaseElement::new((m0 >> cycle_num) & 1)); + m1_bits.push(BaseElement::new((m1 >> cycle_num) & 1)); } result.push(m0_bits); result.push(m1_bits); diff --git a/examples/src/lamport/threshold/prover.rs b/examples/src/lamport/threshold/prover.rs index 2ccd34a85..4cd26ef14 100644 --- a/examples/src/lamport/threshold/prover.rs +++ b/examples/src/lamport/threshold/prover.rs @@ -245,8 +245,8 @@ fn update_sig_verification_state( } else { // for the 8th step of very cycle do the following: - let m0_bit = BaseElement::from((sig_info.m0 >> cycle_num) & 1); - let m1_bit = BaseElement::from((sig_info.m1 >> cycle_num) & 1); + let m0_bit = BaseElement::new((sig_info.m0 >> cycle_num) & 1); + let m1_bit = BaseElement::new((sig_info.m1 >> cycle_num) & 1); let mp_bit = merkle_path_idx[0]; // copy next set of public keys into the registers computing hash of the public key @@ -345,7 +345,7 @@ fn update_merkle_path_index( let index_bit = state[0]; // the cycle is offset by +1 because the first node in the Merkle path is redundant and we // get it by hashing the public key - state[0] = BaseElement::from((index >> (cycle_num + 1)) & 1); + state[0] = BaseElement::new((index >> (cycle_num + 1)) & 1); state[1] += power_of_two * index_bit; } diff --git a/fri/src/folding/mod.rs b/fri/src/folding/mod.rs index b152123b2..4d6f78182 100644 --- a/fri/src/folding/mod.rs +++ b/fri/src/folding/mod.rs @@ -90,7 +90,7 @@ where // build offset inverses and twiddles used during polynomial interpolation let inv_offsets = get_inv_offsets(values.len(), domain_offset, N); let inv_twiddles = get_inv_twiddles::(N); - let len_offset = E::inv((N as u64).into()); + let len_offset = E::inv((N as u32).into()); let mut result = unsafe { uninit_vector(values.len()) }; iter_mut!(result) diff --git a/math/src/fft/concurrent.rs b/math/src/fft/concurrent.rs index f5bbd1014..caab7f81d 100644 --- a/math/src/fft/concurrent.rs +++ b/math/src/fft/concurrent.rs @@ -50,36 +50,47 @@ pub fn evaluate_poly_with_offset>( /// Uses FFT algorithm to interpolate a polynomial from provided `values`; the interpolation /// is done in-place, meaning `values` are updated with polynomial coefficients. -pub fn interpolate_poly(v: &mut [E], inv_twiddles: &[B]) +/// +/// # Panics +/// Panics if the length of `values` is greater than [u32::MAX]. +pub fn interpolate_poly(values: &mut [E], inv_twiddles: &[B]) where B: StarkField, E: FieldElement, { - split_radix_fft(v, inv_twiddles); - let inv_length = E::inv((v.len() as u64).into()); - v.par_iter_mut().for_each(|e| *e *= inv_length); - permute(v); + assert!(values.len() <= u32::MAX as usize, "too many values"); + + split_radix_fft(values, inv_twiddles); + let inv_length = E::inv((values.len() as u32).into()); + values.par_iter_mut().for_each(|e| *e *= inv_length); + permute(values); } /// Uses FFT algorithm to interpolate a polynomial from provided `values` over the domain defined /// by `inv_twiddles` and offset by `domain_offset` factor. +/// +/// +/// # Panics +/// Panics if the length of `values` is greater than [u32::MAX]. pub fn interpolate_poly_with_offset(values: &mut [E], inv_twiddles: &[B], domain_offset: B) where B: StarkField, E: FieldElement, { + assert!(values.len() <= u32::MAX as usize, "too many values"); + split_radix_fft(values, inv_twiddles); permute(values); let domain_offset = E::inv(domain_offset.into()); - let inv_len = E::inv((values.len() as u64).into()); + let inv_len = E::inv((values.len() as u32).into()); let batch_size = values.len() / rayon::current_num_threads().next_power_of_two(); values.par_chunks_mut(batch_size).enumerate().for_each(|(i, batch)| { let mut offset = domain_offset.exp(((i * batch_size) as u64).into()) * inv_len; for coeff in batch.iter_mut() { - *coeff = *coeff * offset; - offset = offset * domain_offset; + *coeff *= offset; + offset *= domain_offset; } }); } @@ -136,7 +147,7 @@ pub(super) fn split_radix_fft>( // apply inner FFTs values .par_chunks_mut(outer_len) - .for_each(|row| row.fft_in_place_raw(&twiddles, stretch, stretch, 0)); + .for_each(|row| row.fft_in_place_raw(twiddles, stretch, stretch, 0)); // transpose inner x inner x stretch square matrix transpose_square_stretch(values, inner_len, stretch); @@ -149,10 +160,10 @@ pub(super) fn split_radix_fft>( let mut outer_twiddle = inner_twiddle; for element in row.iter_mut().skip(1) { *element = (*element).mul_base(outer_twiddle); - outer_twiddle = outer_twiddle * inner_twiddle; + outer_twiddle *= inner_twiddle; } } - row.fft_in_place(&twiddles); + row.fft_in_place(twiddles); }); } @@ -216,7 +227,7 @@ fn clone_and_shift(source: &[E], destination: &mut [E], offset: let mut factor = offset.exp(((i * batch_size) as u64).into()); for (s, d) in source.iter().zip(destination.iter_mut()) { *d = (*s).mul_base(factor); - factor = factor * offset; + factor *= offset; } }); } diff --git a/math/src/fft/serial.rs b/math/src/fft/serial.rs index af14c768f..b138edd9e 100644 --- a/math/src/fft/serial.rs +++ b/math/src/fft/serial.rs @@ -57,12 +57,16 @@ where /// Interpolates `evaluations` over a domain of length `evaluations.len()` in the field specified /// `B` into a polynomial in coefficient form using the FFT algorithm. +/// +/// # Panics +/// Panics if the length of `evaluations` is greater than [u32::MAX]. pub fn interpolate_poly(evaluations: &mut [E], inv_twiddles: &[B]) where B: StarkField, E: FieldElement, { - let inv_length = B::inv((evaluations.len() as u64).into()); + assert!(evaluations.len() <= u32::MAX as usize, "too many evaluations"); + let inv_length = B::inv((evaluations.len() as u32).into()); evaluations.fft_in_place(inv_twiddles); evaluations.shift_by(inv_length); evaluations.permute(); @@ -71,6 +75,9 @@ where /// Interpolates `evaluations` over a domain of length `evaluations.len()` and shifted by /// `domain_offset` in the field specified by `B` into a polynomial in coefficient form using /// the FFT algorithm. +/// +/// # Panics +/// Panics if the length of `evaluations` is greater than [u32::MAX]. pub fn interpolate_poly_with_offset( evaluations: &mut [E], inv_twiddles: &[B], @@ -79,11 +86,13 @@ pub fn interpolate_poly_with_offset( B: StarkField, E: FieldElement, { + assert!(evaluations.len() <= u32::MAX as usize, "too many evaluations"); + evaluations.fft_in_place(inv_twiddles); evaluations.permute(); let domain_offset = B::inv(domain_offset); - let offset = B::inv((evaluations.len() as u64).into()); + let offset = B::inv((evaluations.len() as u32).into()); evaluations.shift_by_series(offset, domain_offset); } diff --git a/math/src/field/extensions/cubic.rs b/math/src/field/extensions/cubic.rs index 97e3f5d60..dfb9a584a 100644 --- a/math/src/field/extensions/cubic.rs +++ b/math/src/field/extensions/cubic.rs @@ -301,18 +301,6 @@ impl> From for CubeExtension { } } -impl> From for CubeExtension { - fn from(value: u128) -> Self { - Self(B::from(value), B::ZERO, B::ZERO) - } -} - -impl> From for CubeExtension { - fn from(value: u64) -> Self { - Self(B::from(value), B::ZERO, B::ZERO) - } -} - impl> From for CubeExtension { fn from(value: u32) -> Self { Self(B::from(value), B::ZERO, B::ZERO) diff --git a/math/src/field/extensions/quadratic.rs b/math/src/field/extensions/quadratic.rs index a68c12d80..a8683b617 100644 --- a/math/src/field/extensions/quadratic.rs +++ b/math/src/field/extensions/quadratic.rs @@ -295,18 +295,6 @@ impl> From for QuadExtension { } } -impl> From for QuadExtension { - fn from(value: u128) -> Self { - Self(B::from(value), B::ZERO) - } -} - -impl> From for QuadExtension { - fn from(value: u64) -> Self { - Self(B::from(value), B::ZERO) - } -} - impl> From for QuadExtension { fn from(value: u32) -> Self { Self(B::from(value), B::ZERO) diff --git a/math/src/field/f128/mod.rs b/math/src/field/f128/mod.rs index e74ab5894..ffe63de56 100644 --- a/math/src/field/f128/mod.rs +++ b/math/src/field/f128/mod.rs @@ -327,14 +327,6 @@ impl ExtensibleField<3> for BaseElement { // TYPE CONVERSIONS // ================================================================================================ -impl From for BaseElement { - /// Converts a 128-bit value into a field element. If the value is greater than or equal to - /// the field modulus, modular reduction is silently performed. - fn from(value: u128) -> Self { - BaseElement::new(value) - } -} - impl From for BaseElement { /// Converts a 64-bit value into a field element. fn from(value: u64) -> Self { @@ -363,16 +355,6 @@ impl From for BaseElement { } } -impl From<[u8; 16]> for BaseElement { - /// Converts the value encoded in an array of 16 bytes into a field element. The bytes - /// are assumed to be in little-endian byte order. If the value is greater than or equal - /// to the field modulus, modular reduction is silently performed. - fn from(bytes: [u8; 16]) -> Self { - let value = u128::from_le_bytes(bytes); - BaseElement::from(value) - } -} - impl<'a> TryFrom<&'a [u8]> for BaseElement { type Error = String; diff --git a/math/src/field/f128/tests.rs b/math/src/field/f128/tests.rs index 570d92b52..b045816fe 100644 --- a/math/src/field/f128/tests.rs +++ b/math/src/field/f128/tests.rs @@ -8,7 +8,6 @@ use super::{ StarkField, Vec, M, }; use crate::field::{ExtensionOf, QuadExtension}; -use core::convert::TryFrom; use num_bigint::BigUint; use rand_utils::{rand_value, rand_vector}; use utils::SliceReader; @@ -26,7 +25,7 @@ fn add() { assert_eq!(BaseElement::from(5u8), BaseElement::from(2u8) + BaseElement::from(3u8)); // test overflow - let t = BaseElement::from(BaseElement::MODULUS - 1); + let t = BaseElement::new(BaseElement::MODULUS - 1); assert_eq!(BaseElement::ZERO, t + BaseElement::ONE); assert_eq!(BaseElement::ONE, t + BaseElement::from(2u8)); @@ -49,7 +48,7 @@ fn sub() { assert_eq!(BaseElement::from(2u8), BaseElement::from(5u8) - BaseElement::from(3u8)); // test underflow - let expected = BaseElement::from(BaseElement::MODULUS - 2); + let expected = BaseElement::new(BaseElement::MODULUS - 2); assert_eq!(expected, BaseElement::from(3u8) - BaseElement::from(5u8)); } @@ -65,13 +64,13 @@ fn mul() { // test overflow let m = BaseElement::MODULUS; - let t = BaseElement::from(m - 1); + let t = BaseElement::new(m - 1); assert_eq!(BaseElement::ONE, t * t); - assert_eq!(BaseElement::from(m - 2), t * BaseElement::from(2u8)); - assert_eq!(BaseElement::from(m - 4), t * BaseElement::from(4u8)); + assert_eq!(BaseElement::new(m - 2), t * BaseElement::from(2u8)); + assert_eq!(BaseElement::new(m - 4), t * BaseElement::from(4u8)); let t = (m + 1) / 2; - assert_eq!(BaseElement::ONE, BaseElement::from(t) * BaseElement::from(2u8)); + assert_eq!(BaseElement::ONE, BaseElement::new(t) * BaseElement::from(2u8)); // test random values let v1: Vec = rand_vector(1000); @@ -116,7 +115,7 @@ fn conjugate() { #[test] fn get_root_of_unity() { let root_40 = BaseElement::get_root_of_unity(40); - assert_eq!(BaseElement::from(23953097886125630542083529559205016746u128), root_40); + assert_eq!(BaseElement::new(23953097886125630542083529559205016746u128), root_40); assert_eq!(BaseElement::ONE, root_40.exp(u128::pow(2, 40))); let root_39 = BaseElement::get_root_of_unity(39); @@ -251,7 +250,8 @@ impl BaseElement { pub fn from_big_uint(value: BigUint) -> Self { let bytes = value.to_bytes_le(); let mut buffer = [0u8; 16]; - buffer[0..bytes.len()].copy_from_slice(&bytes); - BaseElement::try_from(buffer).unwrap() + buffer[..bytes.len()].copy_from_slice(&bytes); + let value = u128::from_le_bytes(buffer); + BaseElement::new(value) } } diff --git a/math/src/field/f62/mod.rs b/math/src/field/f62/mod.rs index 36ed6f061..dec435b85 100644 --- a/math/src/field/f62/mod.rs +++ b/math/src/field/f62/mod.rs @@ -18,8 +18,10 @@ use core::{ slice, }; use utils::{ - collections::Vec, string::ToString, AsBytes, ByteReader, ByteWriter, Deserializable, - DeserializationError, Randomizable, Serializable, + collections::Vec, + string::{String, ToString}, + AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable, + Serializable, }; #[cfg(feature = "serde")] @@ -58,7 +60,7 @@ const G: u64 = 4421547261963328785; /// backing type is `u64`. #[derive(Copy, Clone, Default)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -#[cfg_attr(feature = "serde", serde(from = "u64", into = "u64"))] +#[cfg_attr(feature = "serde", serde(try_from = "u64", into = "u64"))] pub struct BaseElement(u64); impl BaseElement { @@ -406,36 +408,6 @@ impl ExtensibleField<3> for BaseElement { // TYPE CONVERSIONS // ================================================================================================ -impl From for BaseElement { - /// Converts a 128-bit value into a field element. If the value is greater than or equal to - /// the field modulus, modular reduction is silently performed. - fn from(value: u128) -> Self { - // make sure the value is < 4M^2 - 4M + 1; this is overly conservative and a single - // subtraction of (M * 2^65) should be enough, but this needs to be proven - const M4: u128 = (2 * M as u128).pow(2) - 4 * (M as u128) + 1; - const Q: u128 = (2 * M as u128).pow(2) - 4 * (M as u128); - let mut v = value; - while v >= M4 { - v -= Q; - } - - // apply similar reduction as during multiplication; as output we get z = v * R^{-1} mod M, - // so we need to Montgomery-multiply it be R^3 to get z = v * R mod M - let q = (((v as u64) as u128) * U) as u64; - let z = v + (q as u128) * (M as u128); - let z = mul((z >> 64) as u64, R3); - BaseElement(z) - } -} - -impl From for BaseElement { - /// Converts a 64-bit value into a field element. If the value is greater than or equal to - /// the field modulus, modular reduction is silently performed. - fn from(value: u64) -> Self { - BaseElement::new(value) - } -} - impl From for BaseElement { /// Converts a 32-bit value into a field element. fn from(value: u32) -> Self { @@ -457,17 +429,6 @@ impl From for BaseElement { } } -impl From<[u8; 8]> for BaseElement { - /// Converts the value encoded in an array of 8 bytes into a field element. The bytes are - /// assumed to encode the element in the canonical representation in little-endian byte order. - /// If the value is greater than or equal to the field modulus, modular reduction is silently - /// performed. - fn from(bytes: [u8; 8]) -> Self { - let value = u64::from_le_bytes(bytes); - BaseElement::new(value) - } -} - impl From for u128 { fn from(value: BaseElement) -> Self { value.as_int() as u128 @@ -480,6 +441,29 @@ impl From for u64 { } } +impl TryFrom for BaseElement { + type Error = String; + + fn try_from(value: u64) -> Result { + if value >= M { + Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )) + } else { + Ok(Self::new(value)) + } + } +} + +impl TryFrom<[u8; 8]> for BaseElement { + type Error = String; + + fn try_from(bytes: [u8; 8]) -> Result { + let value = u64::from_le_bytes(bytes); + Self::try_from(value) + } +} + impl<'a> TryFrom<&'a [u8]> for BaseElement { type Error = DeserializationError; diff --git a/math/src/field/f62/tests.rs b/math/src/field/f62/tests.rs index 4de0e77dd..23fb445dd 100644 --- a/math/src/field/f62/tests.rs +++ b/math/src/field/f62/tests.rs @@ -23,7 +23,7 @@ fn add() { assert_eq!(BaseElement::from(5u8), BaseElement::from(2u8) + BaseElement::from(3u8)); // test overflow - let t = BaseElement::from(BaseElement::MODULUS - 1); + let t = BaseElement::new(BaseElement::MODULUS - 1); assert_eq!(BaseElement::ZERO, t + BaseElement::ONE); assert_eq!(BaseElement::ONE, t + BaseElement::from(2u8)); } @@ -38,7 +38,7 @@ fn sub() { assert_eq!(BaseElement::from(2u8), BaseElement::from(5u8) - BaseElement::from(3u8)); // test underflow - let expected = BaseElement::from(BaseElement::MODULUS - 2); + let expected = BaseElement::new(BaseElement::MODULUS - 2); assert_eq!(expected, BaseElement::from(3u8) - BaseElement::from(5u8)); } @@ -54,13 +54,13 @@ fn mul() { // test overflow let m = BaseElement::MODULUS; - let t = BaseElement::from(m - 1); + let t = BaseElement::new(m - 1); assert_eq!(BaseElement::ONE, t * t); - assert_eq!(BaseElement::from(m - 2), t * BaseElement::from(2u8)); - assert_eq!(BaseElement::from(m - 4), t * BaseElement::from(4u8)); + assert_eq!(BaseElement::new(m - 2), t * BaseElement::from(2u8)); + assert_eq!(BaseElement::new(m - 4), t * BaseElement::from(4u8)); let t = (m + 1) / 2; - assert_eq!(BaseElement::ONE, BaseElement::from(t) * BaseElement::from(2u8)); + assert_eq!(BaseElement::ONE, BaseElement::new(t) * BaseElement::from(2u8)); } #[test] @@ -212,13 +212,6 @@ fn get_root_of_unity() { // SERIALIZATION AND DESERIALIZATION // ------------------------------------------------------------------------------------------------ -#[test] -fn from_u128() { - let v = u128::MAX; - let e = BaseElement::from(v); - assert_eq!((v % super::M as u128) as u64, e.as_int()); -} - #[test] fn try_from_slice() { let bytes = vec![1, 0, 0, 0, 0, 0, 0, 0]; @@ -303,8 +296,8 @@ proptest! { #[test] fn add_proptest(a in any::(), b in any::()) { - let v1 = BaseElement::from(a); - let v2 = BaseElement::from(b); + let v1 = BaseElement::new(a); + let v2 = BaseElement::new(b); let result = v1 + v2; let expected = (a % super::M + b % super::M) % super::M; @@ -313,8 +306,8 @@ proptest! { #[test] fn sub_proptest(a in any::(), b in any::()) { - let v1 = BaseElement::from(a); - let v2 = BaseElement::from(b); + let v1 = BaseElement::new(a); + let v2 = BaseElement::new(b); let result = v1 - v2; let a = a % super::M; @@ -326,8 +319,8 @@ proptest! { #[test] fn mul_proptest(a in any::(), b in any::()) { - let v1 = BaseElement::from(a); - let v2 = BaseElement::from(b); + let v1 = BaseElement::new(a); + let v2 = BaseElement::new(b); let result = v1 * v2; let expected = (((a as u128) * (b as u128)) % super::M as u128) as u64; @@ -336,7 +329,7 @@ proptest! { #[test] fn exp_proptest(a in any::(), b in any::()) { - let result = BaseElement::from(a).exp(b); + let result = BaseElement::new(a).exp(b); let b = BigUint::from(b); let m = BigUint::from(super::M); @@ -346,7 +339,7 @@ proptest! { #[test] fn inv_proptest(a in any::()) { - let a = BaseElement::from(a); + let a = BaseElement::new(a); let b = a.inv(); let expected = if a == BaseElement::ZERO { BaseElement::ZERO } else { BaseElement::ONE }; @@ -359,17 +352,11 @@ proptest! { prop_assert_eq!(a % super::M, e.as_int()); } - #[test] - fn from_u128_proptest(v in any::()) { - let e = BaseElement::from(v); - assert_eq!((v % super::M as u128) as u64, e.as_int()); - } - // QUADRATIC EXTENSION // -------------------------------------------------------------------------------------------- #[test] fn quad_mul_inv_proptest(a0 in any::(), a1 in any::()) { - let a = QuadExtension::::new(BaseElement::from(a0), BaseElement::from(a1)); + let a = QuadExtension::::new(BaseElement::new(a0), BaseElement::new(a1)); let b = a.inv(); let expected = if a == QuadExtension::::ZERO { @@ -384,7 +371,7 @@ proptest! { // -------------------------------------------------------------------------------------------- #[test] fn cube_mul_inv_proptest(a0 in any::(), a1 in any::(), a2 in any::()) { - let a = CubeExtension::::new(BaseElement::from(a0), BaseElement::from(a1), BaseElement::from(a2)); + let a = CubeExtension::::new(BaseElement::new(a0), BaseElement::new(a1), BaseElement::new(a2)); let b = a.inv(); let expected = if a == CubeExtension::::ZERO { diff --git a/math/src/field/f64/mod.rs b/math/src/field/f64/mod.rs index df24030aa..86d1a9a12 100644 --- a/math/src/field/f64/mod.rs +++ b/math/src/field/f64/mod.rs @@ -23,8 +23,10 @@ use core::{ slice, }; use utils::{ - collections::Vec, string::ToString, AsBytes, ByteReader, ByteWriter, Deserializable, - DeserializationError, Randomizable, Serializable, + collections::Vec, + string::{String, ToString}, + AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable, + Serializable, }; #[cfg(feature = "serde")] @@ -54,12 +56,15 @@ const ELEMENT_BYTES: usize = core::mem::size_of::(); /// The backing type is `u64` but the internal values are always in the range [0, M). #[derive(Copy, Clone, Default)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -#[cfg_attr(feature = "serde", serde(from = "u64", into = "u64"))] +#[cfg_attr(feature = "serde", serde(try_from = "u64", into = "u64"))] pub struct BaseElement(u64); impl BaseElement { /// Creates a new field element from the provided `value`; the value is converted into /// Montgomery representation. + /// + /// If the value is greater than or equal to the field modulus, modular reduction is + /// silently performed. pub const fn new(value: u64) -> BaseElement { Self(mont_red_cst((value as u128) * (R2 as u128))) } @@ -75,6 +80,12 @@ impl BaseElement { self.0 } + /// Returns canonical integer representation of this field element. + #[inline(always)] + pub const fn as_int(&self) -> u64 { + mont_to_int(self.0) + } + /// Computes an exponentiation to the power 7. This is useful for computing Rescue-Prime /// S-Box over this field. #[inline(always)] @@ -88,7 +99,7 @@ impl BaseElement { /// Multiplies an element that is less than 2^32 by a field element. This implementation /// is faster as it avoids the use of Montgomery reduction. #[inline(always)] - pub fn mul_small(self, rhs: u32) -> Self { + pub const fn mul_small(self, rhs: u32) -> Self { let s = (self.inner() as u128) * (rhs as u128); let s_hi = (s >> 64) as u64; let s_lo = s as u64; @@ -273,18 +284,9 @@ impl StarkField for BaseElement { M.to_le_bytes().to_vec() } - // Converts a field element in Montgomery form to canonical form. That is, given x, it computes - // x/2^64 modulo M. This is exactly what mont_red_cst does only that it does it more efficiently - // using the fact that a field element in Montgomery form is stored as a u64 and thus one can - // use this to simplify mont_red_cst in this case. #[inline] fn as_int(&self) -> Self::PositiveInteger { - let x = self.0; - let (a, e) = x.overflowing_add(x << 32); - let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64); - - let (r, c) = 0u64.overflowing_sub(b); - r.wrapping_sub(0u32.wrapping_sub(c as u32) as u64) + mont_to_int(self.0) } } @@ -513,23 +515,6 @@ impl ExtensibleField<3> for BaseElement { // TYPE CONVERSIONS // ================================================================================================ -impl From for BaseElement { - /// Converts a 128-bit value into a field element. - fn from(x: u128) -> Self { - //const R3: u128 = 1 (= 2^192 mod M );// thus we get that mont_red_var((mont_red_var(x) as u128) * R3) becomes - //Self(mont_red_var(mont_red_var(x) as u128)) // Variable time implementation - Self(mont_red_cst(mont_red_cst(x) as u128)) // Constant time implementation - } -} - -impl From for BaseElement { - /// Converts a 64-bit value into a field element. If the value is greater than or equal to - /// the field modulus, modular reduction is silently performed. - fn from(value: u64) -> Self { - Self::new(value) - } -} - impl From for BaseElement { /// Converts a 32-bit value into a field element. fn from(value: u32) -> Self { @@ -551,14 +536,26 @@ impl From for BaseElement { } } -impl From<[u8; 8]> for BaseElement { - /// Converts the value encoded in an array of 8 bytes into a field element. The bytes are - /// assumed to encode the element in the canonical representation in little-endian byte order. - /// If the value is greater than or equal to the field modulus, modular reduction is silently - /// performed. - fn from(bytes: [u8; 8]) -> Self { +impl TryFrom for BaseElement { + type Error = String; + + fn try_from(value: u64) -> Result { + if value >= M { + Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )) + } else { + Ok(Self::new(value)) + } + } +} + +impl TryFrom<[u8; 8]> for BaseElement { + type Error = String; + + fn try_from(bytes: [u8; 8]) -> Result { let value = u64::from_le_bytes(bytes); - Self::new(value) + Self::try_from(value) } } @@ -583,16 +580,8 @@ impl<'a> TryFrom<&'a [u8]> for BaseElement { bytes.len(), ))); } - let value = bytes - .try_into() - .map(u64::from_le_bytes) - .map_err(|error| DeserializationError::UnknownError(format!("{error}")))?; - if value >= M { - return Err(DeserializationError::InvalidValue(format!( - "invalid field element: value {value} is greater than or equal to the field modulus" - ))); - } - Ok(Self::new(value)) + let bytes: [u8; 8] = bytes.try_into().expect("slice to array conversion failed"); + bytes.try_into().map_err(DeserializationError::InvalidValue) } } @@ -617,7 +606,7 @@ impl AsBytes for BaseElement { } // SERIALIZATION / DESERIALIZATION -// ------------------------------------------------------------------------------------------------ +// ================================================================================================ impl Serializable for BaseElement { fn write_into(&self, target: &mut W) { @@ -638,6 +627,9 @@ impl Deserializable for BaseElement { } } +// HELPER FUNCTIONS +// ================================================================================================ + /// Squares the base N number of times and multiplies the result by the tail value. #[inline(always)] fn exp_acc(base: BaseElement, tail: BaseElement) -> BaseElement { @@ -677,10 +669,23 @@ const fn mont_red_cst(x: u128) -> u64 { r.wrapping_sub(0u32.wrapping_sub(c as u32) as u64) } +// Converts a field element in Montgomery form to canonical form. That is, given x, it computes +// x/2^64 modulo M. This is exactly what mont_red_cst does only that it does it more efficiently +// using the fact that a field element in Montgomery form is stored as a u64 and thus one can +// use this to simplify mont_red_cst in this case. +#[inline(always)] +const fn mont_to_int(x: u64) -> u64 { + let (a, e) = x.overflowing_add(x << 32); + let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64); + + let (r, c) = 0u64.overflowing_sub(b); + r.wrapping_sub(0u32.wrapping_sub(c as u32) as u64) +} + /// Test of equality between two BaseField elements; return value is /// 0xFFFFFFFFFFFFFFFF if the two values are equal, or 0 otherwise. #[inline(always)] -pub fn equals(lhs: u64, rhs: u64) -> u64 { +fn equals(lhs: u64, rhs: u64) -> u64 { let t = lhs ^ rhs; !((((t | t.wrapping_neg()) as i64) >> 63) as u64) } diff --git a/math/src/field/f64/tests.rs b/math/src/field/f64/tests.rs index 5e08720a5..2f5691dac 100644 --- a/math/src/field/f64/tests.rs +++ b/math/src/field/f64/tests.rs @@ -45,7 +45,7 @@ fn sub() { #[test] fn neg() { assert_eq!(BaseElement::ZERO, -BaseElement::ZERO); - assert_eq!(BaseElement::from(super::M - 1), -BaseElement::ONE); + assert_eq!(BaseElement::new(super::M - 1), -BaseElement::ONE); let r: BaseElement = rand_value(); assert_eq!(r, -(-r)); @@ -63,20 +63,20 @@ fn mul() { // test overflow let m = BaseElement::MODULUS; - let t = BaseElement::from(m - 1); + let t = BaseElement::new(m - 1); assert_eq!(BaseElement::ONE, t * t); - assert_eq!(BaseElement::from(m - 2), t * BaseElement::from(2u8)); - assert_eq!(BaseElement::from(m - 4), t * BaseElement::from(4u8)); + assert_eq!(BaseElement::new(m - 2), t * BaseElement::from(2u8)); + assert_eq!(BaseElement::new(m - 4), t * BaseElement::from(4u8)); let t = (m + 1) / 2; - assert_eq!(BaseElement::ONE, BaseElement::from(t) * BaseElement::from(2u8)); + assert_eq!(BaseElement::ONE, BaseElement::new(t) * BaseElement::from(2u8)); } #[test] fn mul_small() { // test overflow let m = BaseElement::MODULUS; - let t = BaseElement::from(m - 1); + let t = BaseElement::new(m - 1); let a = u32::MAX; let expected = BaseElement::new(a as u64) * t; @@ -149,13 +149,6 @@ fn get_root_of_unity() { // SERIALIZATION AND DESERIALIZATION // ------------------------------------------------------------------------------------------------ -#[test] -fn from_u128() { - let v = u128::MAX; - let e = BaseElement::from(v); - assert_eq!((v % super::M as u128) as u64, e.as_int()); -} - #[test] fn try_from_slice() { let bytes = vec![1, 0, 0, 0, 0, 0, 0, 0]; @@ -380,8 +373,8 @@ proptest! { #[test] fn add_proptest(a in any::(), b in any::()) { - let v1 = BaseElement::from(a); - let v2 = BaseElement::from(b); + let v1 = BaseElement::new(a); + let v2 = BaseElement::new(b); let result = v1 + v2; let expected = (((a as u128) + (b as u128)) % (super::M as u128)) as u64; @@ -390,8 +383,8 @@ proptest! { #[test] fn sub_proptest(a in any::(), b in any::()) { - let v1 = BaseElement::from(a); - let v2 = BaseElement::from(b); + let v1 = BaseElement::new(a); + let v2 = BaseElement::new(b); let result = v1 - v2; let a = a % super::M; @@ -403,7 +396,7 @@ proptest! { #[test] fn neg_proptest(a in any::()) { - let v = BaseElement::from(a); + let v = BaseElement::new(a); let expected = super::M - (a % super::M); prop_assert_eq!(expected, (-v).as_int()); @@ -411,8 +404,8 @@ proptest! { #[test] fn mul_proptest(a in any::(), b in any::()) { - let v1 = BaseElement::from(a); - let v2 = BaseElement::from(b); + let v1 = BaseElement::new(a); + let v2 = BaseElement::new(b); let result = v1 * v2; let expected = (((a as u128) * (b as u128)) % super::M as u128) as u64; @@ -421,7 +414,7 @@ proptest! { #[test] fn mul_small_proptest(a in any::(), b in any::()) { - let v1 = BaseElement::from(a); + let v1 = BaseElement::new(a); let v2 = b; let result = v1.mul_small(v2); @@ -431,7 +424,7 @@ proptest! { #[test] fn double_proptest(x in any::()) { - let v = BaseElement::from(x); + let v = BaseElement::new(x); let result = v.double(); let expected = (((x as u128) * 2) % super::M as u128) as u64; @@ -440,7 +433,7 @@ proptest! { #[test] fn exp_proptest(a in any::(), b in any::()) { - let result = BaseElement::from(a).exp(b); + let result = BaseElement::new(a).exp(b); let b = BigUint::from(b); let m = BigUint::from(super::M); @@ -450,7 +443,7 @@ proptest! { #[test] fn inv_proptest(a in any::()) { - let a = BaseElement::from(a); + let a = BaseElement::new(a); let b = a.inv(); let expected = if a == BaseElement::ZERO { BaseElement::ZERO } else { BaseElement::ONE }; @@ -463,17 +456,11 @@ proptest! { prop_assert_eq!(a % super::M, e.as_int()); } - #[test] - fn from_u128_proptest(v in any::()) { - let e = BaseElement::from(v); - assert_eq!((v % super::M as u128) as u64, e.as_int()); - } - // QUADRATIC EXTENSION // -------------------------------------------------------------------------------------------- #[test] fn quad_mul_inv_proptest(a0 in any::(), a1 in any::()) { - let a = QuadExtension::::new(BaseElement::from(a0), BaseElement::from(a1)); + let a = QuadExtension::::new(BaseElement::new(a0), BaseElement::new(a1)); let b = a.inv(); let expected = if a == QuadExtension::::ZERO { @@ -486,7 +473,7 @@ proptest! { #[test] fn quad_square_proptest(a0 in any::(), a1 in any::()) { - let a = QuadExtension::::new(BaseElement::from(a0), BaseElement::from(a1)); + let a = QuadExtension::::new(BaseElement::new(a0), BaseElement::new(a1)); let expected = a * a; prop_assert_eq!(expected, a.square()); @@ -496,7 +483,7 @@ proptest! { // -------------------------------------------------------------------------------------------- #[test] fn cube_mul_inv_proptest(a0 in any::(), a1 in any::(), a2 in any::()) { - let a = CubeExtension::::new(BaseElement::from(a0), BaseElement::from(a1), BaseElement::from(a2)); + let a = CubeExtension::::new(BaseElement::new(a0), BaseElement::new(a1), BaseElement::new(a2)); let b = a.inv(); let expected = if a == CubeExtension::::ZERO { @@ -509,7 +496,7 @@ proptest! { #[test] fn cube_square_proptest(a0 in any::(), a1 in any::(), a2 in any::()) { - let a = CubeExtension::::new(BaseElement::from(a0), BaseElement::from(a1), BaseElement::from(a2)); + let a = CubeExtension::::new(BaseElement::new(a0), BaseElement::new(a1), BaseElement::new(a2)); let expected = a * a; prop_assert_eq!(expected, a.square()); diff --git a/math/src/field/traits.rs b/math/src/field/traits.rs index 68c92d530..124e2980b 100644 --- a/math/src/field/traits.rs +++ b/math/src/field/traits.rs @@ -46,8 +46,6 @@ pub trait FieldElement: + MulAssign + DivAssign + Neg - + From - + From + From + From + From diff --git a/math/src/polynom/tests.rs b/math/src/polynom/tests.rs index dce797114..b3631a0c4 100644 --- a/math/src/polynom/tests.rs +++ b/math/src/polynom/tests.rs @@ -12,12 +12,12 @@ use utils::collections::Vec; #[test] fn eval() { - let x = BaseElement::from(11269864713250585702u128); + let x = BaseElement::new(11269864713250585702u128); let poly: [BaseElement; 4] = [ - BaseElement::from(384863712573444386u128), - BaseElement::from(7682273369345308472u128), - BaseElement::from(13294661765012277990u128), - BaseElement::from(16234810094004944758u128), + BaseElement::new(384863712573444386u128), + BaseElement::new(7682273369345308472u128), + BaseElement::new(13294661765012277990u128), + BaseElement::new(16234810094004944758u128), ]; assert_eq!(BaseElement::ZERO, super::eval(&poly[..0], x)); @@ -40,14 +40,14 @@ fn eval() { #[test] fn add() { let poly1: [BaseElement; 3] = [ - BaseElement::from(384863712573444386u128), - BaseElement::from(7682273369345308472u128), - BaseElement::from(13294661765012277990u128), + BaseElement::new(384863712573444386u128), + BaseElement::new(7682273369345308472u128), + BaseElement::new(13294661765012277990u128), ]; let poly2: [BaseElement; 3] = [ - BaseElement::from(9918505539874556741u128), - BaseElement::from(16401861429499852246u128), - BaseElement::from(12181445947541805654u128), + BaseElement::new(9918505539874556741u128), + BaseElement::new(16401861429499852246u128), + BaseElement::new(12181445947541805654u128), ]; // same degree @@ -66,14 +66,14 @@ fn add() { #[test] fn sub() { let poly1: [BaseElement; 3] = [ - BaseElement::from(384863712573444386u128), - BaseElement::from(7682273369345308472u128), - BaseElement::from(13294661765012277990u128), + BaseElement::new(384863712573444386u128), + BaseElement::new(7682273369345308472u128), + BaseElement::new(13294661765012277990u128), ]; let poly2: [BaseElement; 3] = [ - BaseElement::from(9918505539874556741u128), - BaseElement::from(16401861429499852246u128), - BaseElement::from(12181445947541805654u128), + BaseElement::new(9918505539874556741u128), + BaseElement::new(16401861429499852246u128), + BaseElement::new(12181445947541805654u128), ]; // same degree @@ -92,14 +92,14 @@ fn sub() { #[test] fn mul() { let poly1: [BaseElement; 3] = [ - BaseElement::from(384863712573444386u128), - BaseElement::from(7682273369345308472u128), - BaseElement::from(13294661765012277990u128), + BaseElement::new(384863712573444386u128), + BaseElement::new(7682273369345308472u128), + BaseElement::new(13294661765012277990u128), ]; let poly2: [BaseElement; 3] = [ - BaseElement::from(9918505539874556741u128), - BaseElement::from(16401861429499852246u128), - BaseElement::from(12181445947541805654u128), + BaseElement::new(9918505539874556741u128), + BaseElement::new(16401861429499852246u128), + BaseElement::new(12181445947541805654u128), ]; // same degree @@ -134,14 +134,14 @@ fn mul() { #[test] fn div() { let poly1 = vec![ - BaseElement::from(384863712573444386u128), - BaseElement::from(7682273369345308472u128), - BaseElement::from(13294661765012277990u128), + BaseElement::new(384863712573444386u128), + BaseElement::new(7682273369345308472u128), + BaseElement::new(13294661765012277990u128), ]; let poly2 = vec![ - BaseElement::from(9918505539874556741u128), - BaseElement::from(16401861429499852246u128), - BaseElement::from(12181445947541805654u128), + BaseElement::new(9918505539874556741u128), + BaseElement::new(16401861429499852246u128), + BaseElement::new(12181445947541805654u128), ]; // divide degree 4 by degree 2 @@ -153,8 +153,8 @@ fn div() { assert_eq!(poly1[..2].to_vec(), super::div(&poly3, &poly2)); // divide degree 3 by degree 3 - let poly3 = super::mul_by_scalar(&poly1, BaseElement::from(11269864713250585702u128)); - assert_eq!(vec![BaseElement::from(11269864713250585702u128)], super::div(&poly3, &poly1)); + let poly3 = super::mul_by_scalar(&poly1, BaseElement::new(11269864713250585702u128)); + assert_eq!(vec![BaseElement::new(11269864713250585702u128)], super::div(&poly3, &poly1)); } #[test] diff --git a/prover/src/matrix/segments.rs b/prover/src/matrix/segments.rs index 39eac1fe9..2b6687a75 100644 --- a/prover/src/matrix/segments.rs +++ b/prover/src/matrix/segments.rs @@ -229,6 +229,7 @@ mod concurrent { /// In-place recursive FFT with permuted output. /// Adapted from: https://github.com/0xProject/OpenZKP/tree/master/algebra/primefield/src/fft + #[allow(clippy::needless_range_loop)] pub fn split_radix_fft(data: &mut [[B; N]], twiddles: &[B]) { // generator of the domain should be in the middle of twiddles let n = data.len(); @@ -246,7 +247,7 @@ mod concurrent { // apply inner FFTs data.par_chunks_mut(outer_len) - .for_each(|row| row.fft_in_place_raw(&twiddles, stretch, stretch, 0)); + .for_each(|row| row.fft_in_place_raw(twiddles, stretch, stretch, 0)); // transpose inner x inner x stretch square matrix transpose_square_stretch(data, inner_len, stretch); @@ -259,12 +260,12 @@ mod concurrent { let mut outer_twiddle = inner_twiddle; for element in row.iter_mut().skip(1) { for col_idx in 0..N { - element[col_idx] = element[col_idx] * outer_twiddle; + element[col_idx] *= outer_twiddle; } - outer_twiddle = outer_twiddle * inner_twiddle; + outer_twiddle *= inner_twiddle; } } - row.fft_in_place(&twiddles) + row.fft_in_place(twiddles) }); }