269 lines
6.2 KiB
Go
269 lines
6.2 KiB
Go
package ecc
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"errors"
|
|
"math/big"
|
|
"math/bits"
|
|
)
|
|
|
|
//-------------------------------------------------------
|
|
// Ate loop counter (not used for each curve)
|
|
|
|
// NafDecomposition gets the naf decomposition of a big number
|
|
func NafDecomposition(a *big.Int, result []int8) int {
|
|
|
|
var zero, one, two, three big.Int
|
|
|
|
one.SetUint64(1)
|
|
two.SetUint64(2)
|
|
three.SetUint64(3)
|
|
|
|
length := 0
|
|
|
|
// some buffers
|
|
var buf, aCopy big.Int
|
|
aCopy.Set(a)
|
|
|
|
for aCopy.Cmp(&zero) != 0 {
|
|
|
|
// if aCopy % 2 == 0
|
|
buf.And(&aCopy, &one)
|
|
|
|
// aCopy even
|
|
if buf.Cmp(&zero) == 0 {
|
|
result[length] = 0
|
|
} else { // aCopy odd
|
|
buf.And(&aCopy, &three)
|
|
if buf.Cmp(&three) == 0 {
|
|
result[length] = -1
|
|
aCopy.Add(&aCopy, &one)
|
|
} else {
|
|
result[length] = 1
|
|
}
|
|
}
|
|
aCopy.Rsh(&aCopy, 1)
|
|
length++
|
|
}
|
|
return length
|
|
}
|
|
|
|
//-------------------------------------------------------
|
|
// GLV utils
|
|
|
|
// Lattice represents a Z module spanned by V1, V2.
|
|
// det is the associated determinant.
|
|
type Lattice struct {
|
|
V1, V2 [2]big.Int
|
|
Det big.Int
|
|
}
|
|
|
|
// PrecomputeLattice res such that res.V1, res.V2
|
|
// are short vectors satisfying v11+v12lambda=v21+v22lambda=0[r].
|
|
// cf https://www.iacr.org/archive/crypto2001/21390189.pdf
|
|
func PrecomputeLattice(r, lambda *big.Int, res *Lattice) {
|
|
|
|
var rst [2][3]big.Int
|
|
var tmp [3]big.Int
|
|
var quotient, remainder, sqroot, _r, _t big.Int
|
|
|
|
rst[0][0].Set(r)
|
|
rst[0][1].SetUint64(1)
|
|
rst[0][2].SetUint64(0)
|
|
|
|
rst[1][0].Set(lambda)
|
|
rst[1][1].SetUint64(0)
|
|
rst[1][2].SetUint64(1)
|
|
|
|
sqroot.Sqrt(r)
|
|
|
|
var one big.Int
|
|
one.SetUint64(1)
|
|
|
|
// r_i+1 = r_i-1 - q_i.r_i
|
|
// s_i+1 = s_i-1 - q_i.s_i
|
|
// t_i+1 = t_i-1 - q_i.s_i
|
|
for rst[1][0].Cmp(&sqroot) >= 1 {
|
|
|
|
quotient.Div(&rst[0][0], &rst[1][0])
|
|
remainder.Mod(&rst[0][0], &rst[1][0])
|
|
|
|
tmp[0].Set(&rst[1][0])
|
|
tmp[1].Set(&rst[1][1])
|
|
tmp[2].Set(&rst[1][2])
|
|
|
|
rst[1][0].Set(&remainder)
|
|
rst[1][1].Mul(&rst[1][1], "ient).Sub(&rst[0][1], &rst[1][1])
|
|
rst[1][2].Mul(&rst[1][2], "ient).Sub(&rst[0][2], &rst[1][2])
|
|
|
|
rst[0][0].Set(&tmp[0])
|
|
rst[0][1].Set(&tmp[1])
|
|
rst[0][2].Set(&tmp[2])
|
|
}
|
|
|
|
quotient.Div(&rst[0][0], &rst[1][0])
|
|
remainder.Mod(&rst[0][0], &rst[1][0])
|
|
_r.Set(&remainder)
|
|
_t.Mul(&rst[1][2], "ient).Sub(&rst[0][2], &_t)
|
|
|
|
res.V1[0].Set(&rst[1][0])
|
|
res.V1[1].Neg(&rst[1][2])
|
|
|
|
// take the shorter of [rst[0][0], rst[0][2]], [_r, _t]
|
|
tmp[1].Mul(&rst[0][2], &rst[0][2])
|
|
tmp[0].Mul(&rst[0][0], &rst[0][0]).Add(&tmp[1], &tmp[0])
|
|
tmp[2].Mul(&_r, &_r)
|
|
tmp[1].Mul(&_t, &_t).Add(&tmp[2], &tmp[1])
|
|
if tmp[0].Cmp(&tmp[1]) == 1 {
|
|
res.V2[0].Set(&_r)
|
|
res.V2[1].Neg(&_t)
|
|
} else {
|
|
res.V2[0].Set(&rst[0][0])
|
|
res.V2[1].Neg(&rst[0][2])
|
|
}
|
|
|
|
// sets determinant
|
|
tmp[0].Mul(&res.V1[1], &res.V2[0])
|
|
res.Det.Mul(&res.V1[0], &res.V2[1]).Sub(&res.Det, &tmp[0])
|
|
|
|
}
|
|
|
|
// SplitScalar outputs u,v such that u+vlambda=s[r].
|
|
// The method is to view s as (s,0) in ZxZ, and find a close
|
|
// vector w of (s,0) in <l>, where l is a sub Z-module of
|
|
// ker((a,b)->a+blambda[r]): then (u,v)=w-(s,0), and
|
|
// u+vlambda=s[r].
|
|
// cf https://www.iacr.org/archive/crypto2001/21390189.pdf
|
|
func SplitScalar(s *big.Int, l *Lattice) [2]big.Int {
|
|
|
|
var k1, k2 big.Int
|
|
k1.Mul(s, &l.V2[1])
|
|
k2.Mul(s, &l.V1[1]).Neg(&k2)
|
|
rounding(&k1, &l.Det, &k1)
|
|
rounding(&k2, &l.Det, &k2)
|
|
v := getVector(l, &k1, &k2)
|
|
v[0].Sub(s, &v[0])
|
|
v[1].Neg(&v[1])
|
|
return v
|
|
}
|
|
|
|
// sets res to the closest integer from n/d
|
|
func rounding(n, d, res *big.Int) {
|
|
var dshift, r, one big.Int
|
|
one.SetUint64(1)
|
|
dshift.Rsh(d, 1)
|
|
r.Mod(n, d)
|
|
res.Div(n, d)
|
|
if r.Cmp(&dshift) == 1 {
|
|
res.Add(res, &one)
|
|
}
|
|
}
|
|
|
|
// getVector returns axV1 + bxV2
|
|
func getVector(l *Lattice, a, b *big.Int) [2]big.Int {
|
|
var res [2]big.Int
|
|
var tmp big.Int
|
|
tmp.Mul(b, &l.V2[0])
|
|
res[0].Mul(a, &l.V1[0]).Add(&res[0], &tmp)
|
|
tmp.Mul(b, &l.V2[1])
|
|
res[1].Mul(a, &l.V1[1]).Add(&res[1], &tmp)
|
|
return res
|
|
}
|
|
|
|
// ExpandMsgXmd expands msg to a slice of lenInBytes bytes.
|
|
// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5
|
|
// https://tools.ietf.org/html/rfc8017#section-4.1 (I2OSP/O2ISP)
|
|
func ExpandMsgXmd(msg, dst []byte, lenInBytes int) ([]byte, error) {
|
|
|
|
h := sha256.New()
|
|
ell := (lenInBytes + h.Size() - 1) / h.Size() // ceil(len_in_bytes / b_in_bytes)
|
|
if ell > 255 {
|
|
return nil, errors.New("invalid lenInBytes")
|
|
}
|
|
if len(dst) > 255 {
|
|
return nil, errors.New("invalid domain size (>255 bytes)")
|
|
}
|
|
sizeDomain := uint8(len(dst))
|
|
|
|
// Z_pad = I2OSP(0, r_in_bytes)
|
|
// l_i_b_str = I2OSP(len_in_bytes, 2)
|
|
// DST_prime = I2OSP(len(DST), 1) || DST
|
|
// b_0 = H(Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime)
|
|
h.Reset()
|
|
if _, err := h.Write(make([]byte, h.BlockSize())); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write(msg); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write([]byte{uint8(lenInBytes >> 8), uint8(lenInBytes), uint8(0)}); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write(dst); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write([]byte{sizeDomain}); err != nil {
|
|
return nil, err
|
|
}
|
|
b0 := h.Sum(nil)
|
|
|
|
// b_1 = H(b_0 || I2OSP(1, 1) || DST_prime)
|
|
h.Reset()
|
|
if _, err := h.Write(b0); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write([]byte{uint8(1)}); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write(dst); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write([]byte{sizeDomain}); err != nil {
|
|
return nil, err
|
|
}
|
|
b1 := h.Sum(nil)
|
|
|
|
res := make([]byte, lenInBytes)
|
|
copy(res[:h.Size()], b1)
|
|
|
|
for i := 2; i <= ell; i++ {
|
|
// b_i = H(strxor(b_0, b_(i - 1)) || I2OSP(i, 1) || DST_prime)
|
|
h.Reset()
|
|
strxor := make([]byte, h.Size())
|
|
for j := 0; j < h.Size(); j++ {
|
|
strxor[j] = b0[j] ^ b1[j]
|
|
}
|
|
if _, err := h.Write(strxor); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write([]byte{uint8(i)}); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write(dst); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := h.Write([]byte{sizeDomain}); err != nil {
|
|
return nil, err
|
|
}
|
|
b1 = h.Sum(nil)
|
|
copy(res[h.Size()*(i-1):h.Size()*i], b1)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// NextPowerOfTwo returns the next power of 2 of n
|
|
func NextPowerOfTwo(n uint64) uint64 {
|
|
c := bits.OnesCount64(n)
|
|
if c == 0 {
|
|
return 1
|
|
}
|
|
if c == 1 {
|
|
return n
|
|
}
|
|
t := bits.LeadingZeros64(n)
|
|
if t == 0 {
|
|
panic("next power of 2 overflows uint64")
|
|
}
|
|
return uint64(1) << (64 - t)
|
|
}
|