From dfd91ecec73a09e6607aab7b338bd3324927b234 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Wed, 16 Oct 2024 15:46:49 +0200 Subject: [PATCH] Share code --- .../payment/OutgoingPaymentHandler.kt | 78 ++++++++----------- 1 file changed, 34 insertions(+), 44 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt index 3ccc005d2..a55d83f21 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt @@ -27,7 +27,7 @@ import fr.acinq.lightning.wire.UnknownNextPeer class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: WalletParams, val db: OutgoingPaymentsDb) { - interface SendPaymentResult + interface SendPaymentResult: ProcessFailureResult interface ProcessFailureResult interface ProcessFulfillResult @@ -71,6 +71,37 @@ class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle private fun getPaymentAttempt(childId: UUID): PaymentAttempt? = childToPaymentId[childId]?.let { pending[it] } + private suspend fun sendPaymentInternal(request: PayInvoice, failures: List>, channels: Map, currentBlockHeight: Int): SendPaymentResult { + val logger = MDCLogger(logger, staticMdc = request.mdc()) + val attemptNumber = failures.size + val trampolineFees = (request.trampolineFeesOverride ?: walletParams.trampolineFees)[attemptNumber] + logger.info { "trying payment with fee_base=${trampolineFees.feeBase}, fee_proportional=${trampolineFees.feeProportional}" } + val trampolineAmount = request.amount + trampolineFees.calculateFees(request.amount) + return when (val result = selectChannel(trampolineAmount, channels)) { + is Either.Left -> { + logger.warning { "payment failed: ${result.value}" } + if (attemptNumber == 0) { + db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails)) + } + db.completeOutgoingPaymentOffchain(request.paymentId, result.value) + removeFromState(request.paymentId) + Failure(request, OutgoingPaymentFailure(result.value, failures)) + } + is Either.Right -> { + val hop = NodeHop(walletParams.trampolineNode.id, request.recipient, trampolineFees.cltvExpiryDelta, trampolineFees.calculateFees(request.amount)) + val (childPayment, sharedSecrets, cmd) = createOutgoingPayment(request, result.value, hop, currentBlockHeight) + if (attemptNumber == 0) { + db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails, listOf(childPayment), LightningOutgoingPayment.Status.Pending)) + } else { + db.addOutgoingLightningParts(request.paymentId, listOf(childPayment)) + } + val payment = PaymentAttempt(request, attemptNumber, childPayment, sharedSecrets, failures) + pending[payment.request.paymentId] = payment + Progress(payment.request, payment.fees, listOf(cmd)) + } + } + } + suspend fun sendPayment(request: PayInvoice, channels: Map, currentBlockHeight: Int): SendPaymentResult { val logger = MDCLogger(logger, staticMdc = request.mdc()) logger.info { "sending ${request.amount} to ${request.recipient}" } @@ -90,24 +121,7 @@ class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle logger.error { "invoice has already been paid" } return Failure(request, FinalFailure.AlreadyPaid.toPaymentFailure()) } - val trampolineFees = (request.trampolineFeesOverride ?: walletParams.trampolineFees).first() - val trampolineAmount = request.amount + trampolineFees.calculateFees(request.amount) - return when (val result = selectChannel(trampolineAmount, channels)) { - is Either.Left -> { - logger.warning { "payment failed: ${result.value}" } - db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails)) - db.completeOutgoingPaymentOffchain(request.paymentId, result.value) - Failure(request, result.value.toPaymentFailure()) - } - is Either.Right -> { - val hop = NodeHop(walletParams.trampolineNode.id, request.recipient, trampolineFees.cltvExpiryDelta, trampolineFees.calculateFees(request.amount)) - val (childPayment, sharedSecrets, cmd) = createOutgoingPayment(request, result.value, hop, currentBlockHeight) - val payment = PaymentAttempt(request, 0, childPayment, sharedSecrets, listOf()) - db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails, listOf(childPayment), LightningOutgoingPayment.Status.Pending)) - pending[request.paymentId] = payment - Progress(request, payment.fees, listOf(cmd)) - } - } + return sendPaymentInternal(request, listOf(), channels, currentBlockHeight) } /** @@ -182,31 +196,7 @@ class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle Failure(payment.request, OutgoingPaymentFailure(finalError, payment.failures + failure)) } else { // The trampoline node is asking us to retry the payment with more fees or a larger expiry delta. - val nextFees = trampolineFees[payment.attemptNumber + 1] - logger.info { "retrying payment with higher fees (base=${nextFees.feeBase}, proportional=${nextFees.feeProportional})..." } - val trampolineAmount = payment.request.amount + nextFees.calculateFees(payment.request.amount) - when (val result = selectChannel(trampolineAmount, channels)) { - is Either.Left -> { - logger.warning { "payment failed: ${result.value}" } - db.completeOutgoingPaymentOffchain(payment.request.paymentId, result.value) - removeFromState(payment.request.paymentId) - Failure(payment.request, OutgoingPaymentFailure(result.value, payment.failures + failure)) - } - is Either.Right -> { - val hop = NodeHop(walletParams.trampolineNode.id, payment.request.recipient, nextFees.cltvExpiryDelta, nextFees.calculateFees(payment.request.amount)) - val (childPayment, sharedSecrets, cmd) = createOutgoingPayment(payment.request, result.value, hop, currentBlockHeight) - db.addOutgoingLightningParts(payment.request.paymentId, listOf(childPayment)) - val payment1 = PaymentAttempt( - request = payment.request, - attemptNumber = payment.attemptNumber + 1, - pending = childPayment, - sharedSecrets = sharedSecrets, - failures = payment.failures + failure - ) - pending[payment1.request.paymentId] = payment1 - Progress(payment1.request, payment1.fees, listOf(cmd)) - } - } + sendPaymentInternal(payment.request, payment.failures + failure, channels, currentBlockHeight) } }