Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add require_confirmed_inputs to RBF messages #2783

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,8 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
cmd.replyTo ! RES_FAILURE(cmd, InvalidRbfFeerate(d.channelId, cmd.targetFeerate, minNextFeerate))
stay()
} else {
stay() using d.copy(rbfStatus = RbfStatus.RbfRequested(cmd)) sending TxInitRbf(d.channelId, cmd.lockTime, cmd.targetFeerate, d.latestFundingTx.fundingParams.localContribution)
val txInitRbf = TxInitRbf(d.channelId, cmd.lockTime, cmd.targetFeerate, d.latestFundingTx.fundingParams.localContribution, nodeParams.channelConf.requireConfirmedInputsForDualFunding)
stay() using d.copy(rbfStatus = RbfStatus.RbfRequested(cmd)) sending txInitRbf
}
case _ =>
log.warning("cannot initiate rbf, another one is already in progress")
Expand Down Expand Up @@ -541,7 +542,8 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
// we don't change our funding contribution
remoteContribution = msg.fundingContribution,
lockTime = msg.lockTime,
targetFeerate = msg.feerate
targetFeerate = msg.feerate,
requireConfirmedInputs = RequireConfirmedInputs(forLocal = msg.requireConfirmedInputs, forRemote = nodeParams.channelConf.requireConfirmedInputsForDualFunding)
)
val txBuilder = context.spawnAnonymous(InteractiveTxBuilder(
randomBytes32(),
Expand All @@ -552,7 +554,7 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
wallet))
txBuilder ! InteractiveTxBuilder.Start(self)
val toSend = Seq(
Some(TxAckRbf(d.channelId, fundingParams.localContribution)),
Some(TxAckRbf(d.channelId, fundingParams.localContribution, nodeParams.channelConf.requireConfirmedInputsForDualFunding)),
if (remainingRbfAttempts <= 3) Some(Warning(d.channelId, s"will accept at most ${remainingRbfAttempts - 1} future rbf attempts")) else None,
).flatten
stay() using d.copy(rbfStatus = RbfStatus.RbfInProgress(cmd_opt = None, txBuilder, remoteCommitSig = None)) sending toSend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ sealed trait OpenDualFundedChannelTlv extends Tlv

sealed trait AcceptDualFundedChannelTlv extends Tlv

sealed trait TxInitRbfTlv extends Tlv

sealed trait TxAckRbfTlv extends Tlv

sealed trait SpliceInitTlv extends Tlv

sealed trait SpliceAckTlv extends Tlv
Expand All @@ -56,7 +60,7 @@ object ChannelTlv {
tlv => Features(tlv.channelType.features.map(f => f -> FeatureSupport.Mandatory).toMap).toByteVector
))

case class RequireConfirmedInputsTlv() extends OpenDualFundedChannelTlv with AcceptDualFundedChannelTlv with SpliceInitTlv with SpliceAckTlv
case class RequireConfirmedInputsTlv() extends OpenDualFundedChannelTlv with AcceptDualFundedChannelTlv with TxInitRbfTlv with TxAckRbfTlv with SpliceInitTlv with SpliceAckTlv

val requireConfirmedInputsCodec: Codec[RequireConfirmedInputsTlv] = tlvField(provide(RequireConfirmedInputsTlv()))

Expand Down Expand Up @@ -99,6 +103,36 @@ object OpenDualFundedChannelTlv {
)
}

object TxRbfTlv {
/**
* Amount that the peer will contribute to the transaction's shared output.
* When used for splicing, this is a signed value that represents funds that are added or removed from the channel.
*/
case class SharedOutputContributionTlv(amount: Satoshi) extends TxInitRbfTlv with TxAckRbfTlv
}

object TxInitRbfTlv {

import ChannelTlv._
import TxRbfTlv._

val txInitRbfTlvCodec: Codec[TlvStream[TxInitRbfTlv]] = tlvStream(discriminated[TxInitRbfTlv].by(varint)
.typecase(UInt64(0), tlvField(satoshiSigned.as[SharedOutputContributionTlv]))
.typecase(UInt64(2), requireConfirmedInputsCodec)
)
}

object TxAckRbfTlv {

import ChannelTlv._
import TxRbfTlv._

val txAckRbfTlvCodec: Codec[TlvStream[TxAckRbfTlv]] = tlvStream(discriminated[TxAckRbfTlv].by(varint)
.typecase(UInt64(0), tlvField(satoshiSigned.as[SharedOutputContributionTlv]))
.typecase(UInt64(2), requireConfirmedInputsCodec)
)
}

