Skip to content

Commit

Permalink
fix subtle bug in solver (#2)
Browse files Browse the repository at this point in the history
Co-authored-by: Ruben <ruben@polycephaly.org>
  • Loading branch information
MrPugh and Ruben committed Jun 10, 2023
1 parent 8550b1a commit 29f7225
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 16 deletions.
4 changes: 2 additions & 2 deletions ref/matrixmod.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

#include "params.h"

#define pmod_mat_entry(M, M_r, M_c, r, c) M[M_c*r+c]
#define pmod_mat_entry(M, M_r, M_c, r, c) M[(M_c)*(r)+(c)]

#define pmod_mat_set_entry(M, M_r, M_c, r, c, v) (M[M_c*r+c] = v)
#define pmod_mat_set_entry(M, M_r, M_c, r, c, v) (M[(M_c)*(r)+(c)] = v)

#define pmod_mat_t GFq_t

Expand Down
93 changes: 79 additions & 14 deletions ref/util.c
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,16 @@ int solve(pmod_mat_t *A, pmod_mat_t *B_inv, pmod_mat_t *G0prime, GFq_t Amm)

for (int i = 0; i < MEDS_m; i++)
for (int j = 0; j < MEDS_n; j++)
N[j*MEDS_m + i] = MEDS_p - P0prime0[i*MEDS_n + j];
N[j*MEDS_m + i] = (MEDS_p - P0prime0[i*MEDS_n + j]) % MEDS_p;

LOG_MAT(N, MEDS_n, MEDS_m);
//LOG_MAT(N, MEDS_n, MEDS_m);


pmod_mat_t M[MEDS_n*(MEDS_m + MEDS_m + 2)] = {0};

for (int i = 0; i < MEDS_m; i++)
for (int j = 0; j < MEDS_n; j++)
M[j*(MEDS_m + MEDS_m + 2) + i] = MEDS_p - P0prime1[i*MEDS_n + j];
M[j*(MEDS_m + MEDS_m + 2) + i] = (MEDS_p - P0prime1[i*MEDS_n + j]) % MEDS_p;

for (int i = 0; i < MEDS_m; i++)
for (int j = 0; j < MEDS_n; j++)
Expand All @@ -240,22 +240,87 @@ int solve(pmod_mat_t *A, pmod_mat_t *B_inv, pmod_mat_t *G0prime, GFq_t Amm)
M[j*(MEDS_m + MEDS_m + 2) + MEDS_m + MEDS_n + 1] = ((uint32_t)P0prime1[(MEDS_m-1)*MEDS_n + j] * (uint32_t)Amm) % MEDS_p;


LOG_MAT(M, MEDS_n, MEDS_m + MEDS_m + 2);
//LOG_MAT(M, MEDS_n, MEDS_m + MEDS_m + 2);

if (pmod_mat_row_echelon_ct(M, MEDS_n, MEDS_m + MEDS_m + 2) < 0)
if (pmod_mat_syst_ct(M, MEDS_n-1, MEDS_m + MEDS_m + 2) < 0)
return -1;

LOG_MAT(M, MEDS_n, MEDS_m + MEDS_m + 2);
//LOG_MAT_FMT(M, MEDS_n, MEDS_m + MEDS_m + 2, "M part");

// eliminate last row
for (int r = 0; r < MEDS_n-1; r++)
{
uint64_t factor = pmod_mat_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, r);

// ignore last column
for (int c = MEDS_n-1; c < MEDS_m + MEDS_m + 1; c++)
{
uint64_t tmp0 = pmod_mat_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, c);
uint64_t tmp1 = pmod_mat_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, r, c);

int64_t val = (tmp1 * factor) % MEDS_p;

val = tmp0 - val;

val += MEDS_p * (val < 0);

pmod_mat_set_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, c, val);
}

pmod_mat_set_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, r, 0);
}

// normalize last row
{
uint64_t val = pmod_mat_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, MEDS_n-1);

if (val == 0)
return -1;

val = GF_inv(val);

// ignore last column
for (int c = MEDS_n; c < MEDS_m + MEDS_m + 1; c++)
{
uint64_t tmp = pmod_mat_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, c);

tmp = (tmp * val) % MEDS_p;

