Skip to content

Commit

Permalink
Merge pull request #87 from 0x4b53/ctx-done
Browse files Browse the repository at this point in the history
HandlerFunc ctx is .Done() when server is stopped
  • Loading branch information
akarl authored Aug 28, 2020
2 parents da6d5d5 + c06b33d commit cbae93e
Show file tree
Hide file tree
Showing 13 changed files with 90 additions and 38 deletions.
7 changes: 6 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ linters-settings:
include-go-root: false
packages:
- github.com/davecgh/go-spew/spew
- github.com/c2fo/testify
misspell:
# Correct spellings using locale preferences for US or UK.
# Default is to use a neutral variety of English.
Expand Down Expand Up @@ -146,13 +147,17 @@ linters:
enable-all: true
disable:
- dupl
- funlen
- gocyclo
- gomnd
- lll
- maligned
- nakedret
- nlreturn
- noctx
- prealloc
- scopelint
- funlen
- testpackage
disable-all: false
# presets:
# - bugs
Expand Down
2 changes: 1 addition & 1 deletion acknowledger.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (a *AwareAcknowledger) Ack(tag uint64, multiple bool) error {
}

// Nack passes the Nack down to the underlying Acknowledger.
func (a *AwareAcknowledger) Nack(tag uint64, multiple bool, requeue bool) error {
func (a *AwareAcknowledger) Nack(tag uint64, multiple, requeue bool) error {
a.Handled = true
return a.Acknowledger.Nack(tag, multiple, requeue)
}
Expand Down
4 changes: 1 addition & 3 deletions bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ import (
)

func TestFanout(t *testing.T) {
var (
timesCalled int64
)
var timesCalled int64

fanoutHandler := func(ctx context.Context, rw *ResponseWriter, d amqp.Delivery) {
atomic.AddInt64(&timesCalled, 1)
Expand Down
3 changes: 0 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ func (c *Client) runPublisher(ouputChan *amqp.Channel) {
c.publishSettings.Immediate,
request.Publishing,
)

if err != nil {
ouputChan.Close()

Expand Down Expand Up @@ -657,7 +656,6 @@ func (c *Client) runRepliesConsumer(inChan *amqp.Channel) error {
false, // no-wait.
c.queueDeclareSettings.Args,
)

if err != nil {
return err
}
Expand All @@ -671,7 +669,6 @@ func (c *Client) runRepliesConsumer(inChan *amqp.Channel) error {
false, // no-wait.
c.consumeSettings.Args,
)

if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ func TestClient_ConfirmsConsumer_return(t *testing.T) {
})
}
}

func TestClient_ConfirmsConsumer_confirm(t *testing.T) {
client := NewClient("")
client.requestsMap = RequestMap{
Expand Down
10 changes: 4 additions & 6 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ import (
"github.com/streadway/amqp"
)

var (
// ErrUnexpectedConnClosed is returned by ListenAndServe() if the server
// shuts down without calling Stop() and if AMQP does not give an error
// when said shutdown happens.
ErrUnexpectedConnClosed = errors.New("unexpected connection close without specific error")
)
// ErrUnexpectedConnClosed is returned by ListenAndServe() if the server
// shuts down without calling Stop() and if AMQP does not give an error
// when said shutdown happens.
var ErrUnexpectedConnClosed = errors.New("unexpected connection close without specific error")

// OnStartedFunc can be registered at Server.OnStarted(f) and
// Client.OnStarted(f). This is used when you want to do more setup on the
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.1.1
github.com/kr/pretty v0.1.0 // indirect
github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271
github.com/streadway/amqp v1.0.0
github.com/stretchr/testify v1.4.0
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.4 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271 h1:WhxRHzgeVGETMlmVfqhRn8RIeeNoPr2Czh33I4Zdccw=
github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/amqp v1.0.0 h1:kuuDrUJFZL1QYL9hUNuCxNObNzB0bV/ZG5jV3RWAQgo=
github.com/streadway/amqp v1.0.0/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
Expand Down
2 changes: 1 addition & 1 deletion request.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/streadway/amqp"
)

// Request is a requet to perform with the client
// Request is a requet to perform with the client.
type Request struct {
// Exchange is the exchange to which the rquest will be published when
// passing it to the clients send function.
Expand Down
46 changes: 32 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,21 @@ type Server struct {
// print most of what is happening internally.
// If nil, logging is not done.
debugLog LogFunc

baseContext context.Context
baseContextCancel context.CancelFunc
}

// NewServer will return a pointer to a new Server.
func NewServer(url string) *Server {
baseContext, cancelFunc := context.WithCancel(context.Background())

server := Server{
url: url,
bindings: []HandlerBinding{},
middlewares: []ServerMiddlewareFunc{},
url: url,
bindings: []HandlerBinding{},
middlewares: []ServerMiddlewareFunc{},
baseContext: baseContext,
baseContextCancel: cancelFunc,
dialconfig: amqp.Config{
Dial: DefaultDialer,
},
Expand Down Expand Up @@ -125,7 +132,7 @@ func (s *Server) setDefaults() {
}

// WithExchangeDeclareSettings sets configuration used when the server wants
// to declare exchanges. Default settings are:
// to declare exchanges.
func (s *Server) WithExchangeDeclareSettings(settings ExchangeDeclareSettings) *Server {
s.exchangeDeclareSettings = settings

Expand Down Expand Up @@ -242,7 +249,6 @@ func (s *Server) ListenAndServe() {

for {
err := s.listenAndServe()

// If we couldn't run listenAndServe and an error was returned, make
// sure to check if the stopChan was closed - a user might know about
// connection problems and have call Stop(). If the channel isn't
Expand Down Expand Up @@ -342,18 +348,22 @@ func (s *Server) listenAndServe() error {
return err
}

// 2. We've told amqp to stop delivering messages, now we wait for all
// 2. Tell all handlers that we are stopping, in case they have any long
// running functions.
s.baseContextCancel()

// 3. We've told amqp to stop delivering messages, now we wait for all
// the consumers to finish inflight messages.
consumersWg.Done()
consumersWg.Wait()

// 3. Close the responses chan and wait until the consumers are finished.
// 4. Close the responses chan and wait until the consumers are finished.
// We might still have responses we want to send.
close(s.responses)
responderWg.Done()
responderWg.Wait()

// 4. We have no more messages incoming and we've published all our
// 5. We have no more messages incoming and we've published all our
// responses. The closing of connections and channels are deferred so we can
// just return now.
return nil
Expand Down Expand Up @@ -381,6 +391,7 @@ func (s *Server) consume(binding HandlerBinding, inputCh *amqp.Channel, wg *sync
}

consumerTag := uuid.New().String()

deliveries, err := inputCh.Consume(
queueName,
consumerTag,
Expand All @@ -390,7 +401,6 @@ func (s *Server) consume(binding HandlerBinding, inputCh *amqp.Channel, wg *sync
false, // no-wait.
s.consumeSettings.Args,
)

if err != nil {
return "", err
}
Expand All @@ -403,7 +413,12 @@ func (s *Server) consume(binding HandlerBinding, inputCh *amqp.Channel, wg *sync
return consumerTag, nil
}

func (s *Server) runHandler(handler HandlerFunc, deliveries <-chan amqp.Delivery, queueName string, wg *sync.WaitGroup) {
func (s *Server) runHandler(
handler HandlerFunc,
deliveries <-chan amqp.Delivery,
queueName string,
wg *sync.WaitGroup,
) {
wg.Add(1)
defer wg.Done()

Expand All @@ -425,7 +440,7 @@ func (s *Server) runHandler(handler HandlerFunc, deliveries <-chan amqp.Delivery
},
}

ctx := context.WithValue(context.Background(), CtxQueueName, queueName)
ctx := context.WithValue(s.baseContext, CtxQueueName, queueName)

go func(delivery amqp.Delivery) {
handler(ctx, &rw, delivery)
Expand Down Expand Up @@ -464,7 +479,6 @@ func (s *Server) responder(outCh *amqp.Channel, wg *sync.WaitGroup) {
response.immediate,
response.publishing,
)

if err != nil {
// Close the channel so ensure reconnect.
outCh.Close()
Expand Down Expand Up @@ -507,7 +521,12 @@ func cancelConsumers(channel *amqp.Channel, consumerTags []string) error {

// declareAndBind will declare a queue, an exchange and the queue to the
// exchange.
func declareAndBind(inputCh *amqp.Channel, binding HandlerBinding, queueDeclareSettings QueueDeclareSettings, exchangeDeclareSettings ExchangeDeclareSettings) (string, error) {
func declareAndBind(
inputCh *amqp.Channel,
binding HandlerBinding,
queueDeclareSettings QueueDeclareSettings,
exchangeDeclareSettings ExchangeDeclareSettings,
) (string, error) {
queue, err := inputCh.QueueDeclare(
binding.QueueName,
queueDeclareSettings.Durable,
Expand All @@ -516,7 +535,6 @@ func declareAndBind(inputCh *amqp.Channel, binding HandlerBinding, queueDeclareS
false, // no-wait.
queueDeclareSettings.Args,
)

if err != nil {
return "", err
}
Expand Down
42 changes: 39 additions & 3 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ func TestNoAutomaticAck(t *testing.T) {

calls := make(chan struct{}, 2)

server.Bind(DirectBinding("no-auto-ack", func(ctc context.Context, responseWriter *ResponseWriter, d amqp.Delivery) {
calls <- struct{}{}
}))
server.Bind(
DirectBinding("no-auto-ack", func(ctc context.Context, responseWriter *ResponseWriter, d amqp.Delivery) {
calls <- struct{}{}
}),
)

start()

Expand Down Expand Up @@ -234,3 +236,37 @@ func TestServerConfig(t *testing.T) {
assert.Equal(t, s.consumeSettings, cSettings)
assert.Equal(t, s.exchangeDeclareSettings, eSettings)
}

func TestContextDoneWhenServerStopped(t *testing.T) {
server, client, start, stop := initTest()

ctxDone := make(chan bool, 1)

server.Bind(DirectBinding("context.test", func(ctx context.Context, rw *ResponseWriter, d amqp.Delivery) {
select {
case <-ctx.Done():
ctxDone <- true
case <-time.After(5 * time.Second):
ctxDone <- false
}
}))

start()

_, err := client.Send(
NewRequest().
WithRoutingKey("context.test").
WithResponse(false),
)

require.NoError(t, err)

stop()

select {
case wasDone := <-ctxDone:
assert.True(t, wasDone)
case <-time.After(10 * time.Second):
t.Fatalf("handler was never called")
}
}
3 changes: 1 addition & 2 deletions testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (ma *MockAcknowledger) Ack(tag uint64, multiple bool) error {
}

// Nack increases Nacks.
func (ma *MockAcknowledger) Nack(tag uint64, multiple bool, requeue bool) error {
func (ma *MockAcknowledger) Nack(tag uint64, multiple, requeue bool) error {
ma.Nacks++
return nil
}
Expand Down Expand Up @@ -205,7 +205,6 @@ func initTest() (server *Server, client *Client, start, stop func()) {
WithRoutingKey(defaultTestQueue).
WithResponse(false),
)

if err != nil {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func createPrivKey(priv *rsa.PrivateKey) string {
f, _ := ioutil.TempFile(".", "priv*.key")
defer f.Close()

var privateKey = &pem.Block{
privateKey := &pem.Block{
Type: "PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
}
Expand Down

0 comments on commit cbae93e

Please sign in to comment.