Skip to content

Commit

Permalink
WIP Add musig2-based swap-in protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
sstone committed Nov 5, 2023
1 parent 0ef0cd0 commit 2bfbc19
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 49 deletions.
66 changes: 50 additions & 16 deletions src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package fr.acinq.lightning.channel

import fr.acinq.bitcoin.*
import fr.acinq.bitcoin.Script.tail
import fr.acinq.bitcoin.musig2.PublicNonce
import fr.acinq.bitcoin.musig2.SecretNonce
import fr.acinq.lightning.Lightning.randomBytes32
import fr.acinq.lightning.MilliSatoshi
import fr.acinq.lightning.blockchain.electrum.WalletState
import fr.acinq.lightning.blockchain.fee.FeeratePerKw
Expand Down Expand Up @@ -112,7 +115,7 @@ sealed class InteractiveTxInput {
}

/** A local input that funds the interactive transaction, coming from a 2-of-2 swap-in transaction. */
data class LocalSwapIn(override val serialId: Long, override val previousTx: Transaction, override val previousTxOutput: Long, override val sequence: UInt, val userKey: PublicKey, val serverKey: PublicKey, val refundDelay: Int) : Local() {
data class LocalSwapIn(override val serialId: Long, override val previousTx: Transaction, override val previousTxOutput: Long, override val sequence: UInt, val userKey: PublicKey, val serverKey: PublicKey, val refundDelay: Int, val userNonce: SecretNonce?= null, val serverNonce: PublicNonce? = null) : Local() {
override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput)
override val txOut: TxOut = previousTx.txOut[previousTxOutput.toInt()]
}
Expand Down Expand Up @@ -351,9 +354,12 @@ data class SharedTransaction(
// If we are swapping funds in, we provide our partial signatures to the corresponding inputs.
val swapUserSigs = unsignedTx.txIn.mapIndexed { i, txIn ->
localInputs
.filterIsInstance<InteractiveTxInput.LocalSwapIn>()
.filter { input -> keyManager.swapInOnChainWallet.swapInProtocol.isMine(input.txOut) }
.find { txIn.outPoint == it.outPoint }
?.let { input -> keyManager.swapInOnChainWallet.signSwapInputUser(unsignedTx, i, input.txOut) }
?.let { input -> keyManager.swapInOnChainWallet.signSwapInputUser(unsignedTx, i, input.previousTx.txOut) }
}.filterNotNull()

// If the remote is swapping funds in, they'll need our partial signatures to finalize their witness.
val swapServerSigs = unsignedTx.txIn.mapIndexed { i, txIn ->
remoteInputs
Expand All @@ -365,7 +371,16 @@ data class SharedTransaction(
swapInProtocol.signSwapInputServer(unsignedTx, i, input.txOut, serverKey)
}
}.filterNotNull()
return PartiallySignedSharedTransaction(this, TxSignatures(fundingParams.channelId, unsignedTx, listOf(), sharedSig, swapUserSigs, swapServerSigs))

// val swapUserPartialSigs = unsignedTx.txIn.mapIndexed { i, txIn ->
// localInputs
// .filterIsInstance<InteractiveTxInput.LocalSwapIn>()
// .filter { input -> keyManager.swapInOnChainWallet.swapInProtocolMusig2.isMine(input.txOut) }
// .find { txIn.outPoint == it.outPoint }
// ?.let { input -> keyManager.swapInOnChainWallet.signSwapInputUser(unsignedTx, i, input.previousTx.txOut) }
// }.filterNotNull()

return PartiallySignedSharedTransaction(this, TxSignatures(fundingParams.channelId, unsignedTx, listOf(), sharedSig, swapUserSigs, swapServerSigs, listOf(), listOf()))
}
}

Expand Down Expand Up @@ -462,6 +477,7 @@ sealed class InteractiveTxSessionAction {
data class InvalidTxWeight(val channelId: ByteVector32, val txId: ByteVector32) : RemoteFailure() { override fun toString(): String = "transaction weight is too big for standardness rules (txId=$txId)" }
data class InvalidTxFeerate(val channelId: ByteVector32, val txId: ByteVector32, val targetFeerate: FeeratePerKw, val actualFeerate: FeeratePerKw) : RemoteFailure() { override fun toString(): String = "transaction feerate too low (txId=$txId, targetFeerate=$targetFeerate, actualFeerate=$actualFeerate" }
data class InvalidTxDoesNotDoubleSpendPreviousTx(val channelId: ByteVector32, val txId: ByteVector32, val previousTxId: ByteVector32) : RemoteFailure() { override fun toString(): String = "transaction replacement with txId=$txId doesn't double-spend previous attempt (txId=$previousTxId)" }
data class MissingNonce(val channelId: ByteVector32, val serialId: Long): RemoteFailure() { override fun toString(): String = "missing musig2 nonce for input serial_id=$serialId)" }
// @formatter:on
}

