-
-
Notifications
You must be signed in to change notification settings - Fork 34
/
ecdsa_sha.go
153 lines (133 loc) · 3.66 KB
/
ecdsa_sha.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"math/big"
"github.com/gbrlsnchs/jwt/v3/internal"
)
var (
// ErrECDSANilPrivKey is the error for trying to sign a JWT with a nil private key.
ErrECDSANilPrivKey = internal.NewError("jwt: ECDSA private key is nil")
// ErrECDSANilPubKey is the error for trying to verify a JWT with a nil public key.
ErrECDSANilPubKey = internal.NewError("jwt: ECDSA public key is nil")
// ErrECDSAVerification is the error for an invalid ECDSA signature.
ErrECDSAVerification = internal.NewError("jwt: ECDSA verification failed")
_ Algorithm = new(ECDSASHA)
)
// ECDSAPrivateKey is an option to set a private key to the ECDSA-SHA algorithm.
func ECDSAPrivateKey(priv *ecdsa.PrivateKey) func(*ECDSASHA) {
return func(es *ECDSASHA) {
es.priv = priv
}
}
// ECDSAPublicKey is an option to set a public key to the ECDSA-SHA algorithm.
func ECDSAPublicKey(pub *ecdsa.PublicKey) func(*ECDSASHA) {
return func(es *ECDSASHA) {
es.pub = pub
}
}
func byteSize(bitSize int) int {
byteSize := bitSize / 8
if bitSize%8 > 0 {
return byteSize + 1
}
return byteSize
}
// ECDSASHA is an algorithm that uses ECDSA to sign SHA hashes.
type ECDSASHA struct {
name string
priv *ecdsa.PrivateKey
pub *ecdsa.PublicKey
sha crypto.Hash
size int
pool *hashPool
}
func newECDSASHA(name string, opts []func(*ECDSASHA), sha crypto.Hash) *ECDSASHA {
es := ECDSASHA{
name: name,
sha: sha,
pool: newHashPool(sha.New),
}
for _, opt := range opts {
if opt != nil {
opt(&es)
}
}
if es.pub == nil {
if es.priv == nil {
panic(ErrECDSANilPrivKey)
}
es.pub = &es.priv.PublicKey
}
es.size = byteSize(es.pub.Params().BitSize) * 2
return &es
}
// NewES256 creates a new algorithm using ECDSA and SHA-256.
func NewES256(opts ...func(*ECDSASHA)) *ECDSASHA {
return newECDSASHA("ES256", opts, crypto.SHA256)
}
// NewES384 creates a new algorithm using ECDSA and SHA-384.
func NewES384(opts ...func(*ECDSASHA)) *ECDSASHA {
return newECDSASHA("ES384", opts, crypto.SHA384)
}
// NewES512 creates a new algorithm using ECDSA and SHA-512.
func NewES512(opts ...func(*ECDSASHA)) *ECDSASHA {
return newECDSASHA("ES512", opts, crypto.SHA512)
}
// Name returns the algorithm's name.
func (es *ECDSASHA) Name() string {
return es.name
}
// Sign signs headerPayload using the ECDSA-SHA algorithm.
func (es *ECDSASHA) Sign(headerPayload []byte) ([]byte, error) {
if es.priv == nil {
return nil, ErrECDSANilPrivKey
}
return es.sign(headerPayload)
}
// Size returns the signature's byte size.
func (es *ECDSASHA) Size() int {
return es.size
}
// Verify verifies a signature based on headerPayload using ECDSA-SHA.
func (es *ECDSASHA) Verify(headerPayload, sig []byte) (err error) {
if es.pub == nil {
return ErrECDSANilPubKey
}
if sig, err = internal.DecodeToBytes(sig); err != nil {
return err
}
byteSize := byteSize(es.pub.Params().BitSize)
if len(sig) != byteSize*2 {
return ErrECDSAVerification
}
r := big.NewInt(0).SetBytes(sig[:byteSize])
s := big.NewInt(0).SetBytes(sig[byteSize:])
sum, err := es.pool.sign(headerPayload)
if err != nil {
return err
}
if !ecdsa.Verify(es.pub, sum, r, s) {
return ErrECDSAVerification
}
return nil
}
func (es *ECDSASHA) sign(headerPayload []byte) ([]byte, error) {
sum, err := es.pool.sign(headerPayload)
if err != nil {
return nil, err
}
r, s, err := ecdsa.Sign(rand.Reader, es.priv, sum)
if err != nil {
return nil, err
}
byteSize := byteSize(es.priv.Params().BitSize)
rbytes := r.Bytes()
rsig := make([]byte, byteSize)
copy(rsig[byteSize-len(rbytes):], rbytes)
sbytes := s.Bytes()
ssig := make([]byte, byteSize)
copy(ssig[byteSize-len(sbytes):], sbytes)
return append(rsig, ssig...), nil
}