Skip to content

Commit

Permalink
Merge pull request #109 from 0x4b53/restart-server-ch
Browse files Browse the repository at this point in the history
Add support to simpler restart server by using a restart channel
  • Loading branch information
bombsimon authored May 15, 2024
2 parents a5ac1c1 + 91a1768 commit a0f116d
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 15 deletions.
3 changes: 2 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ func (c *Client) runOnce() error {

go c.runPublisher(outputCh)

err = monitorAndWait(
_, err = monitorAndWait(
make(chan struct{}),
c.stopChan,
inputConn.NotifyClose(make(chan *amqp.Error)),
outputConn.NotifyClose(make(chan *amqp.Error)),
Expand Down
8 changes: 5 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ type PublishSettings struct {
ConfirmMode bool
}

func monitorAndWait(stopChan chan struct{}, amqpErrs ...chan *amqp.Error) error {
func monitorAndWait(restartChan, stopChan chan struct{}, amqpErrs ...chan *amqp.Error) (bool, error) {
result := make(chan error, len(amqpErrs))

// Setup monitoring for connections and channels, can be several connections and several channels.
Expand All @@ -121,9 +121,11 @@ func monitorAndWait(stopChan chan struct{}, amqpErrs ...chan *amqp.Error) error

select {
case err := <-result:
return err
return true, err
case <-restartChan:
return true, nil
case <-stopChan:
return nil
return false, nil
}
}

Expand Down
76 changes: 65 additions & 11 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ type Server struct {
// channel will be closed when Stop() is called.
stopChan chan struct{}

// restartChan channel is used to signal restarts. It can be set by the user
// so they can restart the server without having to call Stop()/Start()
restartChan chan struct{}

// isRunning is 1 when the server is running.
isRunning int32

Expand All @@ -94,6 +98,9 @@ func NewServer(url string) *Server {
errorLog: log.Printf, // use the standard logger default.
//nolint:revive // Keep variables for clarity
debugLog: func(format string, args ...interface{}) {}, // don't print anything default.
// We ensure to always create a channel so we can call `Restart` without
// blocking.
restartChan: make(chan struct{}),
}

server.setDefaults()
Expand Down Expand Up @@ -190,6 +197,14 @@ func (s *Server) WithDebugLogger(f LogFunc) *Server {
return s
}

// WithRestartChan will add a channel to the server that will trigger a restart
// when it's triggered.
func (s *Server) WithRestartChan(ch chan struct{}) *Server {
s.restartChan = ch

return s
}

// AddMiddleware will add a ServerMiddleware to the list of middlewares to be
// triggered before the handle func for each request.
func (s *Server) AddMiddleware(m ServerMiddlewareFunc) *Server {
Expand Down Expand Up @@ -241,7 +256,7 @@ func (s *Server) ListenAndServe() {
}

for {
err := s.listenAndServe()
shouldRestart, err := s.listenAndServe()
// If we couldn't run listenAndServe and an error was returned, make
// sure to check if the stopChan was closed - a user might know about
// connection problems and have call Stop(). If the channel isn't
Expand All @@ -259,6 +274,17 @@ func (s *Server) ListenAndServe() {
}
}

if shouldRestart {
// We must set up responses again. It's required to close to shut
// down the responders so we let the shutdown process close it and
// then we re-create it here.
s.responses = make(chan processedRequest)

s.debugLog("server: listener restarting")

continue
}

s.debugLog("server: listener exiting gracefully")

break
Expand All @@ -267,7 +293,7 @@ func (s *Server) ListenAndServe() {
atomic.StoreInt32(&s.isRunning, 0)
}

func (s *Server) listenAndServe() error {
func (s *Server) listenAndServe() (bool, error) {
s.debugLog("server: starting listener: %s", s.url)

// We are using two different connections here because:
Expand All @@ -277,15 +303,15 @@ func (s *Server) listenAndServe() error {
// -- https://godoc.org/github.com/rabbitmq/amqp091-go#Channel.Consume
inputConn, outputConn, err := createConnections(s.url, s.dialconfig)
if err != nil {
return err
return false, err
}

defer inputConn.Close()
defer outputConn.Close()

inputCh, outputCh, err := createChannels(inputConn, outputConn)
if err != nil {
return err
return false, err
}

defer inputCh.Close()
Expand All @@ -297,7 +323,7 @@ func (s *Server) listenAndServe() error {
false,
)
if err != nil {
return err
return false, err
}

// Notify everyone that the server has started.
Expand All @@ -312,7 +338,7 @@ func (s *Server) listenAndServe() error {
// cancel our consumers.
consumerTags, err := s.startConsumers(inputCh, &consumersWg)
if err != nil {
return err
return false, err
}

// This WaitGroup will reach 0 when the responder() has finished sending
Expand All @@ -322,23 +348,28 @@ func (s *Server) listenAndServe() error {

go s.responder(outputCh, &responderWg)

err = monitorAndWait(
shouldRestart, err := monitorAndWait(
s.restartChan,
s.stopChan,
inputConn.NotifyClose(make(chan *amqp.Error)),
outputConn.NotifyClose(make(chan *amqp.Error)),
inputCh.NotifyClose(make(chan *amqp.Error)),
outputCh.NotifyClose(make(chan *amqp.Error)),
)
if err != nil {
return err
return shouldRestart, err
}

s.debugLog("server: gracefully shutting down")
if shouldRestart {
s.debugLog("server: restarting server")
} else {
s.debugLog("server: gracefully shutting down")
}

// 1. Tell amqp we want to shut down by canceling all the consumers.
err = cancelConsumers(inputCh, consumerTags)
if err != nil {
return err
return shouldRestart, err
}

// 3. We've told amqp to stop delivering messages, now we wait for all
Expand All @@ -355,7 +386,7 @@ func (s *Server) listenAndServe() error {
// 5. We have no more messages incoming and we've published all our
// responses. The closing of connections and channels are deferred so we can
// just return now.
return nil
return shouldRestart, nil
}

func (s *Server) startConsumers(inputCh *amqp.Channel, wg *sync.WaitGroup) ([]string, error) {
Expand Down Expand Up @@ -499,6 +530,29 @@ func (s *Server) Stop() {
close(s.stopChan)
}

// Restart will gracefully disconnect from AMQP exactly like `Stop` but instead
// of returning from `ListenAndServe` it will set everything up again from
// scratch and start listening again. This can be useful if a server restart is
// wanted without running `ListenAndServe` in a loop.
func (s *Server) Restart() {
// Restart is noop if not running.
if atomic.LoadInt32(&s.isRunning) == 0 {
return
}

// Ensure we never block on the restartChan, if we're in the middle of a
// setup or teardown process we won't be listening on this channel and if so
// we do a noop.
// This can likely happen e.g. if you have multiple messages in memory and
// acknowledging them stops working, you might call `Restart` on all of them
// but only the first one should trigger the restart.
select {
case s.restartChan <- struct{}{}:
default:
s.debugLog("server: no listener on restartChan, ensure server is running")
}
}

// cancelConsumers will cancel the specified consumers.
func cancelConsumers(channel *amqp.Channel, consumerTags []string) error {
for _, consumerTag := range consumerTags {
Expand Down
47 changes: 47 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package amqprpc
import (
"context"
"fmt"
"log"
"testing"
"time"

Expand Down Expand Up @@ -151,6 +152,52 @@ func TestServerReconnect(t *testing.T) {
assert.Equal(t, []byte("Hello"), reply.Body)
}

func TestManualRestart(t *testing.T) {
hasStarted := make(chan struct{})
restartChan := make(chan struct{})

s := NewServer(testURL).
WithRestartChan(restartChan).
WithDebugLogger(log.Printf).
WithAutoAck(false)

s.OnStarted(func(_, _ *amqp.Connection, _, _ *amqp.Channel) {
hasStarted <- struct{}{}
})

s.Bind(DirectBinding("myqueue", func(_ context.Context, rw *ResponseWriter, d amqp.Delivery) {
_ = d.Ack(false)

fmt.Fprintf(rw, "Hello")
}))

// Wait for the initial startup signal.
go func() { <-hasStarted }()

stop := startAndWait(s)
defer stop()

c := NewClient(testURL)
defer c.Stop()

request := NewRequest().WithRoutingKey("myqueue")
reply, err := c.Send(request)
require.NoError(t, err)
assert.Equal(t, []byte("Hello"), reply.Body)

// We only care about one restart but let's call multiple ones to ensure
// we're not blocking.
s.Restart()
s.Restart()
s.Restart()
<-hasStarted

request = NewRequest().WithRoutingKey("myqueue")
reply, err = c.Send(request)
require.NoError(t, err)
assert.Equal(t, []byte("Hello"), reply.Body)
}

func TestServerOnStarted(t *testing.T) {
errs := make(chan string, 4)

Expand Down

0 comments on commit a0f116d

Please sign in to comment.