Skip to content

Commit

Permalink
Added in support for deriving Hard keys (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
swdee authored Oct 21, 2020
1 parent 267dc38 commit d3c6d31
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 7 deletions.
83 changes: 81 additions & 2 deletions derive.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package schnorrkel
import (
"crypto/rand"
"errors"

"github.com/gtank/merlin"
r255 "github.com/gtank/ristretto255"
)

const ChainCodeLength = 32

var (
ErrDeriveHardKeyType = errors.New("Failed to derive hard key type, DerivableKey must be a SecretKey")
)

// DerivableKey implements DeriveKey
type DerivableKey interface {
Encode() [32]byte
Expand Down Expand Up @@ -70,7 +73,45 @@ func (ek *ExtendedKey) DeriveKey(t *merlin.Transcript) (*ExtendedKey, error) {
return ek.key.DeriveKey(t, ek.chaincode)
}

// DeriveKeySimple derives a subkey identified by byte array i and chain code.
// HardDeriveMiniSecretKey implements BIP-32 like "hard" derivation of a mini
// secret from an extended key's secret key
func (ek *ExtendedKey) HardDeriveMiniSecretKey(i []byte) (*ExtendedKey, error) {
sk, err := ek.Secret()
if err != nil {
return nil, err
}

msk, chainCode, err := sk.HardDeriveMiniSecretKey(i, ek.chaincode)
if err != nil {
return nil, err
}

return NewExtendedKey(msk, chainCode), nil
}

// DeriveKeyHard derives a Hard subkey identified by the byte array i and chain
// code
func DeriveKeyHard(key DerivableKey, i []byte, cc [ChainCodeLength]byte) (*ExtendedKey, error) {
switch key.(type) {
case *SecretKey:
msk, resCC, err := key.(*SecretKey).HardDeriveMiniSecretKey(i, cc)
if err != nil {
return nil, err
}
return NewExtendedKey(msk.ExpandEd25519(), resCC), nil

default:
return nil, ErrDeriveHardKeyType
}
}

// DerviveKeySoft is an alias for DervieKeySimple() used to derive a Soft subkey
// identified by the byte array i and chain code
func DeriveKeySoft(key DerivableKey, i []byte, cc [ChainCodeLength]byte) (*ExtendedKey, error) {
return DeriveKeySimple(key, i, cc)
}

// DeriveKeySimple derives a Soft subkey identified by byte array i and chain code.
func DeriveKeySimple(key DerivableKey, i []byte, cc [ChainCodeLength]byte) (*ExtendedKey, error) {
t := merlin.NewTranscript("SchnorrRistrettoHDKD")
t.AppendMessage([]byte("sign-bytes"), i)
Expand Down Expand Up @@ -115,6 +156,44 @@ func (sk *SecretKey) DeriveKey(t *merlin.Transcript, cc [ChainCodeLength]byte) (
}, nil
}

// HardDeriveMiniSecretKey implements BIP-32 like "hard" derivation of a mini
// secret from a secret key
func (sk *SecretKey) HardDeriveMiniSecretKey(i []byte, cc [ChainCodeLength]byte) (
*MiniSecretKey, [ChainCodeLength]byte, error) {

t := merlin.NewTranscript("SchnorrRistrettoHDKD")
t.AppendMessage([]byte("sign-bytes"), i)
t.AppendMessage([]byte("chain-code"), cc[:])
skenc := sk.Encode()
t.AppendMessage([]byte("secret-key"), skenc[:])

msk := [MiniSecretKeyLength]byte{}
mskBytes := t.ExtractBytes([]byte("HDKD-hard"), MiniSecretKeyLength)
copy(msk[:], mskBytes)

ccRes := [ChainCodeLength]byte{}
ccBytes := t.ExtractBytes([]byte("HDKD-chaincode"), ChainCodeLength)
copy(ccRes[:], ccBytes)

miniSec, err := NewMiniSecretKeyFromRaw(msk)

return miniSec, ccRes, err
}

// HardDeriveMiniSecretKey implements BIP-32 like "hard" derivation of a mini
// secret from a mini secret key
func (mk *MiniSecretKey) HardDeriveMiniSecretKey(i []byte, cc [ChainCodeLength]byte) (
*MiniSecretKey, [ChainCodeLength]byte, error) {
sk := mk.ExpandEd25519()
return sk.HardDeriveMiniSecretKey(i, cc)
}

// DeriveKey derives an Extended Key from the Mini Secret Key
func (mk *MiniSecretKey) DeriveKey(t *merlin.Transcript, cc [ChainCodeLength]byte) (*ExtendedKey, error) {
sk := mk.ExpandEd25519()
return sk.DeriveKey(t, cc)
}

func (pk *PublicKey) DeriveKey(t *merlin.Transcript, cc [ChainCodeLength]byte) (*ExtendedKey, error) {
sc, dcc := pk.DeriveScalarAndChaincode(t, cc)

Expand Down
58 changes: 53 additions & 5 deletions derive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,52 @@ func TestDerivePublicAndPrivateMatch(t *testing.T) {
}
}

func TestDerive_rust(t *testing.T) {
func TestDeriveSoft(t *testing.T) {
// test vectors from https://github.com/Warchant/sr25519-crust/blob/master/test/derive.cpp#L32
kp, err := hex.DecodeString("4c1250e05afcd79e74f6c035aee10248841090e009b6fd7ba6a98d5dc743250cafa4b32c608e3ee2ba624850b3f14c75841af84b16798bf1ee4a3875aa37a2cee661e416406384fe1ca091980958576d2bff7c461636e9f22c895f444905ea1f")
c := commonVectors{
KeyPair: "4c1250e05afcd79e74f6c035aee10248841090e009b6fd7ba6a98d5dc743250cafa4b32c608e3ee2ba624850b3f14c75841af84b16798bf1ee4a3875aa37a2cee661e416406384fe1ca091980958576d2bff7c461636e9f22c895f444905ea1f",
ChainCode: "0c666f6f00000000000000000000000000000000000000000000000000000000",
Public: "b21e5aabeeb35d6a1bf76226a6c65cd897016df09ef208243e59eed2401f5357",
Hard: false,
}

deriveCommon(t, c)
}

func TestDeriveHard(t *testing.T) {
// test vectors from https://github.com/Warchant/sr25519-crust/blob/4b167a8db2c4114561e5380e3493375df426b124/test/derive.cpp#L13
c := commonVectors{
KeyPair: "4c1250e05afcd79e74f6c035aee10248841090e009b6fd7ba6a98d5dc743250cafa4b32c608e3ee2ba624850b3f14c75841af84b16798bf1ee4a3875aa37a2cee661e416406384fe1ca091980958576d2bff7c461636e9f22c895f444905ea1f",
ChainCode: "14416c6963650000000000000000000000000000000000000000000000000000",
Public: "d8db757f04521a940f0237c8a1e44dfbe0b3e39af929eb2e9e257ba61b9a0a1a",
Hard: true,
}

deriveCommon(t, c)
}

// commonVectors is a struct to set the vectors used for deriving soft or hard
// keys for testing
type commonVectors struct {
// KeyPair in the hex encoded string of a known keypair
KeyPair string
// ChainCode is the chain code for generating the derived key hex encoded
ChainCode string
// Public is the expected resulting public key of the derived key hex
// encoded
Public string
// Hard indicates if the vectors are for deriving a Hard key
Hard bool
}

// deriveCommon provides common functions for testing Soft and Hard key derivation
func deriveCommon(t *testing.T, vec commonVectors) {
kp, err := hex.DecodeString(vec.KeyPair)
if err != nil {
t.Fatal(err)
}

cc, err := hex.DecodeString("0c666f6f00000000000000000000000000000000000000000000000000000000")
cc, err := hex.DecodeString(vec.ChainCode)
if err != nil {
t.Fatal(err)
}
Expand All @@ -94,12 +132,20 @@ func TestDerive_rust(t *testing.T) {

ccBytes := [32]byte{}
copy(ccBytes[:], cc)
derived, err := DeriveKeySimple(priv, []byte{}, ccBytes)

var derived *ExtendedKey

if vec.Hard {
derived, err = DeriveKeyHard(priv, []byte{}, ccBytes)
} else {
derived, err = DeriveKeySimple(priv, []byte{}, ccBytes)
}

if err != nil {
t.Fatal(err)
}

expectedPub, err := hex.DecodeString("b21e5aabeeb35d6a1bf76226a6c65cd897016df09ef208243e59eed2401f5357")
expectedPub, err := hex.DecodeString(vec.Public)
if err != nil {
t.Fatal(err)
}
Expand All @@ -108,7 +154,9 @@ func TestDerive_rust(t *testing.T) {
if err != nil {
t.Fatal(err)
}

resultPubBytes := resultPub.Encode()

if !bytes.Equal(expectedPub, resultPubBytes[:]) {
t.Fatalf("Fail: got %x expected %x", resultPubBytes, expectedPub)
}
Expand Down
5 changes: 5 additions & 0 deletions keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import (
r255 "github.com/gtank/ristretto255"
)

const (
// MiniSecretKeyLength is the len in bytes of the MiniSecret Key
MiniSecretKeyLength = 32
)

// MiniSecretKey is a secret scalar
type MiniSecretKey struct {
key [32]byte
Expand Down

0 comments on commit d3c6d31

Please sign in to comment.