Skip to content

Commit

Permalink
math/big: add (*Int).FillBytes
Browse files Browse the repository at this point in the history
Replaced almost every use of Bytes with FillBytes.

Note that the approved proposal was for

    func (*Int) FillBytes(buf []byte)

while this implements

    func (*Int) FillBytes(buf []byte) []byte

because the latter was far nicer to use in all callsites.

Fixes golang#35833

Change-Id: Ia912df123e5d79b763845312ea3d9a8051343c0a
Reviewed-on: https://go-review.googlesource.com/c/go/+/230397
Reviewed-by: Robert Griesemer <gri@golang.org>
  • Loading branch information
FiloSottile committed May 5, 2020
1 parent b5f7ff4 commit c9d5f60
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 74 deletions.
13 changes: 6 additions & 7 deletions src/crypto/elliptic/elliptic.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
N := curve.Params().N
bitSize := N.BitLen()
byteLen := (bitSize + 7) >> 3
byteLen := (bitSize + 7) / 8
priv = make([]byte, byteLen)

for x == nil {
Expand All @@ -304,23 +304,22 @@ func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err e

// Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62.
func Marshal(curve Curve, x, y *big.Int) []byte {
byteLen := (curve.Params().BitSize + 7) >> 3
byteLen := (curve.Params().BitSize + 7) / 8

ret := make([]byte, 1+2*byteLen)
ret[0] = 4 // uncompressed point

xBytes := x.Bytes()
copy(ret[1+byteLen-len(xBytes):], xBytes)
yBytes := y.Bytes()
copy(ret[1+2*byteLen-len(yBytes):], yBytes)
x.FillBytes(ret[1 : 1+byteLen])
y.FillBytes(ret[1+byteLen : 1+2*byteLen])

return ret
}

// Unmarshal converts a point, serialized by Marshal, into an x, y pair.
// It is an error if the point is not in uncompressed form or is not on the curve.
// On error, x = nil.
func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
byteLen := (curve.Params().BitSize + 7) >> 3
byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+2*byteLen {
return
}
Expand Down
20 changes: 4 additions & 16 deletions src/crypto/rsa/pkcs1v15.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) ([]byte, error)
m := new(big.Int).SetBytes(em)
c := encrypt(new(big.Int), pub, m)

copyWithLeftPad(em, c.Bytes())
return em, nil
return c.FillBytes(em), nil
}

// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5.
Expand Down Expand Up @@ -150,7 +149,7 @@ func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid
return
}

em = leftPad(m.Bytes(), k)
em = m.FillBytes(make([]byte, k))
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2)

Expand Down Expand Up @@ -256,8 +255,7 @@ func SignPKCS1v15(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []b
return nil, err
}

copyWithLeftPad(em, c.Bytes())
return em, nil
return c.FillBytes(em), nil
}

// VerifyPKCS1v15 verifies an RSA PKCS#1 v1.5 signature.
Expand Down Expand Up @@ -286,7 +284,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte)

c := new(big.Int).SetBytes(sig)
m := encrypt(new(big.Int), pub, c)
em := leftPad(m.Bytes(), k)
em := m.FillBytes(make([]byte, k))
// EM = 0x00 || 0x01 || PS || 0x00 || T

ok := subtle.ConstantTimeByteEq(em[0], 0)
Expand Down Expand Up @@ -323,13 +321,3 @@ func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte,
}
return
}

// copyWithLeftPad copies src to the end of dest, padding with zero bytes as
// needed.
func copyWithLeftPad(dest, src []byte) {
numPaddingBytes := len(dest) - len(src)
for i := 0; i < numPaddingBytes; i++ {
dest[i] = 0
}
copy(dest[numPaddingBytes:], src)
}
17 changes: 7 additions & 10 deletions src/crypto/rsa/pss.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,19 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
// Note that hashed must be the result of hashing the input message using the
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
emBits := priv.N.BitLen() - 1
em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
return
return nil, err
}
m := new(big.Int).SetBytes(em)
c, err := decryptAndCheck(rand, priv, m)
if err != nil {
return
return nil, err
}
s = make([]byte, priv.Size())
copyWithLeftPad(s, c.Bytes())
return
s := make([]byte, priv.Size())
return c.FillBytes(s), nil
}

const (
Expand Down Expand Up @@ -296,11 +295,9 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts
m := encrypt(new(big.Int), pub, s)
emBits := pub.N.BitLen() - 1
emLen := (emBits + 7) / 8
emBytes := m.Bytes()
if emLen < len(emBytes) {
if m.BitLen() > emLen*8 {
return ErrVerification
}
em := make([]byte, emLen)
copyWithLeftPad(em, emBytes)
em := m.FillBytes(make([]byte, emLen))
return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
}
32 changes: 5 additions & 27 deletions src/crypto/rsa/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,9 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
m := new(big.Int)
m.SetBytes(em)
c := encrypt(new(big.Int), pub, m)
out := c.Bytes()

if len(out) < k {
// If the output is too small, we need to left-pad with zeros.
t := make([]byte, k)
copy(t[k-len(out):], out)
out = t
}

return out, nil
out := make([]byte, k)
return c.FillBytes(out), nil
}

// ErrDecryption represents a failure to decrypt a message.
Expand Down Expand Up @@ -597,12 +590,9 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
lHash := hash.Sum(nil)
hash.Reset()