object SpliceInitTlv {

import ChannelTlv._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

package fr.acinq.eclair.wire.protocol

import fr.acinq.bitcoin.scalacompat.{ByteVector64, Satoshi, TxId}
import fr.acinq.bitcoin.scalacompat.{ByteVector64, TxId}
import fr.acinq.eclair.UInt64
import fr.acinq.eclair.wire.protocol.CommonCodecs.{bytes64, satoshiSigned, txIdAsHash, varint}
import fr.acinq.eclair.wire.protocol.CommonCodecs.{bytes64, txIdAsHash, varint}
import fr.acinq.eclair.wire.protocol.TlvCodecs.{tlvField, tlvStream}
import scodec.Codec
import scodec.codecs.discriminated
Expand Down Expand Up @@ -74,36 +74,6 @@ object TxSignaturesTlv {
)
}

sealed trait TxInitRbfTlv extends Tlv

sealed trait TxAckRbfTlv extends Tlv

object TxRbfTlv {
/**
* Amount that the peer will contribute to the transaction's shared output.
* When used for splicing, this is a signed value that represents funds that are added or removed from the channel.
*/
case class SharedOutputContributionTlv(amount: Satoshi) extends TxInitRbfTlv with TxAckRbfTlv
}

object TxInitRbfTlv {

import TxRbfTlv._

val txInitRbfTlvCodec: Codec[TlvStream[TxInitRbfTlv]] = tlvStream(discriminated[TxInitRbfTlv].by(varint)
.typecase(UInt64(0), tlvField(satoshiSigned.as[SharedOutputContributionTlv]))
)
}

object TxAckRbfTlv {

import TxRbfTlv._

val txAckRbfTlvCodec: Codec[TlvStream[TxAckRbfTlv]] = tlvStream(discriminated[TxAckRbfTlv].by(varint)
.typecase(UInt64(0), tlvField(satoshiSigned.as[SharedOutputContributionTlv]))
)
}

sealed trait TxAbortTlv extends Tlv

object TxAbortTlv {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,33 @@ case class TxInitRbf(channelId: ByteVector32,
feerate: FeeratePerKw,
tlvStream: TlvStream[TxInitRbfTlv] = TlvStream.empty) extends InteractiveTxMessage with HasChannelId {
val fundingContribution: Satoshi = tlvStream.get[TxRbfTlv.SharedOutputContributionTlv].map(_.amount).getOrElse(0 sat)
val requireConfirmedInputs: Boolean = tlvStream.get[ChannelTlv.RequireConfirmedInputsTlv].nonEmpty
}

object TxInitRbf {
def apply(channelId: ByteVector32, lockTime: Long, feerate: FeeratePerKw, fundingContribution: Satoshi): TxInitRbf =
TxInitRbf(channelId, lockTime, feerate, TlvStream[TxInitRbfTlv](TxRbfTlv.SharedOutputContributionTlv(fundingContribution)))
def apply(channelId: ByteVector32, lockTime: Long, feerate: FeeratePerKw, fundingContribution: Satoshi, requireConfirmedInputs: Boolean): TxInitRbf = {
val tlvs: Set[TxInitRbfTlv] = Set(
Some(TxRbfTlv.SharedOutputContributionTlv(fundingContribution)),
if (requireConfirmedInputs) Some(ChannelTlv.RequireConfirmedInputsTlv()) else None,
).flatten
TxInitRbf(channelId, lockTime, feerate, TlvStream(tlvs))
}
}

case class TxAckRbf(channelId: ByteVector32,
tlvStream: TlvStream[TxAckRbfTlv] = TlvStream.empty) extends InteractiveTxMessage with HasChannelId {
val fundingContribution: Satoshi = tlvStream.get[TxRbfTlv.SharedOutputContributionTlv].map(_.amount).getOrElse(0 sat)
val requireConfirmedInputs: Boolean = tlvStream.get[ChannelTlv.RequireConfirmedInputsTlv].nonEmpty
}

object TxAckRbf {
def apply(channelId: ByteVector32, fundingContribution: Satoshi): TxAckRbf =
TxAckRbf(channelId, TlvStream[TxAckRbfTlv](TxRbfTlv.SharedOutputContributionTlv(fundingContribution)))
def apply(channelId: ByteVector32, fundingContribution: Satoshi, requireConfirmedInputs: Boolean): TxAckRbf = {
val tlvs: Set[TxAckRbfTlv] = Set(
Some(TxRbfTlv.SharedOutputContributionTlv(fundingContribution)),
if (requireConfirmedInputs) Some(ChannelTlv.RequireConfirmedInputsTlv()) else None,
).flatten
TxAckRbf(channelId, TlvStream(tlvs))
}
}

case class TxAbort(channelId: ByteVector32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class WaitForDualFundingConfirmedStateSpec extends TestKitBaseClass with Fixture
test("recv TxInitRbf (exhausted RBF attempts)", Tag(ChannelStateTestsTags.DualFunding), Tag(ChannelStateTestsTags.RejectRbfAttempts)) { f =>
import f._

bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, 500_000 sat)
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, 500_000 sat, requireConfirmedInputs = false)
assert(bob2alice.expectMsgType[TxAbort].toAscii == InvalidRbfAttemptsExhausted(channelId(bob), 0).getMessage)
assert(bob.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED)
}
Expand All @@ -412,7 +412,7 @@ class WaitForDualFundingConfirmedStateSpec extends TestKitBaseClass with Fixture
import f._