Expand All @@ -476,8 +492,8 @@ data class InteractiveTxSession(
val remoteInputs: List<InteractiveTxInput.Incoming> = listOf(),
val localOutputs: List<InteractiveTxOutput.Outgoing> = listOf(),
val remoteOutputs: List<InteractiveTxOutput.Incoming> = listOf(),
val txCompleteSent: Boolean = false,
val txCompleteReceived: Boolean = false,
val txCompleteSent: TxComplete? = null,
val txCompleteReceived: TxComplete? = null,
val inputsReceivedCount: Int = 0,
val outputsReceivedCount: Int = 0,
) {
Expand Down Expand Up @@ -513,21 +529,24 @@ data class InteractiveTxSession(
previousTxs
)

val isComplete: Boolean = txCompleteSent && txCompleteReceived
val isComplete: Boolean = txCompleteSent != null && txCompleteReceived != null

fun send(): Pair<InteractiveTxSession, InteractiveTxSessionAction> {
return when (val msg = toSend.firstOrNull()) {
null -> {
val txComplete = TxComplete(fundingParams.channelId)
val next = copy(txCompleteSent = true)
// generate a new secret nonce for each musig2 swapin every time we send TxComplete
val localMusig2SwapIns = localInputs.filterIsInstance<InteractiveTxInput.LocalSwapIn>().filter { swapInKeys.swapInProtocolMusig2.isMine(it.txOut) }
val secretNonces = localMusig2SwapIns.map { it.serialId to SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPublicKey, null, null, null, randomBytes32()) }.toMap()
val txComplete = TxComplete(fundingParams.channelId, secretNonces = secretNonces)
val next = copy(txCompleteSent = txComplete)
if (next.isComplete) {
Pair(next, next.validateTx(txComplete))
} else {
Pair(next, InteractiveTxSessionAction.SendMessage(txComplete))
}
}
is Either.Left -> {
val next = copy(toSend = toSend.tail(), localInputs = localInputs + msg.value, txCompleteSent = false)
val next = copy(toSend = toSend.tail(), localInputs = localInputs + msg.value, txCompleteSent = null)
val swapInParams = TxAddInputTlv.SwapInParams(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.refundDelay)
val txAddInput = when (msg.value) {
is InteractiveTxInput.LocalOnly -> TxAddInput(fundingParams.channelId, msg.value.serialId, msg.value.previousTx, msg.value.previousTxOutput, msg.value.sequence)
Expand All @@ -537,7 +556,7 @@ data class InteractiveTxSession(
Pair(next, InteractiveTxSessionAction.SendMessage(txAddInput))
}
is Either.Right -> {
val next = copy(toSend = toSend.tail(), localOutputs = localOutputs + msg.value, txCompleteSent = false)
val next = copy(toSend = toSend.tail(), localOutputs = localOutputs + msg.value, txCompleteSent = null)
val txAddOutput = when (msg.value) {
is InteractiveTxOutput.Local -> TxAddOutput(fundingParams.channelId, msg.value.serialId, msg.value.amount, msg.value.pubkeyScript)
is InteractiveTxOutput.Shared -> TxAddOutput(fundingParams.channelId, msg.value.serialId, msg.value.amount, msg.value.pubkeyScript)
Expand Down Expand Up @@ -617,19 +636,19 @@ data class InteractiveTxSession(
is TxAddInput -> {
receiveInput(message).fold(
{ f -> Pair(this, f) },
{ input -> copy(remoteInputs = remoteInputs + input, inputsReceivedCount = inputsReceivedCount + 1, txCompleteReceived = false).send() }
{ input -> copy(remoteInputs = remoteInputs + input, inputsReceivedCount = inputsReceivedCount + 1, txCompleteReceived = null).send() }
)
}
is TxAddOutput -> {
receiveOutput(message).fold(
{ f -> Pair(this, f) },
{ output -> copy(remoteOutputs = remoteOutputs + output, outputsReceivedCount = outputsReceivedCount + 1, txCompleteReceived = false).send() }
{ output -> copy(remoteOutputs = remoteOutputs + output, outputsReceivedCount = outputsReceivedCount + 1, txCompleteReceived = null).send() }
)
}
is TxRemoveInput -> {
val remoteInputs1 = remoteInputs.filterNot { i -> (i as InteractiveTxInput).serialId == message.serialId }
if (remoteInputs.size != remoteInputs1.size) {
val next = copy(remoteInputs = remoteInputs1, txCompleteReceived = false)
val next = copy(remoteInputs = remoteInputs1, txCompleteReceived = null)
next.send()
} else {
Pair(this, InteractiveTxSessionAction.UnknownSerialId(message.channelId, message.serialId))
Expand All @@ -638,14 +657,14 @@ data class InteractiveTxSession(
is TxRemoveOutput -> {
val remoteOutputs1 = remoteOutputs.filterNot { o -> (o as InteractiveTxOutput).serialId == message.serialId }
if (remoteOutputs.size != remoteOutputs1.size) {
val next = copy(remoteOutputs = remoteOutputs1, txCompleteReceived = false)
val next = copy(remoteOutputs = remoteOutputs1, txCompleteReceived = null)
next.send()
} else {
Pair(this, InteractiveTxSessionAction.UnknownSerialId(message.channelId, message.serialId))
}
}
is TxComplete -> {
val next = copy(txCompleteReceived = true)
val next = copy(txCompleteReceived = message)
if (next.isComplete) {
Pair(next, next.validateTx(null))
} else {
Expand All @@ -656,6 +675,10 @@ data class InteractiveTxSession(
}

private fun validateTx(txComplete: TxComplete?): InteractiveTxSessionAction {
// tx_complete MUST have been sent and received for us to reach this state, require is used here to tell the compiler that txCompleteSent and txCompleteReceived are not null
require(txCompleteSent != null)
require(txCompleteReceived != null)

if (localInputs.size + remoteInputs.size > 252 || localOutputs.size + remoteOutputs.size > 252) {
return InteractiveTxSessionAction.InvalidTxInputOutputCount(fundingParams.channelId, localInputs.size + remoteInputs.size, localOutputs.size + remoteOutputs.size)
}
Expand Down Expand Up @@ -690,8 +713,19 @@ data class InteractiveTxSession(
}
sharedInputs.first()
}
val localOnlyInputsWithNonces = localOnlyInputs.map {
when {
it is InteractiveTxInput.LocalSwapIn && swapInKeys.swapInProtocolMusig2.isMine(it.txOut) -> {
val userNonce = txCompleteSent.secretNonces[it.serialId]
val serverNonce = txCompleteReceived.serverNonces[it.serialId]
if (userNonce == null || serverNonce == null) return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
it.copy(userNonce = userNonce, serverNonce = serverNonce)
}
else -> it
}
}

val sharedTx = SharedTransaction(sharedInput, sharedOutput, localOnlyInputs, remoteOnlyInputs, localOnlyOutputs, remoteOnlyOutputs, fundingParams.lockTime)
val sharedTx = SharedTransaction(sharedInput, sharedOutput, localOnlyInputsWithNonces, remoteOnlyInputs, localOnlyOutputs, remoteOnlyOutputs, fundingParams.lockTime)
val tx = sharedTx.buildUnsignedTx()
if (sharedTx.localAmountIn < sharedTx.localAmountOut || sharedTx.remoteAmountIn < sharedTx.remoteAmountOut) {
return InteractiveTxSessionAction.InvalidTxChangeAmount(fundingParams.channelId, tx.txid)
Expand Down
33 changes: 21 additions & 12 deletions src/commonMain/kotlin/fr/acinq/lightning/crypto/KeyManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package fr.acinq.lightning.crypto
import fr.acinq.bitcoin.*
import fr.acinq.bitcoin.DeterministicWallet.hardened
import fr.acinq.bitcoin.io.ByteArrayInput
import fr.acinq.bitcoin.musig2.PublicNonce
import fr.acinq.bitcoin.musig2.SecretNonce
import fr.acinq.lightning.DefaultSwapInParams
import fr.acinq.lightning.NodeParams
import fr.acinq.lightning.blockchain.fee.FeeratePerKw
import fr.acinq.lightning.transactions.SwapInProtocol
import fr.acinq.lightning.transactions.SwapInProtocolMusig2
import fr.acinq.lightning.transactions.Transactions
import fr.acinq.lightning.utils.sum
import fr.acinq.lightning.utils.toByteVector
Expand Down Expand Up @@ -128,7 +131,8 @@ interface KeyManager {
fun localServerPrivateKey(remoteNodeId: PublicKey): PrivateKey = DeterministicWallet.derivePrivateKey(localServerExtendedPrivateKey, perUserPath(remoteNodeId)).privateKey

val swapInProtocol = SwapInProtocol(userPublicKey, remoteServerPublicKey, refundDelay)
val redeemScript: List<ScriptElt> = swapInProtocol.redeemScript
val swapInProtocolMusig2 = SwapInProtocolMusig2(userPublicKey, remoteServerPublicKey, refundDelay)

val pubkeyScript: List<ScriptElt> = swapInProtocol.pubkeyScript
val address: String = swapInProtocol.address(chain)

Expand All @@ -146,14 +150,10 @@ interface KeyManager {
"wsh(and_v(v:pk($userKey),or_d(pk(${remoteServerPublicKey.toHex()}),older($refundDelay))))"
}

fun signSwapInputUser(fundingTx: Transaction, index: Int, parentTxOut: TxOut): ByteVector64 {
return swapInProtocol.signSwapInputUser(fundingTx, index, parentTxOut, userPrivateKey)
fun signSwapInputUser(fundingTx: Transaction, index: Int, parentTxOuts: List<TxOut>): ByteVector64 {
return swapInProtocol.signSwapInputUser(fundingTx, index, parentTxOuts[fundingTx.txIn[index].outPoint.index.toInt()] , userPrivateKey)
}

fun signSwapInputServer(fundingTx: Transaction, index: Int, parentTxOut: TxOut, remoteNodeId: PublicKey): ByteVector64 {
return swapInProtocol.signSwapInputServer(fundingTx, index, parentTxOut, localServerPrivateKey(remoteNodeId))
}

/**
* Create a recovery transaction that spends a swap-in transaction after the refund delay has passed
* @param swapInTx swap-in transaction
Expand All @@ -162,7 +162,7 @@ interface KeyManager {
* @return a signed transaction that spends our swap-in transaction. It cannot be published until `swapInTx` has enough confirmations
*/
fun createRecoveryTransaction(swapInTx: Transaction, address: String, feeRate: FeeratePerKw): Transaction? {
val utxos = swapInTx.txOut.filter { it.publicKeyScript.contentEquals(Script.write(pubkeyScript)) }
val utxos = swapInTx.txOut.filter { it.publicKeyScript.contentEquals(Script.write(swapInProtocol.pubkeyScript)) || it.publicKeyScript.contentEquals(Script.write(swapInProtocolMusig2.pubkeyScript))}
return if (utxos.isEmpty()) {
null
} else {
Expand All @@ -175,17 +175,26 @@ interface KeyManager {
txOut = listOf(ourOutput),
lockTime = 0
)
val fees = run {
val recoveryTx = utxos.foldIndexed(unsignedTx) { index, tx, utxo ->

fun sign(tx: Transaction, index: Int, utxo: TxOut): Transaction {
return if (swapInProtocol.isMine(utxo)) {
val sig = swapInProtocol.signSwapInputUser(tx, index, utxo, userPrivateKey)
tx.updateWitness(index, swapInProtocol.witnessRefund(sig))
} else {
val sig = swapInProtocolMusig2.signSwapInputRefund(tx, index, utxos, userPrivateKey)
tx.updateWitness(index, swapInProtocolMusig2.witnessRefund(sig))
}
}

val fees = run {
val recoveryTx = utxos.foldIndexed(unsignedTx) { index, tx, utxo ->
sign(tx, index, utxo)
}
Transactions.weight2fee(feeRate, recoveryTx.weight())
}
val unsignedTx1 = unsignedTx.copy(txOut = listOf(ourOutput.copy(amount = ourOutput.amount - fees)))
val recoveryTx = utxos.foldIndexed(unsignedTx1) { index, tx, utxo ->
val sig = swapInProtocol.signSwapInputUser(tx, index, utxo, userPrivateKey)
tx.updateWitness(index, swapInProtocol.witnessRefund(sig))
sign(tx, index, utxo)
}
// this tx is signed but cannot be published until swapInTx has `refundDelay` confirmations
recoveryTx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class SwapInProtocol(val userPublicKey: PublicKey, val serverPublicKey: PublicKe

val pubkeyScript: List<ScriptElt> = Script.pay2wsh(redeemScript)

fun isMine(txOut: TxOut): Boolean = txOut.publicKeyScript.contentEquals(Script.write(pubkeyScript))

fun address(chain: NodeParams.Chain): String = Bitcoin.addressFromPublicKeyScript(chain.chainHash, pubkeyScript).result!!

fun witness(userSig: ByteVector64, serverSig: ByteVector64): ScriptWitness {
Expand Down Expand Up @@ -55,6 +57,8 @@ class SwapInProtocolMusig2(val userPublicKey: PublicKey, val serverPublicKey: Pu
private val executionData = Script.ExecutionData(annex = null, tapleafHash = merkleRoot)
private val controlBlock = byteArrayOf((Script.TAPROOT_LEAF_TAPSCRIPT + (if (parity) 1 else 0)).toByte()) + internalPubKey.value.toByteArray()

fun isMine(txOut: TxOut): Boolean = txOut.publicKeyScript.contentEquals(Script.write(pubkeyScript))

fun address(chain: NodeParams.Chain): String = Bitcoin.addressFromPublicKeyScript(chain.chainHash, pubkeyScript).result!!

fun witness(commonSig: ByteVector64): ScriptWitness = ScriptWitness(listOf(commonSig))
Expand Down
Loading

0 comments on commit 2bfbc19

Please sign in to comment.