Skip to content

Commit

Permalink
Use context to implement --time
Browse files Browse the repository at this point in the history
Terminate connections gracefully (more or less at least)
Handle ctrl-c better (context canel)
  • Loading branch information
mkuratczyk committed Oct 11, 2023
1 parent c0960c9 commit 5694768
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 158 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ jobs:
run: OMQ_RABBITMQCTL=DOCKER:${{job.services.rabbitmq.id}} bin/ci/before_build.sh

- name: Run go test
run: go test ./...
run: go test -count=1 -p 1 -v ./...
37 changes: 29 additions & 8 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package cmd

import (
"context"
"fmt"
"math"
"os"
"os/signal"
"strings"
"sync"
"syscall"
Expand All @@ -12,6 +14,7 @@ import (
"github.com/rabbitmq/omq/pkg/common"
"github.com/rabbitmq/omq/pkg/config"
"github.com/rabbitmq/omq/pkg/log"
"github.com/rabbitmq/omq/pkg/metrics"
"github.com/rabbitmq/omq/pkg/version"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -178,12 +181,19 @@ func RootCmd() *cobra.Command {
func start(cfg config.Config, publisherProto common.Protocol, consumerProto common.Protocol) {
var wg sync.WaitGroup

if cfg.Duration > 0 {
go func() {
time.Sleep(cfg.Duration)
_ = syscall.Kill(syscall.Getpid(), syscall.SIGINT)
}()
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// handle ^C
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
cancel()
println("Received SIGTERM, shutting down...")
time.Sleep(500 * time.Millisecond)
shutdown()
}()

if cfg.Consumers > 0 {
for i := 1; i <= cfg.Consumers; i++ {
Expand All @@ -197,7 +207,7 @@ func start(cfg config.Config, publisherProto common.Protocol, consumerProto comm
log.Error("Error creating consumer: ", "error", err)
os.Exit(1)
}
c.Start(subscribed)
c.Start(ctx, subscribed)
}()

// wait until we know the receiver has subscribed
Expand All @@ -216,11 +226,16 @@ func start(cfg config.Config, publisherProto common.Protocol, consumerProto comm
log.Error("Error creating publisher: ", "error", err)
os.Exit(1)
}
p.Start()
p.Start(ctx)
}()
}
}

if cfg.Duration > 0 {
log.Debug("Will stop all consumers and publishers at " + time.Now().Add(cfg.Duration).String())
time.AfterFunc(cfg.Duration, func() { cancel() })
}
log.Info("Waiting for all publishers and consumers to complete")
wg.Wait()
}

Expand Down Expand Up @@ -250,3 +265,9 @@ func defaultUri(proto string) string {
}
return uri
}

func shutdown() {
metricsServer := metrics.GetMetricsServer()
metricsServer.PrintMetrics()
os.Exit(1)
}
16 changes: 0 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package main

import (
"os"
"os/signal"
"runtime/pprof"
"syscall"

"github.com/rabbitmq/omq/cmd"
"github.com/rabbitmq/omq/pkg/log"
Expand All @@ -24,14 +22,6 @@ func main() {
metricsServer := metrics.GetMetricsServer()
metricsServer.Start()

// handle ^C
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
shutdown()
}()

cmd.Execute()
metricsServer.PrintMetrics()

Expand All @@ -44,9 +34,3 @@ func main() {
defer memFile.Close()
}
}

func shutdown() {
metricsServer := metrics.GetMetricsServer()
metricsServer.PrintMetrics()
os.Exit(1)
}
67 changes: 41 additions & 26 deletions pkg/amqp10_client/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import (
)

type Amqp10Consumer struct {
Id int
Session *amqp.Session
Topic string
Config config.Config
Id int
Connection *amqp.Conn
Session *amqp.Session
Topic string
Config config.Config
}