val currentBlockHeight = bob.stateData.asInstanceOf[DATA_WAIT_FOR_DUAL_FUNDING_CONFIRMED].latestFundingTx.createdAt
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, 500_000 sat)
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, 500_000 sat, requireConfirmedInputs = false)
assert(bob2alice.expectMsgType[TxAbort].toAscii == InvalidRbfAttemptTooSoon(channelId(bob), currentBlockHeight, currentBlockHeight + 1).getMessage)
assert(bob.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED)
}
Expand All @@ -421,7 +421,7 @@ class WaitForDualFundingConfirmedStateSpec extends TestKitBaseClass with Fixture
import f._

val fundingBelowPushAmount = 199_000.sat
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, fundingBelowPushAmount)
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, fundingBelowPushAmount, requireConfirmedInputs = false)
assert(bob2alice.expectMsgType[TxAbort].toAscii == InvalidPushAmount(channelId(bob), TestConstants.initiatorPushAmount, fundingBelowPushAmount.toMilliSatoshi).getMessage)
assert(bob.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED)
}
Expand All @@ -432,7 +432,7 @@ class WaitForDualFundingConfirmedStateSpec extends TestKitBaseClass with Fixture
alice ! CMD_BUMP_FUNDING_FEE(TestProbe().ref, TestConstants.feeratePerKw * 1.25, 0)
alice2bob.expectMsgType[TxInitRbf]
val fundingBelowPushAmount = 99_000.sat
alice ! TxAckRbf(channelId(alice), fundingBelowPushAmount)
alice ! TxAckRbf(channelId(alice), fundingBelowPushAmount, requireConfirmedInputs = false)
assert(alice2bob.expectMsgType[TxAbort].toAscii == InvalidPushAmount(channelId(alice), TestConstants.nonInitiatorPushAmount, fundingBelowPushAmount.toMilliSatoshi).getMessage)
assert(alice.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import fr.acinq.eclair.router.Announcements
import fr.acinq.eclair.wire.protocol.ChannelTlv.{ChannelTypeTlv, PushAmountTlv, RequireConfirmedInputsTlv, UpfrontShutdownScriptTlv}
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs._
import fr.acinq.eclair.wire.protocol.ReplyChannelRangeTlv._
import fr.acinq.eclair.wire.protocol.TxRbfTlv.SharedOutputContributionTlv
import org.json4s.jackson.Serialization
import org.scalatest.funsuite.AnyFunSuite
import scodec.DecodeResult
Expand Down Expand Up @@ -195,13 +194,13 @@ class LightningMessageCodecsSpec extends AnyFunSuite {
TxSignatures(channelId2, tx1, Nil, None) -> hex"0047 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 1f2ec025a33e39ef8e177afcdc1adc855bf128dc906182255aeb64efa825f106 0000",
TxSignatures(channelId2, tx1, Nil, Some(signature)) -> hex"0047 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 1f2ec025a33e39ef8e177afcdc1adc855bf128dc906182255aeb64efa825f106 0000 fd0259 40 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
TxInitRbf(channelId1, 8388607, FeeratePerKw(4000 sat)) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 007fffff 00000fa0",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), TlvStream[TxInitRbfTlv](SharedOutputContributionTlv(1_500_000 sat))) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 0008000000000016e360",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), TlvStream[TxInitRbfTlv](SharedOutputContributionTlv(0 sat))) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 00080000000000000000",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), TlvStream[TxInitRbfTlv](SharedOutputContributionTlv(-25_000 sat))) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 0008ffffffffffff9e58",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), 1_500_000 sat, requireConfirmedInputs = true) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 0008000000000016e360 0200",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), 0 sat, requireConfirmedInputs = false) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 00080000000000000000",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), -25_000 sat, requireConfirmedInputs = false) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 0008ffffffffffff9e58",
TxAckRbf(channelId2) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
TxAckRbf(channelId2, TlvStream[TxAckRbfTlv](SharedOutputContributionTlv(450_000 sat))) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 0008000000000006ddd0",
TxAckRbf(channelId2, TlvStream[TxAckRbfTlv](SharedOutputContributionTlv(0 sat))) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 00080000000000000000",
TxAckRbf(channelId2, TlvStream[TxAckRbfTlv](SharedOutputContributionTlv(-250_000 sat))) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 0008fffffffffffc2f70",
TxAckRbf(channelId2, 450_000 sat, requireConfirmedInputs = false) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 0008000000000006ddd0",
TxAckRbf(channelId2, 0 sat, requireConfirmedInputs = false) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 00080000000000000000",
TxAckRbf(channelId2, -250_000 sat, requireConfirmedInputs = true) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 0008fffffffffffc2f70 0200",
TxAbort(channelId1, hex"") -> hex"004a aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 0000",
TxAbort(channelId1, ByteVector.view("internal error".getBytes(Charsets.US_ASCII))) -> hex"004a aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 000e 696e7465726e616c206572726f72",
)
Expand Down Expand Up @@ -458,8 +457,6 @@ class LightningMessageCodecsSpec extends AnyFunSuite {
}
}

