From 5694768b86d999e4efcd5efef8d2e54e5ba70366 Mon Sep 17 00:00:00 2001 From: Michal Kuratczyk Date: Wed, 11 Oct 2023 19:35:45 +0200 Subject: [PATCH] Use context to implement --time Terminate connections gracefully (more or less at least) Handle ctrl-c better (context canel) --- .github/workflows/ci.yaml | 2 +- cmd/root.go | 37 +++++++++++++---- main.go | 16 -------- pkg/amqp10_client/consumer.go | 67 +++++++++++++++++++------------ pkg/amqp10_client/publisher.go | 73 +++++++++++++++++++--------------- pkg/common/common.go | 5 ++- pkg/mqtt_client/consumer.go | 32 ++++++++++----- pkg/mqtt_client/publisher.go | 53 +++++++++++++----------- pkg/stomp_client/consumer.go | 46 ++++++++++++++------- pkg/stomp_client/publisher.go | 52 +++++++++++++----------- 10 files changed, 225 insertions(+), 158 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4a4f606..093ca48 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -49,4 +49,4 @@ jobs: run: OMQ_RABBITMQCTL=DOCKER:${{job.services.rabbitmq.id}} bin/ci/before_build.sh - name: Run go test - run: go test ./... + run: go test -count=1 -p 1 -v ./... diff --git a/cmd/root.go b/cmd/root.go index 5203e0a..e02f542 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,9 +1,11 @@ package cmd import ( + "context" "fmt" "math" "os" + "os/signal" "strings" "sync" "syscall" @@ -12,6 +14,7 @@ import ( "github.com/rabbitmq/omq/pkg/common" "github.com/rabbitmq/omq/pkg/config" "github.com/rabbitmq/omq/pkg/log" + "github.com/rabbitmq/omq/pkg/metrics" "github.com/rabbitmq/omq/pkg/version" "github.com/spf13/cobra" @@ -178,12 +181,19 @@ func RootCmd() *cobra.Command { func start(cfg config.Config, publisherProto common.Protocol, consumerProto common.Protocol) { var wg sync.WaitGroup - if cfg.Duration > 0 { - go func() { - time.Sleep(cfg.Duration) - _ = syscall.Kill(syscall.Getpid(), syscall.SIGINT) - }() - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // handle ^C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + cancel() + println("Received SIGTERM, shutting down...") + time.Sleep(500 * time.Millisecond) + shutdown() + }() if cfg.Consumers > 0 { for i := 1; i <= cfg.Consumers; i++ { @@ -197,7 +207,7 @@ func start(cfg config.Config, publisherProto common.Protocol, consumerProto comm log.Error("Error creating consumer: ", "error", err) os.Exit(1) } - c.Start(subscribed) + c.Start(ctx, subscribed) }() // wait until we know the receiver has subscribed @@ -216,11 +226,16 @@ func start(cfg config.Config, publisherProto common.Protocol, consumerProto comm log.Error("Error creating publisher: ", "error", err) os.Exit(1) } - p.Start() + p.Start(ctx) }() } } + if cfg.Duration > 0 { + log.Debug("Will stop all consumers and publishers at " + time.Now().Add(cfg.Duration).String()) + time.AfterFunc(cfg.Duration, func() { cancel() }) + } + log.Info("Waiting for all publishers and consumers to complete") wg.Wait() } @@ -250,3 +265,9 @@ func defaultUri(proto string) string { } return uri } + +func shutdown() { + metricsServer := metrics.GetMetricsServer() + metricsServer.PrintMetrics() + os.Exit(1) +} diff --git a/main.go b/main.go index 9908a0a..f021e6c 100644 --- a/main.go +++ b/main.go @@ -2,9 +2,7 @@ package main import ( "os" - "os/signal" "runtime/pprof" - "syscall" "github.com/rabbitmq/omq/cmd" "github.com/rabbitmq/omq/pkg/log" @@ -24,14 +22,6 @@ func main() { metricsServer := metrics.GetMetricsServer() metricsServer.Start() - // handle ^C - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - <-c - shutdown() - }() - cmd.Execute() metricsServer.PrintMetrics() @@ -44,9 +34,3 @@ func main() { defer memFile.Close() } } - -func shutdown() { - metricsServer := metrics.GetMetricsServer() - metricsServer.PrintMetrics() - os.Exit(1) -} diff --git a/pkg/amqp10_client/consumer.go b/pkg/amqp10_client/consumer.go index 9a8afb3..f10ef9f 100644 --- a/pkg/amqp10_client/consumer.go +++ b/pkg/amqp10_client/consumer.go @@ -15,10 +15,11 @@ import ( ) type Amqp10Consumer struct { - Id int - Session *amqp.Session - Topic string - Config config.Config + Id int + Connection *amqp.Conn + Session *amqp.Session + Topic string + Config config.Config } func NewConsumer(cfg config.Config, id int) *Amqp10Consumer { @@ -39,15 +40,16 @@ func NewConsumer(cfg config.Config, id int) *Amqp10Consumer { topic := topic.CalculateTopic(cfg.ConsumeFrom, id) return &Amqp10Consumer{ - Id: id, - Session: session, - Topic: topic, - Config: cfg, + Id: id, + Connection: conn, + Session: session, + Topic: topic, + Config: cfg, } } -func (c Amqp10Consumer) Start(subscribed chan bool) { +func (c Amqp10Consumer) Start(ctx context.Context, subscribed chan bool) { var durability amqp.Durability switch c.Config.QueueDurability { case config.None: @@ -63,31 +65,44 @@ func (c Amqp10Consumer) Start(subscribed chan bool) { return } close(subscribed) - log.Debug("consumer subscribed", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic, "durability", durability) + log.Debug("consumer subscribed", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic, "durability", durability) m := metrics.EndToEndLatency.With(prometheus.Labels{"protocol": "amqp-1.0"}) log.Info("consumer started", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic) for i := 1; i <= c.Config.ConsumeCount; i++ { - msg, err := receiver.Receive(context.TODO(), nil) - if err != nil { - log.Error("failed to receive a message", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic) + // TODO Receive() is blocking, so cancelling the context won't really stop the consumer + select { + case <-ctx.Done(): + c.Stop("time limit reached") return + default: + msg, err := receiver.Receive(context.TODO(), nil) + if err != nil { + log.Error("failed to receive a message", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic) + return + } + + payload := msg.GetData() + m.Observe(utils.CalculateEndToEndLatency(c.Config.UseMillis, &payload)) + + log.Debug("message received", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic, "size", len(payload)) + + err = receiver.AcceptMessage(context.TODO(), msg) + if err != nil { + log.Error("message NOT accepted", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic) + } + metrics.MessagesConsumed.With(prometheus.Labels{"protocol": "amqp-1.0"}).Inc() + log.Debug("message accepted", "protocol", "amqp-1.0", "consumerId", c.Id, "terminus", c.Topic) } - - payload := msg.GetData() - m.Observe(utils.CalculateEndToEndLatency(c.Config.UseMillis, &payload)) - - log.Debug("message received", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic, "size", len(payload)) - - err = receiver.AcceptMessage(context.TODO(), msg) - if err != nil { - log.Error("message NOT accepted", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic) - } - metrics.MessagesConsumed.With(prometheus.Labels{"protocol": "amqp-1.0"}).Inc() - log.Debug("message accepted", "protocol", "amqp-1.0", "subscriberId", c.Id, "terminus", c.Topic) } - log.Debug("consumer finished", "protocol", "amqp-1.0", "subscriberId", c.Id) + c.Stop("message count reached") + log.Debug("consumer finished", "protocol", "amqp-1.0", "consumerId", c.Id) +} + +func (c Amqp10Consumer) Stop(reason string) { + log.Debug("closing connection", "protocol", "amqp-1.0", "consumerId", c.Id, "reason", reason) + _ = c.Connection.Close() } diff --git a/pkg/amqp10_client/publisher.go b/pkg/amqp10_client/publisher.go index f9067ea..079a117 100644 --- a/pkg/amqp10_client/publisher.go +++ b/pkg/amqp10_client/publisher.go @@ -17,11 +17,12 @@ import ( ) type Amqp10Publisher struct { - Id int - Sender *amqp.Sender - Topic string - Config config.Config - msg []byte + Id int + Sender *amqp.Sender + Connectionection *amqp.Conn + Topic string + Config config.Config + msg []byte } func NewPublisher(cfg config.Config, n int) *Amqp10Publisher { @@ -49,8 +50,8 @@ func NewPublisher(cfg config.Config, n int) *Amqp10Publisher { durability = amqp.DurabilityUnsettledState } - topic := topic.CalculateTopic(cfg.PublishTo, n) - sender, err := session.NewSender(context.TODO(), topic, &amqp.SenderOptions{ + terminus := topic.CalculateTopic(cfg.PublishTo, n) + sender, err := session.NewSender(context.TODO(), terminus, &amqp.SenderOptions{ TargetDurability: durability}) if err != nil { log.Error("publisher failed to create a sender", "protocol", "amqp-1.0", "publisherId", n, "error", err.Error()) @@ -58,14 +59,15 @@ func NewPublisher(cfg config.Config, n int) *Amqp10Publisher { } return &Amqp10Publisher{ - Id: n, - Sender: sender, - Topic: topic, - Config: cfg, + Id: n, + Connectionection: conn, + Sender: sender, + Topic: terminus, + Config: cfg, } } -func (p Amqp10Publisher) Start() { +func (p Amqp10Publisher) Start(ctx context.Context) { // sleep random interval to avoid all publishers publishing at the same time s := rand.Intn(1000) time.Sleep(time.Duration(s) * time.Millisecond) @@ -73,44 +75,44 @@ func (p Amqp10Publisher) Start() { p.msg = utils.MessageBody(p.Config.Size) if p.Config.Rate == -1 { - p.StartFullSpeed() + p.StartFullSpeed(ctx) } else { - p.StartRateLimited() + p.StartRateLimited(ctx) } + log.Debug("publisher completed", "protocol", "amqp-1.0", "publisherId", p.Id) } -func (p Amqp10Publisher) StartFullSpeed() { +func (p Amqp10Publisher) StartFullSpeed(ctx context.Context) { log.Info("publisher started", "protocol", "AMQP-1.0", "publisherId", p.Id, "rate", "unlimited", "destination", p.Topic) for i := 1; i <= p.Config.PublishCount; i++ { - p.Send() + select { + case <-ctx.Done(): + return + default: + p.Send() + } } - - log.Debug("publisher stopped", "protocol", "amqp-1.0", "publisherId", p.Id) } -func (p Amqp10Publisher) StartRateLimited() { +func (p Amqp10Publisher) StartRateLimited(ctx context.Context) { log.Info("publisher started", "protocol", "AMQP-1.0", "publisherId", p.Id, "rate", p.Config.Rate, "destination", p.Topic) ticker := time.NewTicker(time.Duration(1000/float64(p.Config.Rate)) * time.Millisecond) - done := make(chan bool) msgSent := 0 - go func() { - for { - select { - case <-done: + for { + select { + case <-ctx.Done(): + p.Stop("time limit reached") + return + case <-ticker.C: + p.Send() + msgSent++ + if msgSent >= p.Config.PublishCount { + p.Stop("publish count reached") return - case <-ticker.C: - p.Send() - msgSent++ } } - }() - for { - time.Sleep(1 * time.Second) - if msgSent >= p.Config.PublishCount { - break - } } } @@ -132,3 +134,8 @@ func (p Amqp10Publisher) Send() { metrics.MessagesPublished.With(prometheus.Labels{"protocol": "amqp-1.0"}).Inc() log.Debug("message sent", "protocol", "amqp-1.0", "publisherId", p.Id) } + +func (p Amqp10Publisher) Stop(reason string) { + log.Debug("closing connection", "protocol", "amqp-1.0", "publisherId", p.Id, "reason", reason) + p.Connectionection.Close() +} diff --git a/pkg/common/common.go b/pkg/common/common.go index d73629d..c6207bf 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "github.com/rabbitmq/omq/pkg/amqp10_client" @@ -10,11 +11,11 @@ import ( ) type Publisher interface { - Start() + Start(context.Context) } type Consumer interface { - Start(chan bool) + Start(context.Context, chan bool) } type Protocol int diff --git a/pkg/mqtt_client/consumer.go b/pkg/mqtt_client/consumer.go index 6d21308..de76b5b 100644 --- a/pkg/mqtt_client/consumer.go +++ b/pkg/mqtt_client/consumer.go @@ -1,6 +1,7 @@ package mqtt_client import ( + "context" "fmt" "strings" "time" @@ -32,7 +33,7 @@ func NewConsumer(cfg config.Config, id int) *MqttConsumer { SetAutoReconnect(true). SetCleanSession(cfg.MqttConsumer.CleanSession). SetConnectionLostHandler(func(client mqtt.Client, reason error) { - log.Info("connection lost", "protocol", "MQTT", "consumerId", id) + log.Info("connection lost", "protocol", "mqtt", "consumerId", id) }). SetProtocolVersion(4) @@ -53,7 +54,7 @@ func NewConsumer(cfg config.Config, id int) *MqttConsumer { } } -func (c MqttConsumer) Start(subscribed chan bool) { +func (c MqttConsumer) Start(ctx context.Context, subscribed chan bool) { m := metrics.EndToEndLatency.With(prometheus.Labels{"protocol": "mqtt"}) msgsReceived := 0 @@ -63,23 +64,32 @@ func (c MqttConsumer) Start(subscribed chan bool) { payload := msg.Payload() m.Observe(utils.CalculateEndToEndLatency(c.Config.UseMillis, &payload)) msgsReceived++ - log.Debug("message received", "protocol", "MQTT", "subscriberc.Id", c.Id, "topic", c.Topic, "size", len(payload)) + log.Debug("message received", "protocol", "mqtt", "consumerId", c.Id, "topic", c.Topic, "size", len(payload)) } close(subscribed) token := c.Connection.Subscribe(c.Topic, byte(c.Config.MqttConsumer.QoS), handler) token.Wait() if token.Error() != nil { - log.Error("failed to subscribe", "protocol", "MQTT", "publisherc.Id", c.Id, "error", token.Error()) + log.Error("failed to subscribe", "protocol", "mqtt", "consumerId", c.Id, "error", token.Error()) } - log.Info("consumer started", "protocol", "MQTT", "publisherc.Id", c.Id, "c.Topic", c.Topic) + log.Info("consumer started", "protocol", "mqtt", "consumerId", c.Id, "c.Topic", c.Topic) + + // TODO: currently we can consume more than ConsumerCount messages + for msgsReceived <= c.Config.ConsumeCount { + select { + case <-ctx.Done(): + c.Stop("time limit reached") + return + default: + time.Sleep(1 * time.Second) - defer c.Connection.Disconnect(250) - for { - time.Sleep(1 * time.Second) - if msgsReceived >= c.Config.ConsumeCount { - break } } - log.Debug("consumer finished", "protocol", "MQTT", "publisherc.Id", c.Id) + c.Stop("message count reached") +} + +func (c MqttConsumer) Stop(reason string) { + log.Debug("closing connection", "protocol", "mqtt", "consumerId", c.Id, "reason", reason) + c.Connection.Disconnect(250) } diff --git a/pkg/mqtt_client/publisher.go b/pkg/mqtt_client/publisher.go index 253e2a9..6437712 100644 --- a/pkg/mqtt_client/publisher.go +++ b/pkg/mqtt_client/publisher.go @@ -1,6 +1,7 @@ package mqtt_client import ( + "context" "fmt" "math/rand" "strings" @@ -61,7 +62,7 @@ func NewPublisher(cfg config.Config, n int) *MqttPublisher { } -func (p MqttPublisher) Start() { +func (p MqttPublisher) Start(ctx context.Context) { // sleep random interval to avoid all publishers publishing at the same time s := rand.Intn(1000) time.Sleep(time.Duration(s) * time.Millisecond) @@ -71,43 +72,44 @@ func (p MqttPublisher) Start() { p.msg = utils.MessageBody(p.Config.Size) if p.Config.Rate == -1 { - p.StartFullSpeed() + p.StartFullSpeed(ctx) } else { - p.StartRateLimited() + p.StartRateLimited(ctx) } + log.Debug("publisher stopped", "protocol", "MQTT", "publisherId", p.Id) } -func (p MqttPublisher) StartFullSpeed() { +func (p MqttPublisher) StartFullSpeed(ctx context.Context) { log.Info("publisher started", "protocol", "MQTT", "publisherId", p.Id, "rate", "unlimited", "destination", p.Topic) + for i := 1; i <= p.Config.PublishCount; i++ { - p.Send() + select { + case <-ctx.Done(): + return + default: + p.Send() + } } - - log.Debug("publisher stopped", "protocol", "MQTT", "publisherId", p.Id) } -func (p MqttPublisher) StartRateLimited() { +func (p MqttPublisher) StartRateLimited(ctx context.Context) { log.Info("publisher started", "protocol", "MQTT", "publisherId", p.Id, "rate", p.Config.Rate, "destination", p.Topic) ticker := time.NewTicker(time.Duration(1000/float64(p.Config.Rate)) * time.Millisecond) - done := make(chan bool) - msgsSent := 0 - go func() { - for { - select { - case <-done: + msgSent := 0 + for { + select { + case <-ctx.Done(): + p.Stop("time limit reached") + return + case <-ticker.C: + p.Send() + msgSent++ + if msgSent >= p.Config.PublishCount { + p.Stop("publish count reached") return - case <-ticker.C: - p.Send() - msgsSent++ } } - }() - for { - time.Sleep(1 * time.Second) - if msgsSent >= p.Config.PublishCount { - break - } } } @@ -123,3 +125,8 @@ func (p MqttPublisher) Send() { log.Debug("message sent", "protocol", "MQTT", "publisherId", p.Id) metrics.MessagesPublished.With(prometheus.Labels{"protocol": "mqtt"}).Inc() } + +func (p MqttPublisher) Stop(reason string) { + log.Debug("closing connection", "protocol", "mqtt", "publisherId", p.Id, "reason", reason) + p.Connection.Disconnect(250) +} diff --git a/pkg/stomp_client/consumer.go b/pkg/stomp_client/consumer.go index efb6a9c..1e5f9d4 100644 --- a/pkg/stomp_client/consumer.go +++ b/pkg/stomp_client/consumer.go @@ -1,6 +1,8 @@ package stomp_client import ( + "context" + "github.com/rabbitmq/omq/pkg/config" "github.com/rabbitmq/omq/pkg/log" "github.com/rabbitmq/omq/pkg/metrics" @@ -18,10 +20,11 @@ var o []func(*stomp.Conn) error = []func(*stomp.Conn) error{ } type StompConsumer struct { - Id int - Connection *stomp.Conn - Topic string - Config config.Config + Id int + Connection *stomp.Conn + Subscription *stomp.Subscription + Topic string + Config config.Config } func NewConsumer(cfg config.Config, id int) *StompConsumer { @@ -42,7 +45,7 @@ func NewConsumer(cfg config.Config, id int) *StompConsumer { } } -func (c StompConsumer) Start(subscribed chan bool) { +func (c StompConsumer) Start(ctx context.Context, subscribed chan bool) { var sub *stomp.Subscription var err error if c.Config.QueueDurability == config.None { @@ -54,28 +57,41 @@ func (c StompConsumer) Start(subscribed chan bool) { log.Error("subscription failed", "protocol", "STOMP", "consumerId", c.Id, "queue", c.Topic, "error", err.Error()) return } + c.Subscription = sub close(subscribed) m := metrics.EndToEndLatency.With(prometheus.Labels{"protocol": "stomp"}) log.Info("consumer started", "protocol", "STOMP", "consumerId", c.Id, "destination", c.Topic) for i := 1; i <= c.Config.ConsumeCount; i++ { - msg := <-sub.C - if msg.Err != nil { - log.Error("failed to receive a message", "protocol", "STOMP", "subscriberId", c.Id, "c.Topic", c.Topic, "error", msg.Err) - return - } - m.Observe(utils.CalculateEndToEndLatency(c.Config.UseMillis, &msg.Body)) - log.Debug("message received", "protocol", "stomp", "subscriberId", c.Id, "destination", c.Topic, "size", len(msg.Body), "ack required", msg.ShouldAck()) + select { + case msg := <-sub.C: + if msg.Err != nil { + log.Error("failed to receive a message", "protocol", "STOMP", "consumerId", c.Id, "c.Topic", c.Topic, "error", msg.Err) + return + } + m.Observe(utils.CalculateEndToEndLatency(c.Config.UseMillis, &msg.Body)) + log.Debug("message received", "protocol", "stomp", "consumerId", c.Id, "destination", c.Topic, "size", len(msg.Body), "ack required", msg.ShouldAck()) - err = c.Connection.Ack(msg) - if err != nil { - log.Error("message NOT acknowledged", "protocol", "stomp", "subscriberId", c.Id, "destination", c.Topic) + err = c.Connection.Ack(msg) + if err != nil { + log.Error("message NOT acknowledged", "protocol", "stomp", "consumerId", c.Id, "destination", c.Topic) + } + case <-ctx.Done(): + c.Stop("time limit reached") + return } metrics.MessagesConsumed.With(prometheus.Labels{"protocol": "stomp"}).Inc() } + c.Stop("message count reached") log.Debug("consumer finished", "protocol", "STOMP", "consumerId", c.Id) } + +func (c StompConsumer) Stop(reason string) { + log.Debug("closing connection", "protocol", "stomp", "consumerId", c.Id, "reason", reason) + _ = c.Subscription.Unsubscribe() + _ = c.Connection.Disconnect() +} diff --git a/pkg/stomp_client/publisher.go b/pkg/stomp_client/publisher.go index a0a7ab8..29648a8 100644 --- a/pkg/stomp_client/publisher.go +++ b/pkg/stomp_client/publisher.go @@ -1,6 +1,7 @@ package stomp_client import ( + "context" "math/rand" "time" @@ -46,7 +47,7 @@ func NewPublisher(cfg config.Config, id int) *StompPublisher { } } -func (p StompPublisher) Start() { +func (p StompPublisher) Start(ctx context.Context) { // sleep random interval to avoid all publishers publishing at the same time s := rand.Intn(1000) time.Sleep(time.Duration(s) * time.Millisecond) @@ -54,44 +55,44 @@ func (p StompPublisher) Start() { p.msg = utils.MessageBody(p.Config.Size) if p.Config.Rate == -1 { - p.StartFullSpeed() + p.StartFullSpeed(ctx) } else { - p.StartRateLimited() + p.StartRateLimited(ctx) } } -func (p StompPublisher) StartFullSpeed() { +func (p StompPublisher) StartFullSpeed(ctx context.Context) { log.Info("publisher started", "protocol", "STOMP", "publisherId", p.Id, "rate", "unlimited", "destination", p.Topic) for i := 1; i <= p.Config.PublishCount; i++ { - p.Send() + select { + case <-ctx.Done(): + return + default: + p.Send() + } } - - log.Debug("publisher finished", "publisherId", p.Id) + log.Debug("publisher completed", "protocol", "stomp", "publisherId", p.Id) } -func (p StompPublisher) StartRateLimited() { +func (p StompPublisher) StartRateLimited(ctx context.Context) { log.Info("publisher started", "protocol", "STOMP", "publisherId", p.Id, "rate", p.Config.Rate, "destination", p.Topic) ticker := time.NewTicker(time.Duration(1000/float64(p.Config.Rate)) * time.Millisecond) - done := make(chan bool) - msgsSent := 0 - go func() { - for { - select { - case <-done: + msgSent := 0 + for { + select { + case <-ctx.Done(): + p.Stop("time limit reached") + return + case <-ticker.C: + p.Send() + msgSent++ + if msgSent >= p.Config.PublishCount { + p.Stop("publish count reached") return - case <-ticker.C: - p.Send() - msgsSent++ } } - }() - for { - time.Sleep(1 * time.Second) - if msgsSent >= p.Config.PublishCount { - break - } } } @@ -115,3 +116,8 @@ func (p StompPublisher) Send() { metrics.MessagesPublished.With(prometheus.Labels{"protocol": "stomp"}).Inc() } + +func (p StompPublisher) Stop(reason string) { + log.Debug("closing connection", "protocol", "stomp", "publisherId", p.Id, "reason", reason) + _ = p.Connection.Disconnect() +}