Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Sphinx privacy leak #1247

Merged
merged 3 commits into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,20 @@ object Sphinx extends Logging {
* @param associatedData associated data.
* @param ephemeralPublicKey ephemeral key shared with the target node.
* @param sharedSecret shared secret with this hop.
* @param packet current packet (None if the packet hasn't been initialized).
* @param packet current packet or random bytes if the packet hasn't been initialized.
* @param onionPayloadFiller optional onion payload filler, needed only when you're constructing the last packet.
t-bast marked this conversation as resolved.
Show resolved Hide resolved
* @return the next packet.
*/
def wrap(payload: ByteVector, associatedData: ByteVector32, ephemeralPublicKey: PublicKey, sharedSecret: ByteVector32, packet: Option[wire.OnionRoutingPacket], onionPayloadFiller: ByteVector = ByteVector.empty): wire.OnionRoutingPacket = {
def wrap(payload: ByteVector, associatedData: ByteVector32, ephemeralPublicKey: PublicKey, sharedSecret: ByteVector32, packet: Either[ByteVector, wire.OnionRoutingPacket], onionPayloadFiller: ByteVector = ByteVector.empty): wire.OnionRoutingPacket = {
require(payload.length <= PayloadLength - MacLength, s"packet payload cannot exceed ${PayloadLength - MacLength} bytes")

val (currentMac, currentPayload): (ByteVector32, ByteVector) = packet match {
// Packet construction starts with an empty mac and payload.
case None => (ByteVector32.Zeroes, ByteVector.fill(PayloadLength)(0))
case Some(p) => (p.hmac, p.payload)
// Packet construction starts with an empty mac and random payload.
case Left(startingBytes) =>
require(startingBytes.length == PayloadLength, "invalid initial random bytes length")
(ByteVector32.Zeroes, startingBytes)
case Right(p) => (p.hmac, p.payload)
}

val nextOnionPayload = {
val onionPayload1 = payload ++ currentMac ++ currentPayload.dropRight(payload.length + MacLength)
val onionPayload2 = onionPayload1 xor generateStream(generateKey("rho", sharedSecret), PayloadLength)
Expand All @@ -248,12 +249,14 @@ object Sphinx extends Logging {
val (ephemeralPublicKeys, sharedsecrets) = computeEphemeralPublicKeysAndSharedSecrets(sessionKey, publicKeys)
val filler = generateFiller("rho", sharedsecrets.dropRight(1), payloads.dropRight(1))

val lastPacket = wrap(payloads.last, associatedData, ephemeralPublicKeys.last, sharedsecrets.last, None, filler)
// We deterministically-derive the initial payload bytes: see https://github.com/lightningnetwork/lightning-rfc/pull/697
val startingBytes = generateStream(generateKey("pad", sessionKey.value), PayloadLength)
val lastPacket = wrap(payloads.last, associatedData, ephemeralPublicKeys.last, sharedsecrets.last, Left(startingBytes), filler)

@tailrec
def loop(hopPayloads: Seq[ByteVector], ephKeys: Seq[PublicKey], sharedSecrets: Seq[ByteVector32], packet: wire.OnionRoutingPacket): wire.OnionRoutingPacket = {
if (hopPayloads.isEmpty) packet else {
val nextPacket = wrap(hopPayloads.last, associatedData, ephKeys.last, sharedSecrets.last, Some(packet))
val nextPacket = wrap(hopPayloads.last, associatedData, ephKeys.last, sharedSecrets.last, Right(packet))
loop(hopPayloads.dropRight(1), ephKeys.dropRight(1), sharedSecrets.dropRight(1), nextPacket)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ object IncomingPacket {
}

private def validateFinal(add: UpdateAddHtlc, payload: Onion.FinalPayload): Either[FailureMessage, IncomingPacket] = {
if (add.amountMsat < payload.amount) {
if (add.amountMsat != payload.amount) {
Left(FinalIncorrectHtlcAmount(add.amountMsat))
} else if (add.cltvExpiry != payload.expiry) {
Left(FinalIncorrectCltvExpiry(add.cltvExpiry))
Expand All @@ -112,7 +112,7 @@ object IncomingPacket {
}

private def validateFinal(add: UpdateAddHtlc, outerPayload: Onion.FinalPayload, innerPayload: Onion.FinalPayload): Either[FailureMessage, IncomingPacket] = {
if (add.amountMsat < outerPayload.amount) {
if (add.amountMsat != outerPayload.amount) {
Left(FinalIncorrectHtlcAmount(add.amountMsat))
} else if (add.cltvExpiry != outerPayload.expiry) {
Left(FinalIncorrectCltvExpiry(add.cltvExpiry))
Expand Down
Loading