Skip to content

Commit

Permalink
Merge pull request #108 from 0x4b53/update-ack-fn
Browse files Browse the repository at this point in the history
Add `OnErrFunc` which is more generic in how ack failures are handled
  • Loading branch information
akarl authored May 20, 2024
2 parents a0f116d + 0fb00d3 commit 7141a08
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 15 deletions.
46 changes: 38 additions & 8 deletions middleware/ack.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,46 @@ package middleware
import (
"context"

amqp "github.com/rabbitmq/amqp091-go"

amqprpc "github.com/0x4b53/amqp-rpc/v3"
amqp "github.com/rabbitmq/amqp091-go"
)

// OnErrFunc is the function that will be called when the middleware get an
// error from `Ack`. The error and the delivery will be passed.
type OnErrFunc func(err error, delivery amqp.Delivery)

// OnAckErrorLog is a built-in function that will log the error if any is
// returned from `Ack`.
//
// middleware := AckDelivery(OnAckErrorLog(log.Printf))
func OnAckErrorLog(logFn amqprpc.LogFunc) OnErrFunc {
return func(err error, delivery amqp.Delivery) {
logFn("could not ack delivery (%s): %v\n", delivery.CorrelationId, err)
}
}

// OnAckErrorSendOnChannel will first log the error and correlation ID and then
// try to send on the passed channel. If no one is consuming on the passed
// channel the middleware will not block but instead log a message about missing
// channel consumers.
func OnAckErrorSendOnChannel(logFn amqprpc.LogFunc, ch chan struct{}) OnErrFunc {
logErr := OnAckErrorLog(logFn)

return func(err error, delivery amqp.Delivery) {
logErr(err, delivery)

select {
case ch <- struct{}{}:
default:
logFn("ack middleware: could not send on channel, no one is consuming\n")
}
}
}

// AckDelivery is a middleware that will acknowledge the delivery after the
// handler has been executed. Any error returned from d.Ack will be passed
// to the provided logFunc.
func AckDelivery(logFunc amqprpc.LogFunc) amqprpc.ServerMiddlewareFunc {
// handler has been executed. If the Ack fails the error and the `amqp.Delivery`
// will be passed to the `OnErrFunc`.
func AckDelivery(onErrFn OnErrFunc) amqprpc.ServerMiddlewareFunc {
return func(next amqprpc.HandlerFunc) amqprpc.HandlerFunc {
return func(ctx context.Context, rw *amqprpc.ResponseWriter, d amqp.Delivery) {
acknowledger := amqprpc.NewAwareAcknowledger(d.Acknowledger)
Expand All @@ -23,9 +54,8 @@ func AckDelivery(logFunc amqprpc.LogFunc) amqprpc.ServerMiddlewareFunc {
return
}

err := d.Ack(false)
if err != nil {
logFunc("could not Ack delivery (%s): %v", d.CorrelationId, err)
if err := d.Ack(false); err != nil {
onErrFn(err, d)
}
}
}
Expand Down
51 changes: 44 additions & 7 deletions middleware/ack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package middleware

import (
"context"
"errors"
"log"
"sync/atomic"
"testing"
"time"

amqp "github.com/rabbitmq/amqp091-go"
"github.com/stretchr/testify/assert"
Expand All @@ -13,32 +16,66 @@ import (

func TestAckDelivery(t *testing.T) {
tests := []struct {
handler amqprpc.HandlerFunc
name string
handler amqprpc.HandlerFunc
name string
ackReturn error
didSendOnChannel bool
}{
{
name: "handler doesn't ack",
handler: func(_ context.Context, _ *amqprpc.ResponseWriter, _ amqp.Delivery) {},
name: "handler doesn't ack",
handler: func(_ context.Context, _ *amqprpc.ResponseWriter, _ amqp.Delivery) {},
ackReturn: nil,
},
{
name: "handler does ack",
handler: func(_ context.Context, _ *amqprpc.ResponseWriter, d amqp.Delivery) {
_ = d.Ack(false)
},
ackReturn: nil,
},
{
name: "handler fails to ack",
handler: func(_ context.Context, _ *amqprpc.ResponseWriter, _ amqp.Delivery) {},
ackReturn: errors.New("issue in the multiplexer"), //nolint:err113 // Just a test
didSendOnChannel: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
acknowledger := &amqprpc.MockAcknowledger{}
acknowledger := &amqprpc.MockAcknowledger{
OnAckFn: func() error {
return tt.ackReturn
},
}

didSendOnCh := atomic.Bool{}
didSendOnCh.Store(false)

// We setup a channel to ensure we don't proceed until we started
// the go routine that will listen to the signal.
isListening := make(chan struct{})

ch := make(chan struct{})
go func() {
close(isListening)
<-ch
didSendOnCh.Store(true)
}()

// Block until ready.
<-isListening

handler := AckDelivery(log.Printf)(tt.handler)
handler := AckDelivery(OnAckErrorSendOnChannel(log.Printf, ch))(tt.handler)

rw := amqprpc.NewResponseWriter(&amqp.Publishing{})
d := amqp.Delivery{Acknowledger: acknowledger}
d := amqp.Delivery{Acknowledger: acknowledger, CorrelationId: "id-1234"}

handler(context.Background(), rw, d)

assert.Equal(t, 1, acknowledger.Acks)
assert.Eventually(t, func() bool {
return didSendOnCh.Load() == tt.didSendOnChannel
}, 2*time.Second, 100*time.Millisecond)
})
}
}
6 changes: 6 additions & 0 deletions testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@ type MockAcknowledger struct {
Acks int
Nacks int
Rejects int
OnAckFn func() error
}

// Ack increases Acks.
func (ma *MockAcknowledger) Ack(_ uint64, _ bool) error {
ma.Acks++

if ma.OnAckFn != nil {
return ma.OnAckFn()
}

return nil
}

Expand Down

0 comments on commit 7141a08

Please sign in to comment.