// Converting the plaintext number to bytes will strip any
// leading zeros so we may have to left pad. We do this unconditionally
// to avoid leaking timing information. (Although we still probably
// leak the number of leading zeros. It's not clear that we can do
// anything about this.)
em := leftPad(m.Bytes(), k)
// We probably leak the number of leading zeros.
// It's not clear that we can do anything about this.
em := m.FillBytes(make([]byte, k))

firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)

Expand Down Expand Up @@ -643,15 +633,3 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext

return rest[index+1:], nil
}

// leftPad returns a new slice of length size. The contents of input are right
// aligned in the new slice.
func leftPad(input []byte, size int) (out []byte) {
n := len(input)
if n > size {
n = size
}
out = make([]byte, size)
copy(out[len(out)-n:], input)
return
}
7 changes: 2 additions & 5 deletions src/crypto/tls/key_schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,8 @@ func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte {
}

xShared, _ := curve.ScalarMult(x, y, p.privateKey)
sharedKey := make([]byte, (curve.Params().BitSize+7)>>3)
xBytes := xShared.Bytes()
copy(sharedKey[len(sharedKey)-len(xBytes):], xBytes)

return sharedKey
sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
return xShared.FillBytes(sharedKey)
}

type x25519Parameters struct {
Expand Down
7 changes: 2 additions & 5 deletions src/crypto/x509/sec1.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,10 @@ func MarshalECPrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
// marshalECPrivateKey marshals an EC private key into ASN.1, DER format and
// sets the curve ID to the given OID, or omits it if OID is nil.
func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) {
privateKeyBytes := key.D.Bytes()
paddedPrivateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
copy(paddedPrivateKey[len(paddedPrivateKey)-len(privateKeyBytes):], privateKeyBytes)

privateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
return asn1.Marshal(ecPrivateKey{
Version: 1,
PrivateKey: paddedPrivateKey,
PrivateKey: key.D.FillBytes(privateKey),
NamedCurveOID: oid,
PublicKey: asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)},
})
Expand Down
15 changes: 15 additions & 0 deletions src/math/big/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,26 @@ func (z *Int) SetBytes(buf []byte) *Int {
}

// Bytes returns the absolute value of x as a big-endian byte slice.
//
// To use a fixed length slice, or a preallocated one, use FillBytes.
func (x *Int) Bytes() []byte {
buf := make([]byte, len(x.abs)*_S)
return buf[x.abs.bytes(buf):]
}

// FillBytes sets buf to the absolute value of x, storing it as a zero-extended
// big-endian byte slice, and returns buf.
//
// If the absolute value of x doesn't fit in buf, FillBytes will panic.
func (x *Int) FillBytes(buf []byte) []byte {
// Clear whole buffer. (This gets optimized into a memclr.)
for i := range buf {
buf[i] = 0
}
x.abs.bytes(buf)
return buf
}

// BitLen returns the length of the absolute value of x in bits.
// The bit length of 0 is 0.
func (x *Int) BitLen() int {
Expand Down
54 changes: 54 additions & 0 deletions src/math/big/int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1840,3 +1840,57 @@ func BenchmarkDiv(b *testing.B) {
})
}
}

func TestFillBytes(t *testing.T) {
checkResult := func(t *testing.T, buf []byte, want *Int) {
t.Helper()
got := new(Int).SetBytes(buf)
if got.CmpAbs(want) != 0 {
t.Errorf("got 0x%x, want 0x%x: %x", got, want, buf)
}
}
panics := func(f func()) (panic bool) {
defer func() { panic = recover() != nil }()
f()
return
}

for _, n := range []string{
"0",
"1000",
"0xffffffff",
"-0xffffffff",
"0xffffffffffffffff",
"0x10000000000000000",
"0xabababababababababababababababababababababababababa",
"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
} {
t.Run(n, func(t *testing.T) {
t.Logf(n)
x, ok := new(Int).SetString(n, 0)
if !ok {
panic("invalid test entry")
}

// Perfectly sized buffer.
byteLen := (x.BitLen() + 7) / 8
buf := make([]byte, byteLen)
checkResult(t, x.FillBytes(buf), x)

// Way larger, checking all bytes get zeroed.
buf = make([]byte, 100)
for i := range buf {
buf[i] = 0xff
}
checkResult(t, x.FillBytes(buf), x)

// Too small.
if byteLen > 0 {
buf = make([]byte, byteLen-1)
if !panics(func() { x.FillBytes(buf) }) {
t.Errorf("expected panic for small buffer and value %x", x)
}
}
})
}
}
15 changes: 11 additions & 4 deletions src/math/big/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -1476,19 +1476,26 @@ func (z nat) expNNMontgomery(x, y, m nat) nat {
}

// bytes writes the value of z into buf using big-endian encoding.
// len(buf) must be >= len(z)*_S. The value of z is encoded in the
// slice buf[i:]. The number i of unused bytes at the beginning of
// buf is returned as result.
// The value of z is encoded in the slice buf[i:]. If the value of z
// cannot be represented in buf, bytes panics. The number i of unused
// bytes at the beginning of buf is returned as result.
func (z nat) bytes(buf []byte) (i int) {
i = len(buf)
for _, d := range z {
for j := 0; j < _S; j++ {
i--
buf[i] = byte(d)
if i >= 0 {
buf[i] = byte(d)
} else if byte(d) != 0 {
panic("math/big: buffer too small to fit value")
}
d >>= 8
}
}

if i < 0 {
i = 0
}
for i < len(buf) && buf[i] == 0 {
i++
}
Expand Down

0 comments on commit c9d5f60

Please sign in to comment.