From f5300b4269404db744c60826b615035356299987 Mon Sep 17 00:00:00 2001 From: Ravi Atluri Date: Fri, 15 Nov 2024 09:20:51 +0530 Subject: [PATCH] Refactor batch consumer to simply creating and processing batches --- xkafka/batch_consumer.go | 365 ++++++++++------------------ xkafka/batch_consumer_test.go | 31 ++- xkafka/middleware/slog/slog_test.go | 3 +- 3 files changed, 149 insertions(+), 250 deletions(-) diff --git a/xkafka/batch_consumer.go b/xkafka/batch_consumer.go index 5ee4584..35c75c3 100644 --- a/xkafka/batch_consumer.go +++ b/xkafka/batch_consumer.go @@ -4,24 +4,20 @@ import ( "context" "errors" "strings" - "sync" - "sync/atomic" "time" "github.com/confluentinc/confluent-kafka-go/v2/kafka" "github.com/sourcegraph/conc/stream" ) -// BatchConsumer manages consumption & processing of messages -// from kafka topics in batches. +// BatchConsumer manages the consumption of messages from kafka topics +// and processes them in batches. type BatchConsumer struct { name string kafka consumerClient handler BatchHandler middlewares []BatchMiddlewarer config *consumerConfig - batch *BatchManager - cancelCtx atomic.Pointer[context.CancelFunc] } // NewBatchConsumer creates a new BatchConsumer instance. @@ -50,185 +46,135 @@ func NewBatchConsumer(name string, handler BatchHandler, opts ...ConsumerOption) config: cfg, kafka: consumer, handler: handler, - batch: NewBatchManager(cfg.batchSize, cfg.batchTimeout), }, nil } -// GetMetadata returns the metadata for the consumer. -func (c *BatchConsumer) GetMetadata() (*Metadata, error) { - return c.kafka.GetMetadata(nil, false, int(c.config.metadataTimeout.Milliseconds())) +// Use appends a BatchMiddlewareFunc to the chain. +func (c *BatchConsumer) Use(mwf ...BatchMiddlewarer) { + c.middlewares = append(c.middlewares, mwf...) } -// Use appends a BatchMiddleware to the chain. -func (c *BatchConsumer) Use(mws ...BatchMiddlewarer) { - c.middlewares = append(c.middlewares, mws...) -} - -// Run starts the consumer and blocks until context is cancelled. +// Run starts running the BatchConsumer. The component will stop running +// when the context is closed. Run blocks until the context is closed or +// an error occurs. func (c *BatchConsumer) Run(ctx context.Context) (err error) { - defer func() { - cerr := c.close() - err = errors.Join(err, cerr) - }() - if err := c.subscribe(); err != nil { return err } - err = c.start(ctx) + defer func() { + if cerr := c.close(); cerr != nil { + err = errors.Join(err, cerr) + } + }() - return err + return c.start(ctx) } -// Start subscribes to the configured topics and starts consuming messages. -// This method is non-blocking and returns immediately post subscribe. -// Instead, use Run if you want to block until the context is closed or an error occurs. -// -// Errors are handled by the ErrorHandler if set, otherwise they stop the consumer -// and are returned. -func (c *BatchConsumer) Start() error { - if err := c.subscribe(); err != nil { - return err - } - - ctx, cancel := context.WithCancel(context.Background()) - c.cancelCtx.Store(&cancel) - - go func() { _ = c.start(ctx) }() - - return nil -} +func (c *BatchConsumer) start(ctx context.Context) error { + c.handler = c.concatMiddlewares(c.handler) -// Close closes the consumer. -func (c *BatchConsumer) Close() { - cancel := c.cancelCtx.Load() - if cancel != nil { - (*cancel)() + if c.config.concurrency > 1 { + return c.runAsync(ctx) } - _ = c.close() + return c.runSequential(ctx) } -func (c *BatchConsumer) start(ctx context.Context) error { - c.handler = c.concatMiddlewares(c.handler) - - // Create a context that can be cancelled with cause - ctx, cancel := context.WithCancelCause(ctx) +func (c *BatchConsumer) runSequential(ctx context.Context) (err error) { defer func() { - cancel(nil) - c.batch.Stop() - }() - - errChan := make(chan error, 2) - var wg sync.WaitGroup - wg.Add(2) - - // Start process goroutine - go func() { - defer wg.Done() - var err error - if c.config.concurrency > 1 { - err = c.processAsync(ctx) - } else { - err = c.process(ctx) - } - if err != nil { - cancel(err) - errChan <- err - } - }() - - // Start consume goroutine - go func() { - defer wg.Done() - err := c.consume(ctx) - if err != nil { - cancel(err) - errChan <- err + if uerr := c.unsubscribe(); uerr != nil { + err = errors.Join(err, uerr) } }() - // Wait for completion and collect errors - go func() { - wg.Wait() - close(errChan) - }() - - // Return the first error that occurred - for err := range errChan { - return err - } + batch := NewBatch() + timer := time.NewTimer(c.config.batchTimeout) + defer timer.Stop() - return context.Cause(ctx) -} - -func (c *BatchConsumer) process(ctx context.Context) error { for { select { case <-ctx.Done(): - return nil - case batch := <-c.batch.Receive(): - err := c.handler.HandleBatch(ctx, batch) - if ferr := c.config.errorHandler(err); ferr != nil { - return ferr + if len(batch.Messages) > 0 { + if err := c.processBatch(ctx, batch); err != nil { + return err + } + } + return err + + case <-timer.C: + if len(batch.Messages) > 0 { + if err := c.processBatch(ctx, batch); err != nil { + return err + } + batch = NewBatch() } + timer.Reset(c.config.batchTimeout) - err = c.saveOffset(batch) + default: + km, err := c.kafka.ReadMessage(c.config.pollTimeout) if err != nil { - return err + var kerr kafka.Error + if errors.As(err, &kerr) && kerr.Code() == kafka.ErrTimedOut { + continue + } + + if ferr := c.config.errorHandler(err); ferr != nil { + return ferr + } + continue + } + + msg := newMessage(c.name, km) + batch.Messages = append(batch.Messages, msg) + + if len(batch.Messages) >= c.config.batchSize { + if err := c.processBatch(ctx, batch); err != nil { + return err + } + batch = NewBatch() + timer.Reset(c.config.batchTimeout) } } } } -func (c *BatchConsumer) processAsync(ctx context.Context) error { +func (c *BatchConsumer) runAsync(ctx context.Context) error { st := stream.New().WithMaxGoroutines(c.config.concurrency) ctx, cancel := context.WithCancelCause(ctx) + batch := NewBatch() + timer := time.NewTimer(c.config.batchTimeout) + defer timer.Stop() + for { select { case <-ctx.Done(): st.Wait() - err := context.Cause(ctx) - if errors.Is(err, context.Canceled) { - return nil - } - return err - case batch := <-c.batch.Receive(): - st.Go(func() stream.Callback { - err := c.handler.HandleBatch(ctx, batch) - if ferr := c.config.errorHandler(err); ferr != nil { - cancel(ferr) + var err error - return func() {} - } + if len(batch.Messages) > 0 { + err = c.processBatch(ctx, batch) + } - return func() { - if err := c.saveOffset(batch); err != nil { - cancel(err) - } - } - }) - } - } -} + uerr := c.unsubscribe() + err = errors.Join(err, uerr) -func (c *BatchConsumer) consume(ctx context.Context) (err error) { - if err := c.subscribe(); err != nil { - return err - } + cerr := context.Cause(ctx) + if cerr != nil && !errors.Is(cerr, context.Canceled) { + err = errors.Join(err, cerr) + } - defer func() { - if uerr := c.unsubscribe(); uerr != nil { - err = errors.Join(err, uerr) - } - }() + return err + + case <-timer.C: + if len(batch.Messages) > 0 { + c.processBatchAsync(ctx, batch, st, cancel) + batch = NewBatch() + } + timer.Reset(c.config.batchTimeout) - for { - select { - case <-ctx.Done(): - return context.Cause(ctx) default: km, err := c.kafka.ReadMessage(c.config.pollTimeout) if err != nil { @@ -238,58 +184,63 @@ func (c *BatchConsumer) consume(ctx context.Context) (err error) { } if ferr := c.config.errorHandler(err); ferr != nil { - err = ferr - - return err + cancel(ferr) } continue } msg := newMessage(c.name, km) - c.batch.Add(msg) + batch.Messages = append(batch.Messages, msg) + + if len(batch.Messages) >= c.config.batchSize { + c.processBatchAsync(ctx, batch, st, cancel) + batch = NewBatch() + timer.Reset(c.config.batchTimeout) + } } } } -func (c *BatchConsumer) subscribe() error { - return c.kafka.SubscribeTopics(c.config.topics, nil) -} - -func (c *BatchConsumer) unsubscribe() error { - _, _ = c.kafka.Commit() - - return c.kafka.Unsubscribe() -} - -func (c *BatchConsumer) close() error { - <-time.After(c.config.shutdownTimeout) +func (c *BatchConsumer) processBatch(ctx context.Context, batch *Batch) error { + err := c.handler.HandleBatch(ctx, batch) + if ferr := c.config.errorHandler(err); ferr != nil { + return ferr + } - return c.kafka.Close() + return c.storeBatch(batch) } -func (c *BatchConsumer) concatMiddlewares(handler BatchHandler) BatchHandler { - for i := len(c.middlewares) - 1; i >= 0; i-- { - handler = c.middlewares[i].BatchMiddleware(handler) - } +func (c *BatchConsumer) processBatchAsync(ctx context.Context, batch *Batch, st *stream.Stream, cancel context.CancelCauseFunc) { + st.Go(func() stream.Callback { + err := c.handler.HandleBatch(ctx, batch) + if ferr := c.config.errorHandler(err); ferr != nil { + cancel(ferr) + return func() {} + } - return handler + return func() { + if err := c.storeBatch(batch); err != nil { + cancel(err) + } + } + }) } -func (c *BatchConsumer) saveOffset(batch *Batch) error { +func (c *BatchConsumer) storeBatch(batch *Batch) error { if batch.Status != Success && batch.Status != Skip { return nil } - offsets := batch.GroupMaxOffset() - - _, err := c.kafka.StoreOffsets(offsets) + tps := batch.GroupMaxOffset() + _, err := c.kafka.StoreOffsets(tps) if err != nil { return err } if c.config.manualCommit { - if _, err := c.kafka.Commit(); err != nil { + _, err := c.kafka.Commit() + if err != nil { return err } } @@ -297,83 +248,23 @@ func (c *BatchConsumer) saveOffset(batch *Batch) error { return nil } -// BatchManager manages aggregation and processing of Message batches. -type BatchManager struct { - size int - timeout time.Duration - batch *Batch - mutex *sync.RWMutex - flushChan chan *Batch - done chan struct{} -} - -// NewBatchManager creates a new BatchManager. -func NewBatchManager(size int, timeout time.Duration) *BatchManager { - b := &BatchManager{ - size: size, - timeout: timeout, - mutex: &sync.RWMutex{}, - batch: NewBatch(), - flushChan: make(chan *Batch), - done: make(chan struct{}), - } - - go b.runFlushByTime() - - return b -} - -// Add adds to batch and flushes when MaxSize is reached. -func (b *BatchManager) Add(m *Message) { - b.mutex.Lock() - b.batch.Messages = append(b.batch.Messages, m) - - if len(b.batch.Messages) >= b.size { - b.flush() +func (c *BatchConsumer) concatMiddlewares(h BatchHandler) BatchHandler { + for i := len(c.middlewares) - 1; i >= 0; i-- { + h = c.middlewares[i].BatchMiddleware(h) } - - b.mutex.Unlock() + return h } -// Receive returns a channel to read batched Messages. -func (b *BatchManager) Receive() <-chan *Batch { - return b.flushChan -} - -func (b *BatchManager) runFlushByTime() { - t := time.NewTicker(b.timeout) - defer t.Stop() - - for { - select { - case <-b.done: - b.mutex.Lock() - b.flush() - close(b.flushChan) - b.mutex.Unlock() - return - case <-t.C: - b.mutex.Lock() - b.flush() - b.mutex.Unlock() - } - } +func (c *BatchConsumer) subscribe() error { + return c.kafka.SubscribeTopics(c.config.topics, nil) } -// flush sends the batch to the flush channel and resets the batch. -// DESIGN: flush does NOT acquire a mutex lock. Locks should be managed by caller. -// nolint:gosimple -func (b *BatchManager) flush() { - if len(b.batch.Messages) == 0 { - return - } - - b.flushChan <- b.batch - - b.batch = NewBatch() +func (c *BatchConsumer) unsubscribe() error { + _, _ = c.kafka.Commit() + return c.kafka.Unsubscribe() } -// Stop signals the batch manager to stop and clean up -func (b *BatchManager) Stop() { - close(b.done) +func (c *BatchConsumer) close() error { + <-time.After(c.config.shutdownTimeout) + return c.kafka.Close() } diff --git a/xkafka/batch_consumer_test.go b/xkafka/batch_consumer_test.go index 226dd77..fe683e4 100644 --- a/xkafka/batch_consumer_test.go +++ b/xkafka/batch_consumer_test.go @@ -63,19 +63,19 @@ func TestNewBatchConsumer(t *testing.T) { func TestBatchConsumer_Lifecycle(t *testing.T) { t.Parallel() - t.Run("StartSubscribeError", func(t *testing.T) { + t.Run("RunSubscribeError", func(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) expectError := errors.New("error in subscribe") mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(expectError) - assert.Error(t, consumer.Start()) + assert.Error(t, consumer.Run(context.Background())) mockKafka.AssertExpectations(t) }) - t.Run("StartSuccessCloseError", func(t *testing.T) { + t.Run("RunCloseError", func(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) @@ -84,14 +84,19 @@ func TestBatchConsumer_Lifecycle(t *testing.T) { mockKafka.On("Commit").Return(nil, nil) mockKafka.On("Close").Return(errors.New("error in close")) - assert.NoError(t, consumer.Start()) - <-time.After(100 * time.Millisecond) - consumer.Close() + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + <-time.After(100 * time.Millisecond) + cancel() + }() + + assert.Error(t, consumer.Run(ctx)) mockKafka.AssertExpectations(t) }) - t.Run("StartCloseSuccess", func(t *testing.T) { + t.Run("RunSuccess", func(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) @@ -100,9 +105,14 @@ func TestBatchConsumer_Lifecycle(t *testing.T) { mockKafka.On("Commit").Return(nil, nil) mockKafka.On("Close").Return(nil) - assert.NoError(t, consumer.Start()) - <-time.After(100 * time.Millisecond) - consumer.Close() + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + <-time.After(100 * time.Millisecond) + cancel() + }() + + assert.NoError(t, consumer.Run(ctx)) mockKafka.AssertExpectations(t) }) @@ -208,7 +218,6 @@ func TestBatchConsumer_Async(t *testing.T) { b.AckSuccess() assert.NotNil(t, b) - assert.Len(t, b.Messages, 3) n := count.Add(1) diff --git a/xkafka/middleware/slog/slog_test.go b/xkafka/middleware/slog/slog_test.go index 889bf22..eaa74bf 100644 --- a/xkafka/middleware/slog/slog_test.go +++ b/xkafka/middleware/slog/slog_test.go @@ -4,9 +4,8 @@ import ( "context" "testing" - "log/slog" - "github.com/stretchr/testify/assert" + "log/slog" "github.com/gojekfarm/xtools/xkafka" )