case class TestItem(msg: Any, hex: String)

test("test vectors for extended channel queries ") {
val refs = Map(
QueryChannelRange(Block.RegtestGenesisBlock.hash, BlockHeight(100000), 1500, TlvStream.empty) ->
Expand Down Expand Up @@ -494,61 +491,10 @@ class LightningMessageCodecsSpec extends AnyFunSuite {
QueryShortChannelIds(Block.RegtestGenesisBlock.hash, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(RealShortChannelId(14200), RealShortChannelId(46645), RealShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) ->
hex"010506226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f001801789c63600001f30a30c5b0cd144cb92e3b020017c6034a010c01789c6364620100000e0008"
)

val items = refs.map { case (obj, refbin) =>
refs.map { case (obj, refbin) =>
val bin = lightningMessageCodec.encode(obj).require
assert(refbin.bits == bin)
TestItem(obj, bin.toHex)
}

// NB: uncomment this to update the test vectors

/*class EncodingTypeSerializer extends CustomSerializer[EncodingType](format => ( {
null
}, {
case EncodingType.UNCOMPRESSED => JString("UNCOMPRESSED")
case EncodingType.COMPRESSED_ZLIB => JString("COMPRESSED_ZLIB")
}))

class ExtendedQueryFlagsSerializer extends CustomSerializer[QueryChannelRangeTlv.QueryFlags](format => ( {
null
}, {
case QueryChannelRangeTlv.QueryFlags(flag) =>
JString(((if (QueryChannelRangeTlv.QueryFlags.wantTimestamps(flag)) List("WANT_TIMESTAMPS") else List()) ::: (if (QueryChannelRangeTlv.QueryFlags.wantChecksums(flag)) List("WANT_CHECKSUMS") else List())) mkString (" | "))
}))

implicit val formats = org.json4s.DefaultFormats.withTypeHintFieldName("type") +
new EncodingTypeSerializer +
new ExtendedQueryFlagsSerializer +
new ByteVectorSerializer +
new ByteVector32Serializer +
new UInt64Serializer +
new MilliSatoshiSerializer +
new ShortChannelIdSerializer +
new StateSerializer +
new ShaChainSerializer +
new PublicKeySerializer +
new PrivateKeySerializer +
new TransactionSerializer +
new TransactionWithInputInfoSerializer +
new InetSocketAddressSerializer +
new OutPointSerializer +
new OutPointKeySerializer +
new InputInfoSerializer +
new ColorSerializer +
new RouteResponseSerializer +
new ThrowableSerializer +
new FailureMessageSerializer +
new NodeAddressSerializer +
new DirectionSerializer +
new InvoiceSerializer +
ShortTypeHints(List(
classOf[QueryChannelRange],
classOf[ReplyChannelRange],
classOf[QueryShortChannelIds]))

val json = Serialization.writePretty(items)
println(json)*/
}

test("decode channel_update with htlc_maximum_msat") {
Expand Down