func NewConsumer(cfg config.Config, id int) *Amqp10Consumer {
Expand All @@ -39,15 +40,16 @@ func NewConsumer(cfg config.Config, id int) *Amqp10Consumer {
topic := topic.CalculateTopic(cfg.ConsumeFrom, id)

return &Amqp10Consumer{
Id: id,
Session: session,
Topic: topic,
Config: cfg,
Id: id,
Connection: conn,
Session: session,
Topic: topic,
Config: cfg,
}

}

func (c Amqp10Consumer) Start(subscribed chan bool) {
func (c Amqp10Consumer) Start(ctx context.Context, subscribed chan bool) {
var durability amqp.Durability
switch c.Config.QueueDurability {
case config.None:
Expand All @@ -63,31 +65,44 @@ func (c Amqp10Consumer) Start(subscribed chan bool) {
return
}
close(subscribed)
log.Debug("consumer subscribed", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic, "durability", durability)
log.Debug("consumer subscribed", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic, "durability", durability)

m := metrics.EndToEndLatency.With(prometheus.Labels{"protocol": "amqp-1.0"})

log.Info("consumer started", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic)

for i := 1; i <= c.Config.ConsumeCount; i++ {
msg, err := receiver.Receive(context.TODO(), nil)
if err != nil {
log.Error("failed to receive a message", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic)
// TODO Receive() is blocking, so cancelling the context won't really stop the consumer
select {
case <-ctx.Done():
c.Stop("time limit reached")
return
default:
msg, err := receiver.Receive(context.TODO(), nil)
if err != nil {
log.Error("failed to receive a message", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic)
return
}

payload := msg.GetData()
m.Observe(utils.CalculateEndToEndLatency(c.Config.UseMillis, &payload))

log.Debug("message received", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic, "size", len(payload))

err = receiver.AcceptMessage(context.TODO(), msg)
if err != nil {
log.Error("message NOT accepted", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic)
}
metrics.MessagesConsumed.With(prometheus.Labels{"protocol": "amqp-1.0"}).Inc()
log.Debug("message accepted", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic)
}

payload := msg.GetData()
m.Observe(utils.CalculateEndToEndLatency(c.Config.UseMillis, &payload))

log.Debug("message received", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic, "size", len(payload))

err = receiver.AcceptMessage(context.TODO(), msg)
if err != nil {
log.Error("message NOT accepted", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic)
}
metrics.MessagesConsumed.With(prometheus.Labels{"protocol": "amqp-1.0"}).Inc()
log.Debug("message accepted", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic)
}

log.Debug("consumer finished", "protocol", "amqp-1.0", "subscriberId", c.Id)
c.Stop("message count reached")
log.Debug("consumer finished", "protocol", "amqp-1.0", "consumerId", c.Id)
}

func (c Amqp10Consumer) Stop(reason string) {
log.Debug("closing connection", "protocol", "amqp-1.0", "consumerId", c.Id, "reason", reason)
_ = c.Connection.Close()
}
73 changes: 40 additions & 33 deletions pkg/amqp10_client/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (
)

type Amqp10Publisher struct {
Id int
Sender *amqp.Sender
Topic string
Config config.Config
msg []byte
Id int
Sender *amqp.Sender
Connectionection *amqp.Conn
Topic string
Config config.Config
msg []byte
}

func NewPublisher(cfg config.Config, n int) *Amqp10Publisher {
Expand Down Expand Up @@ -49,68 +50,69 @@ func NewPublisher(cfg config.Config, n int) *Amqp10Publisher {
durability = amqp.DurabilityUnsettledState
}

topic := topic.CalculateTopic(cfg.PublishTo, n)
sender, err := session.NewSender(context.TODO(), topic, &amqp.SenderOptions{
terminus := topic.CalculateTopic(cfg.PublishTo, n)
sender, err := session.NewSender(context.TODO(), terminus, &amqp.SenderOptions{
TargetDurability: durability})
if err != nil {
log.Error("publisher failed to create a sender", "protocol", "amqp-1.0", "publisherId", n, "error", err.Error())
return nil
}

return &Amqp10Publisher{
Id: n,
Sender: sender,
Topic: topic,
Config: cfg,
Id: n,
Connectionection: conn,
Sender: sender,
Topic: terminus,
Config: cfg,
}
}

func (p Amqp10Publisher) Start() {
func (p Amqp10Publisher) Start(ctx context.Context) {
// sleep random interval to avoid all publishers publishing at the same time
s := rand.Intn(1000)
time.Sleep(time.Duration(s) * time.Millisecond)

p.msg = utils.MessageBody(p.Config.Size)

if p.Config.Rate == -1 {
p.StartFullSpeed()
p.StartFullSpeed(ctx)
} else {
p.StartRateLimited()
p.StartRateLimited(ctx)
}
log.Debug("publisher completed", "protocol", "amqp-1.0", "publisherId", p.Id)
}

func (p Amqp10Publisher) StartFullSpeed() {
func (p Amqp10Publisher) StartFullSpeed(ctx context.Context) {
log.Info("publisher started", "protocol", "AMQP-1.0", "publisherId", p.Id, "rate", "unlimited", "destination", p.Topic)

for i := 1; i <= p.Config.PublishCount; i++ {
p.Send()
select {
case <-ctx.Done():
return
default:
p.Send()
}
}

log.Debug("publisher stopped", "protocol", "amqp-1.0", "publisherId", p.Id)
}

func (p Amqp10Publisher) StartRateLimited() {
func (p Amqp10Publisher) StartRateLimited(ctx context.Context) {
log.Info("publisher started", "protocol", "AMQP-1.0", "publisherId", p.Id, "rate", p.Config.Rate, "destination", p.Topic)
ticker := time.NewTicker(time.Duration(1000/float64(p.Config.Rate)) * time.Millisecond)
done := make(chan bool)

msgSent := 0
go func() {
for {
select {
case <-done:
for {
select {
case <-ctx.Done():
p.Stop("time limit reached")
return
case <-ticker.C:
p.Send()
msgSent++
if msgSent >= p.Config.PublishCount {
p.Stop("publish count reached")
return
case <-ticker.C:
p.Send()
msgSent++
}
}
}()
for {
time.Sleep(1 * time.Second)
if msgSent >= p.Config.PublishCount {
break
}
}
}

Expand All @@ -132,3 +134,8 @@ func (p Amqp10Publisher) Send() {
metrics.MessagesPublished.With(prometheus.Labels{"protocol": "amqp-1.0"}).Inc()
log.Debug("message sent", "protocol", "amqp-1.0", "publisherId", p.Id)
}

func (p Amqp10Publisher) Stop(reason string) {
log.Debug("closing connection", "protocol", "amqp-1.0", "publisherId", p.Id, "reason", reason)
p.Connectionection.Close()
}
5 changes: 3 additions & 2 deletions pkg/common/common.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"fmt"

"github.com/rabbitmq/omq/pkg/amqp10_client"
Expand All @@ -10,11 +11,11 @@ import (
)

type Publisher interface {
Start()
Start(context.Context)
}

type Consumer interface {
Start(chan bool)
Start(context.Context, chan bool)
}

type Protocol int
Expand Down
Loading

0 comments on commit 5694768

Please sign in to comment.