pmod_mat_set_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, c, tmp);
}
}

pmod_mat_set_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, MEDS_n-1, 1);

M[MEDS_n*(MEDS_m + MEDS_m + 2)-1] = 0;

LOG_MAT(M, MEDS_n, MEDS_m + MEDS_m + 2);
//LOG_MAT_FMT(M, MEDS_n, MEDS_m + MEDS_m + 2, "M red");

// back substitute
for (int r = 0; r < MEDS_n-1; r++)
{
uint64_t factor = pmod_mat_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, r, MEDS_n-1);

// ignore last column
for (int c = MEDS_n; c < MEDS_m + MEDS_m + 1; c++)
{
uint64_t tmp0 = pmod_mat_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, MEDS_n-1, c);
uint64_t tmp1 = pmod_mat_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, r, c);

int64_t val = (tmp0 * factor) % MEDS_p;

val = tmp1 - val;

val += MEDS_p * (val < 0);

pmod_mat_set_entry(M, MEDS_n, MEDS_m + MEDS_m + 2, r, c, val);
}

pmod_mat_set_entry(M, M_r, MEDS_m + MEDS_m + 2, r, MEDS_n-1, 0);
}

pmod_mat_back_substitution_ct(M, MEDS_n, MEDS_m + MEDS_m + 2);

LOG_MAT(M, MEDS_n, MEDS_m + MEDS_m + 2);
//LOG_MAT_FMT(M, MEDS_n, MEDS_m + MEDS_m + 2, "M done");


GFq_t sol[MEDS_n*MEDS_n + MEDS_m*MEDS_m] = {0};
Expand All @@ -271,7 +336,7 @@ int solve(pmod_mat_t *A, pmod_mat_t *B_inv, pmod_mat_t *G0prime, GFq_t Amm)
for (int i = 0; i < MEDS_n; i++)
sol[MEDS_n*MEDS_n - MEDS_n + i] = ((uint32_t)P0prime0[(MEDS_m-1)*MEDS_n + i] * (uint32_t)Amm) % MEDS_p;

LOG_VEC_FMT(sol, MEDS_n*MEDS_n + MEDS_m*MEDS_m, "initial sol");
//LOG_VEC_FMT(sol, MEDS_n*MEDS_n + MEDS_m*MEDS_m, "initial sol");


// incomplete blocks:
Expand All @@ -292,7 +357,7 @@ int solve(pmod_mat_t *A, pmod_mat_t *B_inv, pmod_mat_t *G0prime, GFq_t Amm)
sol[MEDS_n*MEDS_n - MEDS_n + i] = tmp % MEDS_p;
}

LOG_VEC_FMT(sol, MEDS_n*MEDS_n + MEDS_m*MEDS_m, "incomplete blocks");
//LOG_VEC_FMT(sol, MEDS_n*MEDS_n + MEDS_m*MEDS_m, "incomplete blocks");


// complete blocks:
Expand All @@ -315,7 +380,7 @@ int solve(pmod_mat_t *A, pmod_mat_t *B_inv, pmod_mat_t *G0prime, GFq_t Amm)
sol[MEDS_n*MEDS_n - block*MEDS_n + i] = tmp % MEDS_p;
}

LOG_VEC_FMT(sol, MEDS_n*MEDS_n + MEDS_m*MEDS_m, "complete blocks");
//LOG_VEC_FMT(sol, MEDS_n*MEDS_n + MEDS_m*MEDS_m, "complete blocks");


for (int i = 0; i < MEDS_m * MEDS_m; i++)
Expand All @@ -324,8 +389,8 @@ int solve(pmod_mat_t *A, pmod_mat_t *B_inv, pmod_mat_t *G0prime, GFq_t Amm)
for (int i = 0; i < MEDS_n * MEDS_n; i++)
B_inv[i] = sol[i];

LOG_MAT(A, MEDS_m, MEDS_m);
LOG_MAT(B_inv, MEDS_n, MEDS_n);
//LOG_MAT(A, MEDS_m, MEDS_m);
//LOG_MAT(B_inv, MEDS_n, MEDS_n);

return 0;
}
Expand Down

0 comments on commit 29f7225

Please sign in to comment.