From e496a6536288b0b23cff6766842acc55d3a09b44 Mon Sep 17 00:00:00 2001 From: armfazh Date: Fri, 4 Mar 2022 13:23:35 -0800 Subject: [PATCH 1/2] Adds Set/Copy methods for group element and scalar --- group/group.go | 4 ++++ group/ristretto255.go | 18 ++++++++++++++++++ group/short.go | 17 +++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/group/group.go b/group/group.go index 0bfaea6d9..fb56eba25 100644 --- a/group/group.go +++ b/group/group.go @@ -31,6 +31,8 @@ type Group interface { // Element represents an abstract element of a prime-order group. type Element interface { + Set(Element) Element + Copy() Element IsIdentity() bool IsEqual(Element) bool Add(Element, Element) Element @@ -45,6 +47,8 @@ type Element interface { // Scalar represents an integer scalar. type Scalar interface { + Set(Scalar) Scalar + Copy() Scalar IsEqual(Scalar) bool SetUint64(uint64) Add(Scalar, Scalar) Scalar diff --git a/group/ristretto255.go b/group/ristretto255.go index 6099b53af..cbe4be00b 100644 --- a/group/ristretto255.go +++ b/group/ristretto255.go @@ -121,6 +121,15 @@ func (e *ristrettoElement) IsEqual(x Element) bool { return e.p.Equals(&x.(*ristrettoElement).p) } +func (e *ristrettoElement) Set(x Element) Element { + e.p.Set(&x.(*ristrettoElement).p) + return e +} + +func (e *ristrettoElement) Copy() Element { + return &ristrettoElement{*new(r255.Point).Set(&e.p)} +} + func (e *ristrettoElement) Add(x Element, y Element) Element { e.p.Add(&x.(*ristrettoElement).p, &y.(*ristrettoElement).p) return e @@ -164,6 +173,15 @@ func (s *ristrettoScalar) IsEqual(x Scalar) bool { return s.s.Equals(&x.(*ristrettoScalar).s) } +func (s *ristrettoScalar) Set(x Scalar) Scalar { + s.s.Set(&x.(*ristrettoScalar).s) + return s +} + +func (s *ristrettoScalar) Copy() Scalar { + return &ristrettoScalar{*new(r255.Scalar).Set(&s.s)} +} + func (s *ristrettoScalar) Add(x Scalar, y Scalar) Scalar { s.s.Add(&x.(*ristrettoScalar).s, &y.(*ristrettoScalar).s) return s diff --git a/group/short.go b/group/short.go index fc25b681c..151a749e6 100644 --- a/group/short.go +++ b/group/short.go @@ -132,6 +132,14 @@ func (e *wElt) IsEqual(o Element) bool { return e.x.Cmp(oo.x) == 0 && e.y.Cmp(oo.y) == 0 } +func (e *wElt) Set(a Element) Element { + aa := e.cvtElt(a) + e.x.Set(aa.x) + e.y.Set(aa.y) + return e +} + +func (e *wElt) Copy() Element { return e.wG.zeroElement().Set(e) } func (e *wElt) Add(a, b Element) Element { aa, bb := e.cvtElt(a), e.cvtElt(b) e.x, e.y = e.c.Add(aa.x, aa.y, bb.x, bb.y) @@ -225,6 +233,15 @@ func (s *wScl) fromBig(b *big.Int) { } } +func (s *wScl) Set(a Scalar) Scalar { + aa := s.cvtScl(a) + if err := s.UnmarshalBinary(aa.k); err != nil { + panic(err) + } + return s +} + +func (s *wScl) Copy() Scalar { return s.wG.zeroScalar().Set(s) } func (s *wScl) Add(a, b Scalar) Scalar { aa, bb := s.cvtScl(a), s.cvtScl(b) r := new(big.Int) From 8afee8aad0e3e1b42bd548e1e368bcbc13c2b784 Mon Sep 17 00:00:00 2001 From: armfazh Date: Fri, 4 Mar 2022 13:42:34 -0800 Subject: [PATCH 2/2] Updates CopyBlinds to use group.Copy --- oprf/client.go | 3 +-- oprf/oprf.go | 8 ++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/oprf/client.go b/oprf/client.go index 78aae0c6e..981808c0d 100644 --- a/oprf/client.go +++ b/oprf/client.go @@ -64,8 +64,7 @@ func (c client) blind(inputs [][]byte, blinds []Blind) (*FinalizeData, *Evaluati return finData, evalReq, nil } -func (c client) unblind(serUnblindeds [][]byte, blindeds []group.Element, blind []Blind) error { - var err error +func (c client) unblind(serUnblindeds [][]byte, blindeds []group.Element, blind []Blind) (err error) { invBlind := c.params.group.NewScalar() U := c.params.group.NewElement() diff --git a/oprf/oprf.go b/oprf/oprf.go index 2bbbc36a3..88341f747 100644 --- a/oprf/oprf.go +++ b/oprf/oprf.go @@ -270,8 +270,12 @@ type FinalizeData struct { // CopyBlinds copies the serialized blinds to use when determinstically // invoking DeterministicBlind. -func (d FinalizeData) CopyBlinds() []Blind { - return d.blinds +func (f FinalizeData) CopyBlinds() []Blind { + out := make([]Blind, len(f.blinds)) + for i, b := range f.blinds { + out[i] = b.Copy() + } + return out } // EvaluationRequest contains the blinded elements to be evaluated by the Server.