Skip to content

Commit

Permalink
Add test for splice with commit index mismatch
Browse files Browse the repository at this point in the history
This is a follow-up for #554
  • Loading branch information
t-bast committed Sep 28, 2023
1 parent 4f9d7d7 commit 062550d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 56 deletions.
40 changes: 28 additions & 12 deletions src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -433,53 +433,69 @@ object TestsHelper {
return payer0 to payee0
}

private fun <T : ChannelStateWithCommitments> receiveCommitSigs(receiver: LNChannel<T>, commitSigs: List<CommitSig>): Pair<LNChannel<T>, List<ChannelAction>> {
return commitSigs.fold(Pair(receiver, emptyList())) { pair, commitSig ->
val (statePrev, actionsPrev) = pair
assertTrue(actionsPrev.isEmpty())
val (stateNext, actionsNext) = statePrev.process(ChannelCommand.MessageReceived(commitSig))
assertIs<LNChannel<T>>(stateNext)
Pair(stateNext, actionsNext)
}
}

/**
* Cross sign nodes where nodeA initiate the signature exchange
*/
fun <T : ChannelStateWithCommitments> crossSign(nodeA: LNChannel<T>, nodeB: LNChannel<T>): Pair<LNChannel<T>, LNChannel<T>> {
fun <T : ChannelStateWithCommitments> crossSign(nodeA: LNChannel<T>, nodeB: LNChannel<T>, commitmentsCount: Int = 1): Pair<LNChannel<T>, LNChannel<T>> {
val sCommitIndex = nodeA.state.commitments.localCommitIndex
val rCommitIndex = nodeB.state.commitments.localCommitIndex
val rHasChanges = nodeB.state.commitments.changes.localHasChanges()

val (sender0, sActions0) = nodeA.process(ChannelCommand.Commitment.Sign)
val commitSig0 = sActions0.findOutgoingMessage<CommitSig>()
val commitSigs0 = sActions0.findOutgoingMessages<CommitSig>()
assertEquals(commitmentsCount, commitSigs0.size)
commitSigs0.forEach { assertEquals(commitmentsCount, it.batchSize) }

val (receiver0, rActions0) = nodeB.process(ChannelCommand.MessageReceived(commitSig0))
val (receiver0, rActions0) = receiveCommitSigs(nodeB, commitSigs0)
val revokeAndAck0 = rActions0.findOutgoingMessage<RevokeAndAck>()
val commandSign0 = rActions0.findCommand<ChannelCommand.Commitment.Sign>()

val (sender1, _) = sender0.process(ChannelCommand.MessageReceived(revokeAndAck0))
assertIs<LNChannel<T>>(sender1)
val (receiver1, rActions1) = receiver0.process(commandSign0)
val commitSig1 = rActions1.findOutgoingMessage<CommitSig>()
val commitSigs1 = rActions1.findOutgoingMessages<CommitSig>()
assertEquals(commitmentsCount, commitSigs1.size)
commitSigs1.forEach { assertEquals(commitmentsCount, it.batchSize) }

val (sender2, sActions2) = sender1.process(ChannelCommand.MessageReceived(commitSig1))
val (sender2, sActions2) = receiveCommitSigs(sender1, commitSigs1)
val revokeAndAck1 = sActions2.findOutgoingMessage<RevokeAndAck>()
val (receiver2, _) = receiver1.process(ChannelCommand.MessageReceived(revokeAndAck1))
assertIs<LNChannel<T>>(receiver2)

if (rHasChanges) {
val commandSign1 = sActions2.findCommand<ChannelCommand.Commitment.Sign>()
val (sender3, sActions3) = sender2.process(commandSign1)
val commitSig2 = sActions3.findOutgoingMessage<CommitSig>()
val commitSigs2 = sActions3.findOutgoingMessages<CommitSig>()
assertEquals(commitmentsCount, commitSigs2.size)

val (receiver3, rActions3) = receiver2.process(ChannelCommand.MessageReceived(commitSig2))
val (receiver3, rActions3) = receiveCommitSigs(receiver2, commitSigs2)
val revokeAndAck2 = rActions3.findOutgoingMessage<RevokeAndAck>()
val (sender4, _) = sender3.process(ChannelCommand.MessageReceived(revokeAndAck2))

assertIs<LNChannel<T>>(sender4)
assertIs<LNChannel<T>>(receiver3)
assertEquals(sCommitIndex + 1, sender4.commitments.localCommitIndex)
assertEquals(sCommitIndex + 2, sender4.commitments.remoteCommitIndex)
assertEquals(rCommitIndex + 2, sender4.commitments.remoteCommitIndex)
assertEquals(rCommitIndex + 2, receiver3.commitments.localCommitIndex)
assertEquals(rCommitIndex + 1, receiver3.commitments.remoteCommitIndex)
assertEquals(sCommitIndex + 1, receiver3.commitments.remoteCommitIndex)

return sender4 to receiver3
} else {
assertIs<LNChannel<T>>(sender2)
assertIs<LNChannel<T>>(receiver2)
assertEquals(sCommitIndex + 1, sender2.commitments.localCommitIndex)
assertEquals(sCommitIndex + 1, sender2.commitments.remoteCommitIndex)
assertEquals(rCommitIndex + 1, sender2.commitments.remoteCommitIndex)
assertEquals(rCommitIndex + 1, receiver2.commitments.localCommitIndex)
assertEquals(rCommitIndex + 1, receiver2.commitments.remoteCommitIndex)
assertEquals(sCommitIndex + 1, receiver2.commitments.remoteCommitIndex)

return sender2 to receiver2
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import fr.acinq.lightning.blockchain.electrum.WalletState
import fr.acinq.lightning.blockchain.fee.FeeratePerKw
import fr.acinq.lightning.channel.*
import fr.acinq.lightning.channel.TestsHelper.addHtlc
import fr.acinq.lightning.channel.TestsHelper.crossSign
import fr.acinq.lightning.channel.TestsHelper.fulfillHtlc
import fr.acinq.lightning.channel.TestsHelper.reachNormal
import fr.acinq.lightning.crypto.KeyManager
Expand Down Expand Up @@ -41,6 +42,56 @@ class SpliceTestsCommon : LightningTestSuite() {
spliceIn(alice, bob, listOf(30_000.sat, 40_000.sat, 25_000.sat))
}

@Test
fun `splice funds in -- local and remote commit index mismatch`() {
// Alice and Bob asynchronously exchange HTLCs, which makes their commit indices diverge.
val (nodes, preimages) = run {
val (alice0, bob0) = reachNormal()
// Alice sends an HTLC to Bob and signs it.
val (nodes1, preimage1, _) = addHtlc(15_000_000.msat, alice0, bob0)
val (alice1, bob1) = nodes1
val (alice2, actionsAlice2) = alice1.process(ChannelCommand.Commitment.Sign)
val sigAlice1 = actionsAlice2.findOutgoingMessage<CommitSig>()
// Bob sends an HTLC to Alice before receiving her commit_sig.
val (nodes3, preimage2, _) = addHtlc(10_000_000.msat, bob1, alice2)
val (bob3, alice3) = nodes3
// Bob receives Alice's commit_sig and also signs his HTLC.
val (bob4, actionsBob4) = bob3.process(ChannelCommand.MessageReceived(sigAlice1))
val revBob1 = actionsBob4.findOutgoingMessage<RevokeAndAck>()
actionsBob4.hasCommand<ChannelCommand.Commitment.Sign>()
val (bob5, actionsBob5) = bob4.process(ChannelCommand.Commitment.Sign)
val sigBob = actionsBob5.findOutgoingMessage<CommitSig>()
val (alice4, _) = alice3.process(ChannelCommand.MessageReceived(revBob1))
val (alice5, actionsAlice5) = alice4.process(ChannelCommand.MessageReceived(sigBob))
val revAlice = actionsAlice5.findOutgoingMessage<RevokeAndAck>()
actionsAlice5.hasCommand<ChannelCommand.Commitment.Sign>()
val (alice6, actionsAlice6) = alice5.process(ChannelCommand.Commitment.Sign)
val sigAlice2 = actionsAlice6.findOutgoingMessage<CommitSig>()
val (bob6, _) = bob5.process(ChannelCommand.MessageReceived(revAlice))
val (bob7, actionsBob7) = bob6.process(ChannelCommand.MessageReceived(sigAlice2))
assertIs<LNChannel<Normal>>(bob7)
val revBob2 = actionsBob7.findOutgoingMessage<RevokeAndAck>()
val (alice7, _) = alice6.process(ChannelCommand.MessageReceived(revBob2))
assertIs<LNChannel<Normal>>(alice7)
assertEquals(785_000_000.msat, alice7.commitments.latest.localCommit.spec.toLocal)
assertEquals(190_000_000.msat, alice7.commitments.latest.localCommit.spec.toRemote)
assertEquals(1, alice7.commitments.localCommitIndex)
assertEquals(2, alice7.commitments.remoteCommitIndex)
assertEquals(2, bob7.commitments.localCommitIndex)
assertEquals(1, bob7.commitments.remoteCommitIndex)
Pair(Pair(alice7, bob7), Pair(preimage1, preimage2))
}

// TODO: once we support quiescence, fulfill those HTLCs after the splice instead of before.
val (alice1, bob1) = fulfillHtlc(0, preimages.first, nodes.first, nodes.second)
val (bob2, alice2) = fulfillHtlc(0, preimages.second, bob1, alice1)
val (alice3, bob3) = crossSign(alice2, bob2)
assertEquals(2, alice3.commitments.localCommitIndex)
assertEquals(4, alice3.commitments.remoteCommitIndex)

spliceIn(alice3, bob3, listOf(500_000.sat))
}

@Test
fun `splice cpfp`() {
val (alice, bob) = reachNormal()
Expand Down Expand Up @@ -1069,50 +1120,6 @@ class SpliceTestsCommon : LightningTestSuite() {
}
}

private fun crossSign(alice: LNChannel<Normal>, bob: LNChannel<Normal>, commitmentsCount: Int): Pair<LNChannel<Normal>, LNChannel<Normal>> {
val commitIndexAlice = alice.state.commitments.localCommitIndex
val commitIndexBob = bob.state.commitments.localCommitIndex

val (alice1, actionsAlice1) = alice.process(ChannelCommand.Commitment.Sign)
val commitSigsAlice = actionsAlice1.findOutgoingMessages<CommitSig>()
assertEquals(commitSigsAlice.size, commitmentsCount)
commitSigsAlice.forEach { assertEquals(it.batchSize, commitmentsCount) }

val (bob2, actionsBob2) = commitSigsAlice.fold(Pair(bob, emptyList<ChannelAction>())) { pair, commitSig ->
val (bobPrev, actionsBobPrev) = pair
assertTrue(actionsBobPrev.isEmpty())
val (bobNext, actionsBobNext) = bobPrev.process(ChannelCommand.MessageReceived(commitSig))
assertIs<LNChannel<Normal>>(bobNext)
Pair(bobNext, actionsBobNext)
}
val revokeAndAckBob = actionsBob2.findOutgoingMessage<RevokeAndAck>()
val (bob3, actionsBob3) = bob2.process(actionsBob2.findCommand<ChannelCommand.Commitment.Sign>())
val commitSigsBob = actionsBob3.findOutgoingMessages<CommitSig>()
assertEquals(commitSigsBob.size, commitmentsCount)
commitSigsBob.forEach { assertEquals(it.batchSize, commitmentsCount) }

val (alice2, actionsAlice2) = alice1.process(ChannelCommand.MessageReceived(revokeAndAckBob))
actionsAlice2.has<ChannelAction.Storage.StoreState>()
val (alice3, actionsAlice3) = commitSigsBob.fold(Pair(alice2, emptyList<ChannelAction>())) { pair, commitSig ->
val (alicePrev, actionsAlicePrev) = pair
assertTrue(actionsAlicePrev.isEmpty())
val (aliceNext, actionsAliceNext) = alicePrev.process(ChannelCommand.MessageReceived(commitSig))
assertIs<LNChannel<Normal>>(aliceNext)
Pair(aliceNext, actionsAliceNext)
}
assertIs<LNChannel<Normal>>(alice3)
assertEquals(alice3.commitments.localCommitIndex, commitIndexAlice + 1)
assertEquals(alice3.commitments.remoteCommitIndex, commitIndexAlice + 1)
val revokeAndAckAlice = actionsAlice3.findOutgoingMessage<RevokeAndAck>()

val (bob4, _) = bob3.process(ChannelCommand.MessageReceived(revokeAndAckAlice))
assertIs<LNChannel<Normal>>(bob4)
assertEquals(bob4.commitments.localCommitIndex, commitIndexBob + 1)
assertEquals(bob4.commitments.remoteCommitIndex, commitIndexBob + 1)

return Pair(alice3, bob4)
}

fun disconnect(alice: LNChannel<Normal>, bob: LNChannel<Normal>): Triple<LNChannel<Syncing>, LNChannel<Syncing>, ChannelReestablish> {
val (alice1, actionsAlice1) = alice.process(ChannelCommand.Disconnected)
val (bob1, actionsBob1) = bob.process(ChannelCommand.Disconnected)
Expand Down

0 comments on commit 062550d

Please sign in to comment.