diff --git a/bscript/address.go b/bscript/address.go index 701395d4..4cfef7a1 100644 --- a/bscript/address.go +++ b/bscript/address.go @@ -77,7 +77,7 @@ func NewAddressFromPublicKeyHash(hash []byte, mainnet bool) (*Address, error) { if !mainnet { bb[0] = 111 } - + // nolint:makezero // stop complaining bb = append(bb, hash...) return &Address{ @@ -99,7 +99,7 @@ func NewAddressFromPublicKey(pubKey *bsvec.PublicKey, mainnet bool) (*Address, e if !mainnet { bb[0] = 111 } - + // nolint:makezero // stop complaining bb = append(bb, hash...) return &Address{ diff --git a/txchange.go b/txchange.go index 6778dbd1..2051b39e 100644 --- a/txchange.go +++ b/txchange.go @@ -18,45 +18,83 @@ func (tx *Tx) ChangeToAddress(addr string, f []*Fee) error { } // Change calculates the amount of fees needed to cover the transaction -// and adds the left over change in a new output using the script provided. +// and adds the left over change in a new output using the script provided. func (tx *Tx) Change(s *bscript.Script, f []*Fee) error { + available, hasChange, err := tx.change(s, f, true) + if err != nil { + return err + } + if hasChange { + // add rest of available sats to the change output + tx.Outputs[len(tx.Outputs)-1].Satoshis = available + } + return nil +} + +// ChangeToOutput will calculate fees and add them to an output at the index specified (0 based). +// If an invalid index is supplied and error is returned. +func (tx *Tx) ChangeToOutput(index uint, f []*Fee) error { + if int(index) > len(tx.Outputs)-1 { + return errors.New("index is greater than number of inputs in transaction") + } + available, hasChange, err := tx.change(tx.Outputs[index].LockingScript, f, false) + if err != nil { + return err + } + if hasChange { + tx.Outputs[index].Satoshis += available + } + return nil +} +// CalculateFee will return the amount of fees the current transaction will +// require. +func (tx *Tx) CalculateFee(f []*Fee) (uint64, error) { + total := tx.TotalInputSatoshis() - tx.TotalOutputSatoshis() + sats, _, err := tx.change(nil, f, false) + if err != nil { + return 0, err + } + return total - sats, nil +} + +// change will return the amount of satoshis to add to an input after fees are removed. +// True will be returned if change has been added. +func (tx *Tx) change(s *bscript.Script, f []*Fee, newOutput bool) (uint64, bool, error) { inputAmount := tx.TotalInputSatoshis() outputAmount := tx.TotalOutputSatoshis() if inputAmount < outputAmount { - return errors.New("satoshis inputted to the tx are less than the outputted satoshis") + return 0, false, errors.New("satoshis inputted to the tx are less than the outputted satoshis") } available := inputAmount - outputAmount standardFees, err := ExtractStandardFee(f) if err != nil { - return err + return 0, false, err } if !tx.canAddChange(available, standardFees) { - return nil + return 0, false, err + } + if newOutput { + tx.AddOutput(&Output{Satoshis: 0, LockingScript: s}) } - - tx.AddOutput(&Output{Satoshis: 0, LockingScript: s}) var preSignedFeeRequired uint64 if preSignedFeeRequired, err = tx.getPreSignedFeeRequired(f); err != nil { - return err + return 0, false, err } var expectedUnlockingScriptFees uint64 if expectedUnlockingScriptFees, err = tx.getExpectedUnlockingScriptFees(f); err != nil { - return err + return 0, false, err } available -= preSignedFeeRequired + expectedUnlockingScriptFees - // add rest of available sats to the change output - tx.Outputs[len(tx.Outputs)-1].Satoshis = available - - return nil + return available, true, nil } func (tx *Tx) canAddChange(available uint64, standardFees *Fee) bool { diff --git a/txchange_test.go b/txchange_test.go index 2431c70b..af77e55a 100644 --- a/txchange_test.go +++ b/txchange_test.go @@ -1,6 +1,7 @@ package bt_test import ( + "errors" "testing" "github.com/bitcoinsv/bsvutil" @@ -299,3 +300,140 @@ func TestTx_Change(t *testing.T) { assert.Equal(t, "01000000028ee20a442cdbcc9f9f927d9c2c9370e611675ebc24c064e8e94508ec8eca889e000000006b483045022100fa52a44cd8010ba646a8df6bac6e5e8aa93f24439521c2ce1c8fe6550e73c1750220636e30d757702a6777d8310090962d4bac2b3fd634127856d51b184f5c702c8f4121034aaeabc056f33fd960d1e43fc8a0672723af02f275e54c31381af66a334634caffffffff42eaf7bdddc797a0beb97717ff8846f03c963fb5fe15a2b555b9cbd477b0254e000000006b483045022100c201fd55ef33525b3eb0557fac77408b8ec7f6ea5b00d08512df105172f992d60220753b21519a416dcbeaf1a501d9c36de2aea9c83c6d258320500371819d0758e14121034aaeabc056f33fd960d1e43fc8a0672723af02f275e54c31381af66a334634caffffffff01c62b0000000000001976a9147824dec00be2c45dad83c9b5e9f5d7ef05ba3cf988ac00000000", tx.ToString()) }) } + +func TestTx_ChangeToOutput(t *testing.T) { + tests := map[string]struct { + tx *bt.Tx + index uint + fees []*bt.Fee + expOutputTotal uint64 + expChangeOutput uint64 + err error + }{ + "no change to add should return no change output": { + tx: func() *bt.Tx { + tx := bt.NewTx() + assert.NoError(t, tx.From( + "07912972e42095fe58daaf09161c5a5da57be47c2054dc2aaa52b30fefa1940b", + 0, + "76a914af2590a45ae401651fdbdf59a76ad43d1862534088ac", + 1000)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 1000)) + return tx + }(), + index: 0, + fees: bt.DefaultFees(), + expOutputTotal: 1000, + expChangeOutput: 1000, + err: nil, + }, "change to add should add change to output": { + tx: func() *bt.Tx { + tx := bt.NewTx() + assert.NoError(t, tx.From( + "07912972e42095fe58daaf09161c5a5da57be47c2054dc2aaa52b30fefa1940b", + 0, + "76a914af2590a45ae401651fdbdf59a76ad43d1862534088ac", + 1000)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + return tx + }(), + index: 0, + fees: bt.DefaultFees(), + expOutputTotal: 904, + expChangeOutput: 904, + err: nil, + }, "change to add should add change to specified output": { + tx: func() *bt.Tx { + tx := bt.NewTx() + assert.NoError(t, tx.From( + "07912972e42095fe58daaf09161c5a5da57be47c2054dc2aaa52b30fefa1940b", + 0, + "76a914af2590a45ae401651fdbdf59a76ad43d1862534088ac", + 2500)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + return tx + }(), + index: 3, + fees: bt.DefaultFees(), + expOutputTotal: 2353, + expChangeOutput: 853, + err: nil, + }, "index out of range should return error": { + tx: func() *bt.Tx { + tx := bt.NewTx() + assert.NoError(t, tx.From( + "07912972e42095fe58daaf09161c5a5da57be47c2054dc2aaa52b30fefa1940b", + 0, + "76a914af2590a45ae401651fdbdf59a76ad43d1862534088ac", + 1000)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + return tx + }(), + index: 1, + fees: bt.DefaultFees(), + err: errors.New("index is greater than number of inputs in transaction"), + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + err := test.tx.ChangeToOutput(test.index, test.fees) + if test.err != nil { + assert.Error(t, err) + assert.Equal(t, test.err, err) + return + } + assert.Equal(t, test.expOutputTotal, test.tx.TotalOutputSatoshis()) + assert.Equal(t, test.expChangeOutput, test.tx.Outputs[test.index].Satoshis) + }) + } +} + +func TestTx_CalculateChange(t *testing.T) { + tests := map[string]struct { + tx *bt.Tx + fees []*bt.Fee + expFees uint64 + err error + }{ + "Transaction with one input one output should return 96": { + tx: func() *bt.Tx { + tx := bt.NewTx() + assert.NoError(t, tx.From( + "07912972e42095fe58daaf09161c5a5da57be47c2054dc2aaa52b30fefa1940b", + 0, + "76a914af2590a45ae401651fdbdf59a76ad43d1862534088ac", + 1000)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + return tx + }(), + fees: bt.DefaultFees(), + expFees: 96, + }, "Transaction with one input 4 outputs should return 147": { + tx: func() *bt.Tx { + tx := bt.NewTx() + assert.NoError(t, tx.From( + "07912972e42095fe58daaf09161c5a5da57be47c2054dc2aaa52b30fefa1940b", + 0, + "76a914af2590a45ae401651fdbdf59a76ad43d1862534088ac", + 2500)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + assert.NoError(t, tx.PayTo("mxAoAyZFXX6LZBWhoam3vjm6xt9NxPQ15f", 500)) + return tx + }(), + fees: bt.DefaultFees(), + expFees: 147, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + fee, err := test.tx.CalculateFee(test.fees) + assert.Equal(t, test.err, err) + assert.Equal(t, test.expFees, fee) + }) + } +}