From c9d5f60eaa4450ccf1ce878d55b4c6a12843f2f3 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Mon, 27 Apr 2020 21:52:38 -0400 Subject: [PATCH] math/big: add (*Int).FillBytes Replaced almost every use of Bytes with FillBytes. Note that the approved proposal was for func (*Int) FillBytes(buf []byte) while this implements func (*Int) FillBytes(buf []byte) []byte because the latter was far nicer to use in all callsites. Fixes #35833 Change-Id: Ia912df123e5d79b763845312ea3d9a8051343c0a Reviewed-on: https://go-review.googlesource.com/c/go/+/230397 Reviewed-by: Robert Griesemer --- src/crypto/elliptic/elliptic.go | 13 ++++---- src/crypto/rsa/pkcs1v15.go | 20 +++--------- src/crypto/rsa/pss.go | 17 +++++------ src/crypto/rsa/rsa.go | 32 +++---------------- src/crypto/tls/key_schedule.go | 7 ++--- src/crypto/x509/sec1.go | 7 ++--- src/math/big/int.go | 15 +++++++++ src/math/big/int_test.go | 54 +++++++++++++++++++++++++++++++++ src/math/big/nat.go | 15 ++++++--- 9 files changed, 106 insertions(+), 74 deletions(-) diff --git a/src/crypto/elliptic/elliptic.go b/src/crypto/elliptic/elliptic.go index e2f71cdb63bab..bd5168c5fd842 100644 --- a/src/crypto/elliptic/elliptic.go +++ b/src/crypto/elliptic/elliptic.go @@ -277,7 +277,7 @@ var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f} func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) { N := curve.Params().N bitSize := N.BitLen() - byteLen := (bitSize + 7) >> 3 + byteLen := (bitSize + 7) / 8 priv = make([]byte, byteLen) for x == nil { @@ -304,15 +304,14 @@ func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err e // Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62. func Marshal(curve Curve, x, y *big.Int) []byte { - byteLen := (curve.Params().BitSize + 7) >> 3 + byteLen := (curve.Params().BitSize + 7) / 8 ret := make([]byte, 1+2*byteLen) ret[0] = 4 // uncompressed point - xBytes := x.Bytes() - copy(ret[1+byteLen-len(xBytes):], xBytes) - yBytes := y.Bytes() - copy(ret[1+2*byteLen-len(yBytes):], yBytes) + x.FillBytes(ret[1 : 1+byteLen]) + y.FillBytes(ret[1+byteLen : 1+2*byteLen]) + return ret } @@ -320,7 +319,7 @@ func Marshal(curve Curve, x, y *big.Int) []byte { // It is an error if the point is not in uncompressed form or is not on the curve. // On error, x = nil. func Unmarshal(curve Curve, data []byte) (x, y *big.Int) { - byteLen := (curve.Params().BitSize + 7) >> 3 + byteLen := (curve.Params().BitSize + 7) / 8 if len(data) != 1+2*byteLen { return } diff --git a/src/crypto/rsa/pkcs1v15.go b/src/crypto/rsa/pkcs1v15.go index 499242ffc5b57..3208119ae1ff4 100644 --- a/src/crypto/rsa/pkcs1v15.go +++ b/src/crypto/rsa/pkcs1v15.go @@ -61,8 +61,7 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) ([]byte, error) m := new(big.Int).SetBytes(em) c := encrypt(new(big.Int), pub, m) - copyWithLeftPad(em, c.Bytes()) - return em, nil + return c.FillBytes(em), nil } // DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5. @@ -150,7 +149,7 @@ func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid return } - em = leftPad(m.Bytes(), k) + em = m.FillBytes(make([]byte, k)) firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0) secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2) @@ -256,8 +255,7 @@ func SignPKCS1v15(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []b return nil, err } - copyWithLeftPad(em, c.Bytes()) - return em, nil + return c.FillBytes(em), nil } // VerifyPKCS1v15 verifies an RSA PKCS#1 v1.5 signature. @@ -286,7 +284,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte) c := new(big.Int).SetBytes(sig) m := encrypt(new(big.Int), pub, c) - em := leftPad(m.Bytes(), k) + em := m.FillBytes(make([]byte, k)) // EM = 0x00 || 0x01 || PS || 0x00 || T ok := subtle.ConstantTimeByteEq(em[0], 0) @@ -323,13 +321,3 @@ func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte, } return } - -// copyWithLeftPad copies src to the end of dest, padding with zero bytes as -// needed. -func copyWithLeftPad(dest, src []byte) { - numPaddingBytes := len(dest) - len(src) - for i := 0; i < numPaddingBytes; i++ { - dest[i] = 0 - } - copy(dest[numPaddingBytes:], src) -} diff --git a/src/crypto/rsa/pss.go b/src/crypto/rsa/pss.go index f9844d87329a8..b2adbedb28fa8 100644 --- a/src/crypto/rsa/pss.go +++ b/src/crypto/rsa/pss.go @@ -207,20 +207,19 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { // Note that hashed must be the result of hashing the input message using the // given hash function. salt is a random sequence of bytes whose length will be // later used to verify the signature. -func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) { +func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { emBits := priv.N.BitLen() - 1 em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) if err != nil { - return + return nil, err } m := new(big.Int).SetBytes(em) c, err := decryptAndCheck(rand, priv, m) if err != nil { - return + return nil, err } - s = make([]byte, priv.Size()) - copyWithLeftPad(s, c.Bytes()) - return + s := make([]byte, priv.Size()) + return c.FillBytes(s), nil } const ( @@ -296,11 +295,9 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts m := encrypt(new(big.Int), pub, s) emBits := pub.N.BitLen() - 1 emLen := (emBits + 7) / 8 - emBytes := m.Bytes() - if emLen < len(emBytes) { + if m.BitLen() > emLen*8 { return ErrVerification } - em := make([]byte, emLen) - copyWithLeftPad(em, emBytes) + em := m.FillBytes(make([]byte, emLen)) return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New()) } diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index b4bfa13defbdf..28eb5926c1a54 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -416,16 +416,9 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l m := new(big.Int) m.SetBytes(em) c := encrypt(new(big.Int), pub, m) - out := c.Bytes() - if len(out) < k { - // If the output is too small, we need to left-pad with zeros. - t := make([]byte, k) - copy(t[k-len(out):], out) - out = t - } - - return out, nil + out := make([]byte, k) + return c.FillBytes(out), nil } // ErrDecryption represents a failure to decrypt a message. @@ -597,12 +590,9 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext lHash := hash.Sum(nil) hash.Reset() - // Converting the plaintext number to bytes will strip any - // leading zeros so we may have to left pad. We do this unconditionally - // to avoid leaking timing information. (Although we still probably - // leak the number of leading zeros. It's not clear that we can do - // anything about this.) - em := leftPad(m.Bytes(), k) + // We probably leak the number of leading zeros. + // It's not clear that we can do anything about this. + em := m.FillBytes(make([]byte, k)) firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0) @@ -643,15 +633,3 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext return rest[index+1:], nil } - -// leftPad returns a new slice of length size. The contents of input are right -// aligned in the new slice. -func leftPad(input []byte, size int) (out []byte) { - n := len(input) - if n > size { - n = size - } - out = make([]byte, size) - copy(out[len(out)-n:], input) - return -} diff --git a/src/crypto/tls/key_schedule.go b/src/crypto/tls/key_schedule.go index 2aab323202f7d..314016979afb8 100644 --- a/src/crypto/tls/key_schedule.go +++ b/src/crypto/tls/key_schedule.go @@ -173,11 +173,8 @@ func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { } xShared, _ := curve.ScalarMult(x, y, p.privateKey) - sharedKey := make([]byte, (curve.Params().BitSize+7)>>3) - xBytes := xShared.Bytes() - copy(sharedKey[len(sharedKey)-len(xBytes):], xBytes) - - return sharedKey + sharedKey := make([]byte, (curve.Params().BitSize+7)/8) + return xShared.FillBytes(sharedKey) } type x25519Parameters struct { diff --git a/src/crypto/x509/sec1.go b/src/crypto/x509/sec1.go index 0bfb90cd5464a..52c108ff1d624 100644 --- a/src/crypto/x509/sec1.go +++ b/src/crypto/x509/sec1.go @@ -52,13 +52,10 @@ func MarshalECPrivateKey(key *ecdsa.PrivateKey) ([]byte, error) { // marshalECPrivateKey marshals an EC private key into ASN.1, DER format and // sets the curve ID to the given OID, or omits it if OID is nil. func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) { - privateKeyBytes := key.D.Bytes() - paddedPrivateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8) - copy(paddedPrivateKey[len(paddedPrivateKey)-len(privateKeyBytes):], privateKeyBytes) - + privateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8) return asn1.Marshal(ecPrivateKey{ Version: 1, - PrivateKey: paddedPrivateKey, + PrivateKey: key.D.FillBytes(privateKey), NamedCurveOID: oid, PublicKey: asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)}, }) diff --git a/src/math/big/int.go b/src/math/big/int.go index 8816cf5266cc4..65f32487b58c0 100644 --- a/src/math/big/int.go +++ b/src/math/big/int.go @@ -447,11 +447,26 @@ func (z *Int) SetBytes(buf []byte) *Int { } // Bytes returns the absolute value of x as a big-endian byte slice. +// +// To use a fixed length slice, or a preallocated one, use FillBytes. func (x *Int) Bytes() []byte { buf := make([]byte, len(x.abs)*_S) return buf[x.abs.bytes(buf):] } +// FillBytes sets buf to the absolute value of x, storing it as a zero-extended +// big-endian byte slice, and returns buf. +// +// If the absolute value of x doesn't fit in buf, FillBytes will panic. +func (x *Int) FillBytes(buf []byte) []byte { + // Clear whole buffer. (This gets optimized into a memclr.) + for i := range buf { + buf[i] = 0 + } + x.abs.bytes(buf) + return buf +} + // BitLen returns the length of the absolute value of x in bits. // The bit length of 0 is 0. func (x *Int) BitLen() int { diff --git a/src/math/big/int_test.go b/src/math/big/int_test.go index e3a1587b3f0ad..3c8557323a032 100644 --- a/src/math/big/int_test.go +++ b/src/math/big/int_test.go @@ -1840,3 +1840,57 @@ func BenchmarkDiv(b *testing.B) { }) } } + +func TestFillBytes(t *testing.T) { + checkResult := func(t *testing.T, buf []byte, want *Int) { + t.Helper() + got := new(Int).SetBytes(buf) + if got.CmpAbs(want) != 0 { + t.Errorf("got 0x%x, want 0x%x: %x", got, want, buf) + } + } + panics := func(f func()) (panic bool) { + defer func() { panic = recover() != nil }() + f() + return + } + + for _, n := range []string{ + "0", + "1000", + "0xffffffff", + "-0xffffffff", + "0xffffffffffffffff", + "0x10000000000000000", + "0xabababababababababababababababababababababababababa", + "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + } { + t.Run(n, func(t *testing.T) { + t.Logf(n) + x, ok := new(Int).SetString(n, 0) + if !ok { + panic("invalid test entry") + } + + // Perfectly sized buffer. + byteLen := (x.BitLen() + 7) / 8 + buf := make([]byte, byteLen) + checkResult(t, x.FillBytes(buf), x) + + // Way larger, checking all bytes get zeroed. + buf = make([]byte, 100) + for i := range buf { + buf[i] = 0xff + } + checkResult(t, x.FillBytes(buf), x) + + // Too small. + if byteLen > 0 { + buf = make([]byte, byteLen-1) + if !panics(func() { x.FillBytes(buf) }) { + t.Errorf("expected panic for small buffer and value %x", x) + } + } + }) + } +} diff --git a/src/math/big/nat.go b/src/math/big/nat.go index c31ec5156b81d..6a3989bf9d82b 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -1476,19 +1476,26 @@ func (z nat) expNNMontgomery(x, y, m nat) nat { } // bytes writes the value of z into buf using big-endian encoding. -// len(buf) must be >= len(z)*_S. The value of z is encoded in the -// slice buf[i:]. The number i of unused bytes at the beginning of -// buf is returned as result. +// The value of z is encoded in the slice buf[i:]. If the value of z +// cannot be represented in buf, bytes panics. The number i of unused +// bytes at the beginning of buf is returned as result. func (z nat) bytes(buf []byte) (i int) { i = len(buf) for _, d := range z { for j := 0; j < _S; j++ { i-- - buf[i] = byte(d) + if i >= 0 { + buf[i] = byte(d) + } else if byte(d) != 0 { + panic("math/big: buffer too small to fit value") + } d >>= 8 } } + if i < 0 { + i = 0 + } for i < len(buf) && buf[i] == 0 { i++ }