diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 794a4a1..2c754b8 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -23,12 +23,18 @@ jobs: image: tfgco/pusher:ci-test services: zookeeper: - image: wurstmeister/zookeeper + image: confluentinc/cp-zookeeper:7.4.0 + env: + ZOOKEEPER_CLIENT_PORT: 2181 kafka: - image: wurstmeister/kafka:0.10.1.0-2 + image: confluentinc/cp-kafka:7.4.0 + options: --health-cmd "kafka-topics --list --bootstrap-server kafka:9092" --health-interval 10s --health-timeout 10s --health-retries 15 env: - KAFKA_ADVERTISED_HOST_NAME: kafka - KAFKA_ADVERTISED_PORT: 9092 + KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092,PLAINTEXT_HOST://kafka:29092 + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: PLAINTEXT:PLAINTEXT,PLAINTEXT_HOST:PLAINTEXT + KAFKA_LISTENERS: PLAINTEXT://kafka:9092,PLAINTEXT_HOST://0.0.0.0:29092 + KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 + KAFKA_DEFAULT_REPLICATION_FACTOR: 1 KAFKA_ZOOKEEPER_CONNECT: "zookeeper:2181" KAFKA_NUM_PARTITIONS: 5 KAFKA_CREATE_TOPICS: "com.games.test:5:1" diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index f0ce882..8cd9530 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -48,7 +48,7 @@ jobs: - name: Build run: docker run -v $PWD:/go/src/github.com/topfreegames/pusher tfgco/pusher:ci-test go build -v -o bin/pusher main.go - name: Test - run: docker run -v $PWD:/go/src/github.com/topfreegames/pusher tfgco/pusher:ci-test ginkgo -v -r --randomizeAllSpecs --randomizeSuites --cover --focus="\[Unit\].*" . + run: docker run -v $PWD:/go/src/github.com/topfreegames/pusher tfgco/pusher:ci-test ginkgo -v -r --randomizeAllSpecs --randomizeSuites --cover --focus="\[Unit\].*" --skipPackage=e2e . - name: Lint continue-on-error: true run: docker run -v $PWD:/go/src/github.com/topfreegames/pusher tfgco/pusher:ci-test golangci-lint run diff --git a/Makefile b/Makefile index 1b90d60..a077e10 100644 --- a/Makefile +++ b/Makefile @@ -105,8 +105,8 @@ test-unit: @echo "-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=" @echo @export $ACK_GINKGO_RC=true - @$(GINKGO) --race -trace -r --randomizeAllSpecs --randomizeSuites --cover --focus="\[Unit\].*" . - @$(MAKE) test-coverage-func + @$(GINKGO) -trace -r --randomizeAllSpecs --randomizeSuites --cover --focus="\[Unit\].*" --skipPackage=e2e . + @#$(MAKE) test-coverage-func @echo @echo "-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=" @echo "= Unit tests finished. =" @@ -120,7 +120,7 @@ run-integration-test: @echo "-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=" @echo @export $ACK_GINKGO_RC=true - @$(GINKGO) --race -trace -r -tags=integration --randomizeAllSpecs --randomizeSuites --focus="\[Integration\].*" . + @$(GINKGO) -trace -r -tags=integration --randomizeAllSpecs --randomizeSuites --focus="\[Integration\].*" . @echo @echo "-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=" @echo "= Integration tests finished. =" @@ -173,4 +173,7 @@ integration-test-container-dev: build-image-dev start-deps-container-dev test-db .PHONY: mocks mocks: - $(MOCKGENERATE) -source=interfaces/client.go -destination=mocks/firebase/client.go \ No newline at end of file + $(MOCKGENERATE) -source=interfaces/client.go -destination=mocks/firebase/client.go + $(MOCKGENERATE) -source=interfaces/apns.go -destination=mocks/interfaces/apns.go + $(MOCKGENERATE) -source=interfaces/statsd.go -destination=mocks/interfaces/statsd.go + $(MOCKGENERATE) -source=interfaces/feedback_reporter.go -destination=mocks/interfaces/feedback_reporter.go diff --git a/cmd/apns.go b/cmd/apns.go index bfcc997..73ecf9e 100644 --- a/cmd/apns.go +++ b/cmd/apns.go @@ -28,6 +28,7 @@ import ( "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" + "github.com/topfreegames/pusher/config" "github.com/topfreegames/pusher/interfaces" "github.com/topfreegames/pusher/pusher" "github.com/topfreegames/pusher/util" @@ -35,7 +36,8 @@ import ( func startApns( debug, json, production bool, - config *viper.Viper, + vConfig *viper.Viper, + config *config.Config, statsdClientOrNil interfaces.StatsDClient, dbOrNil interfaces.DB, queueOrNil interfaces.APNSPushQueue, @@ -49,7 +51,7 @@ func startApns( } else { log.Level = logrus.InfoLevel } - return pusher.NewAPNSPusher(production, config, log, statsdClientOrNil, dbOrNil, queueOrNil) + return pusher.NewAPNSPusher(production, vConfig, config, log, statsdClientOrNil, dbOrNil, queueOrNil) } // apnsCmd represents the apns command @@ -58,19 +60,19 @@ var apnsCmd = &cobra.Command{ Short: "starts pusher in apns mode", Long: `starts pusher in apns mode`, Run: func(cmd *cobra.Command, args []string) { - config, err := util.NewViperWithConfigFile(cfgFile) + config, vConfig, err := config.NewConfigAndViper(cfgFile) if err != nil { panic(err) } - sentryURL := config.GetString("sentry.url") + sentryURL := vConfig.GetString("sentry.url") if sentryURL != "" { raven.SetDSN(sentryURL) } ctx := context.Background() - apnsPusher, err := startApns(debug, json, production, config, nil, nil, nil) + apnsPusher, err := startApns(debug, json, production, vConfig, config, nil, nil, nil) if err != nil { raven.CaptureErrorAndWait(err, map[string]string{ "version": util.Version, diff --git a/cmd/apns_test.go b/cmd/apns_test.go index 694cf86..eaa760a 100644 --- a/cmd/apns_test.go +++ b/cmd/apns_test.go @@ -24,6 +24,7 @@ package cmd import ( "fmt" + "github.com/topfreegames/pusher/config" "os" . "github.com/onsi/ginkgo" @@ -31,23 +32,23 @@ import ( "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/topfreegames/pusher/mocks" - "github.com/topfreegames/pusher/util" ) var _ = Describe("APNS", func() { - cfg := os.Getenv("CONFIG_FILE") - if cfg == "" { - cfg = "../config/test.yaml" + configFile := os.Getenv("CONFIG_FILE") + if configFile == "" { + configFile = "../config/test.yaml" } - var config *viper.Viper + var vConfig *viper.Viper + var cfg *config.Config var mockPushQueue *mocks.APNSPushQueueMock var mockDB *mocks.PGMock var mockStatsDClient *mocks.StatsDClientMock BeforeEach(func() { var err error - config, err = util.NewViperWithConfigFile(cfg) + cfg, vConfig, err = config.NewConfigAndViper(configFile) Expect(err).NotTo(HaveOccurred()) mockDB = mocks.NewPGMock(0, 1) mockPushQueue = mocks.NewAPNSPushQueueMock() @@ -56,7 +57,7 @@ var _ = Describe("APNS", func() { Describe("[Unit]", func() { It("Should return apnsPusher without errors", func() { - apnsPusher, err := startApns(false, false, false, config, mockStatsDClient, mockDB, mockPushQueue) + apnsPusher, err := startApns(false, false, false, vConfig, cfg, mockStatsDClient, mockDB, mockPushQueue) Expect(err).NotTo(HaveOccurred()) Expect(apnsPusher).NotTo(BeNil()) Expect(apnsPusher.ViperConfig).NotTo(BeNil()) @@ -66,21 +67,21 @@ var _ = Describe("APNS", func() { }) It("Should set log to json format", func() { - apnsPusher, err := startApns(false, true, false, config, mockStatsDClient, mockDB, mockPushQueue) + apnsPusher, err := startApns(false, true, false, vConfig, cfg, mockStatsDClient, mockDB, mockPushQueue) Expect(err).NotTo(HaveOccurred()) Expect(apnsPusher).NotTo(BeNil()) Expect(fmt.Sprintf("%T", apnsPusher.Logger.Formatter)).To(Equal(fmt.Sprintf("%T", &logrus.JSONFormatter{}))) }) It("Should set log to debug", func() { - apnsPusher, err := startApns(true, false, false, config, mockStatsDClient, mockDB, mockPushQueue) + apnsPusher, err := startApns(true, false, false, vConfig, cfg, mockStatsDClient, mockDB, mockPushQueue) Expect(err).NotTo(HaveOccurred()) Expect(apnsPusher).NotTo(BeNil()) Expect(apnsPusher.Logger.Level).To(Equal(logrus.DebugLevel)) }) It("Should set log to production", func() { - apnsPusher, err := startApns(false, false, true, config, mockStatsDClient, mockDB, mockPushQueue) + apnsPusher, err := startApns(false, false, true, vConfig, cfg, mockStatsDClient, mockDB, mockPushQueue) Expect(err).NotTo(HaveOccurred()) Expect(apnsPusher).NotTo(BeNil()) Expect(apnsPusher.IsProduction).To(BeTrue()) diff --git a/config/config.go b/config/config.go index cb7e866..711f87c 100644 --- a/config/config.go +++ b/config/config.go @@ -14,9 +14,15 @@ type ( // Config is the struct that holds all the configuration for the Pusher. Config struct { GCM GCM + Apns Apns + Queue Kafka GracefulShutdownTimeout int } + Kafka struct { + Brokers string + } + GCM struct { Apps string PingInterval int @@ -24,6 +30,18 @@ type ( MaxPendingMessages int LogStatsInterval int } + + Apns struct { + Apps string + Certs map[string]Cert + } + + Cert struct { + AuthKeyPath string + KeyID string + TeamID string + Topic string + } ) // NewConfigAndViper returns a new Config object and the corresponding viper instance. @@ -45,7 +63,7 @@ func NewConfigAndViper(configFile string) (*Config, *viper.Viper, error) { return config, v, nil } -func (c *Config) GetAppsArray() []string { +func (c *Config) GetGcmAppsArray() []string { arr := strings.Split(c.GCM.Apps, ",") res := make([]string, 0, len(arr)) for _, a := range arr { @@ -55,6 +73,16 @@ func (c *Config) GetAppsArray() []string { return res } +func (c *Config) GetApnsAppsArray() []string { + arr := strings.Split(c.Apns.Apps, ",") + res := make([]string, 0, len(arr)) + for _, a := range arr { + res = append(res, strings.TrimSpace(a)) + } + + return res +} + func decodeHookFunc() viper.DecoderConfigOption { hooks := mapstructure.ComposeDecodeHookFunc( StringToMapStringHookFunc(), diff --git a/config/docker_test.yaml b/config/docker_test.yaml index 5df2b65..5694800 100644 --- a/config/docker_test.yaml +++ b/config/docker_test.yaml @@ -24,7 +24,7 @@ gcm: queue: topics: - "^push-[^-_]+_(apns|gcm)[_-](single|massive)" - brokers: "kafka:9092" + brokers: "kafka:29092" group: testGroup sessionTimeout: 6000 fetch.min.bytes: 1 @@ -36,10 +36,10 @@ feedback: - kafka kafka: topics: "com.games.test.feedbacks" - brokers: "kafka:9092" + brokers: "kafka:29092" cache: - requestTimeout: 100 - cleaningInterval: 20 + requestTimeout: 3000 + cleaningInterval: 600 stats: reporters: - statsd @@ -65,7 +65,7 @@ feedbackListeners: queue: topics: - "^push-[^-_]+-(apns|gcm)-feedbacks" - brokers: "kafka:9092" + brokers: "kafka:29092" group: testGroup sessionTimeout: 6000 fetch.min.bytes: 1 diff --git a/config/test.yaml b/config/test.yaml index 0d08321..ebb57d8 100644 --- a/config/test.yaml +++ b/config/test.yaml @@ -38,8 +38,8 @@ feedback: topics: "com.games.test.feedbacks" brokers: "localhost:9941" cache: - requestTimeout: 100 - cleaningInterval: 20 + requestTimeout: 3000 + cleaningInterval: 600 stats: reporters: - statsd diff --git a/e2e/apns_e2e_test.go b/e2e/apns_e2e_test.go new file mode 100644 index 0000000..d82e27c --- /dev/null +++ b/e2e/apns_e2e_test.go @@ -0,0 +1,322 @@ +package e2e + +import ( + "context" + "fmt" + "github.com/confluentinc/confluent-kafka-go/v2/kafka" + "github.com/google/uuid" + "github.com/sideshow/apns2" + "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/stretchr/testify/suite" + "github.com/topfreegames/pusher/config" + mocks "github.com/topfreegames/pusher/mocks/interfaces" + "github.com/topfreegames/pusher/pusher" + "github.com/topfreegames/pusher/structs" + "go.uber.org/mock/gomock" + "os" + "strings" + "testing" + "time" +) + +const wait = 5 * time.Second +const timeout = 1 * time.Minute +const topicTemplate = "push-%s_apns-single" + +type ApnsE2ETestSuite struct { + suite.Suite + + config *config.Config + vConfig *viper.Viper +} + +func TestApnsE2eSuite(t *testing.T) { + suite.Run(t, new(ApnsE2ETestSuite)) +} + +func (s *ApnsE2ETestSuite) SetupSuite() { + configFile := os.Getenv("CONFIG_FILE") + if configFile == "" { + configFile = "../config/test.yaml" + } + c, v, err := config.NewConfigAndViper(configFile) + s.Require().NoError(err) + s.config = c + s.vConfig = v +} + +func (s *ApnsE2ETestSuite) setupApnsPusher() (*mocks.MockAPNSPushQueue, *mocks.MockStatsDClient, chan *structs.ResponseWithMetadata) { + responsesChannel := make(chan *structs.ResponseWithMetadata) + + ctrl := gomock.NewController(s.T()) + mockApnsClient := mocks.NewMockAPNSPushQueue(ctrl) + mockApnsClient.EXPECT().ResponseChannel().Return(responsesChannel) + + statsdClientMock := mocks.NewMockStatsDClient(ctrl) + // Gauge can be called any times from the go stats report + statsdClientMock.EXPECT().Gauge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + logger := logrus.New() + logger.Level = logrus.DebugLevel + + s.assureTopicsExist() + time.Sleep(wait) + + apnsPusher, err := pusher.NewAPNSPusher(false, s.vConfig, s.config, logger, statsdClientMock, nil, mockApnsClient) + s.Require().NoError(err) + ctx := context.Background() + go apnsPusher.Start(ctx) + + time.Sleep(wait) + + return mockApnsClient, statsdClientMock, responsesChannel +} + +func (s *ApnsE2ETestSuite) TestSimpleNotification() { + appName := strings.Split(uuid.NewString(), "-")[0] + s.config.Apns.Apps = appName + s.vConfig.Set("queue.topics", []string{fmt.Sprintf(topicTemplate, appName)}) + + mockApnsClient, statsdClientMock, responsesChannel := s.setupApnsPusher() + producer, err := kafka.NewProducer(&kafka.ConfigMap{ + "bootstrap.servers": s.config.Queue.Brokers, + }) + s.Require().NoError(err) + + app := s.config.GetApnsAppsArray()[0] + topic := "push-" + app + "_apns-single" + token := "token" + testDone := make(chan bool) + mockApnsClient.EXPECT(). + Push(gomock.Any()). + DoAndReturn(func(notification *apns2.Notification) error { + s.Equal(token, notification.DeviceToken) + s.Equal(s.config.Apns.Certs[app].Topic, notification.Topic) + + go func() { + responsesChannel <- &structs.ResponseWithMetadata{ + ApnsID: notification.ApnsID, + Sent: true, + StatusCode: 200, + DeviceToken: token, + } + }() + return nil + }) + + statsdClientMock.EXPECT(). + Incr("sent", []string{fmt.Sprintf("platform:%s", "apns"), fmt.Sprintf("game:%s", app)}, float64(1)). + DoAndReturn(func(string, []string, float64) error { + return nil + }) + + statsdClientMock.EXPECT(). + Incr("ack", []string{fmt.Sprintf("platform:%s", "apns"), fmt.Sprintf("game:%s", app)}, float64(1)). + DoAndReturn(func(string, []string, float64) error { + testDone <- true + return nil + }) + + err = producer.Produce(&kafka.Message{ + TopicPartition: kafka.TopicPartition{ + Topic: &topic, + Partition: kafka.PartitionAny, + }, + Value: []byte(`{"deviceToken":"` + token + `", "payload": {"aps": {"alert": "Hello"}}}`), + }, + nil) + s.Require().NoError(err) + + //Give it some time to process the message + timer := time.NewTimer(timeout) + select { + case <-testDone: + // Wait some time to make sure it won't call the push client again after the testDone signal + time.Sleep(wait) + case <-timer.C: + s.FailNow("Timeout waiting for Handler to report notification sent") + } +} + +func (s *ApnsE2ETestSuite) TestNotificationRetry() { + appName := strings.Split(uuid.NewString(), "-")[0] + s.config.Apns.Apps = appName + s.vConfig.Set("queue.topics", []string{fmt.Sprintf(topicTemplate, appName)}) + + mockApnsClient, statsdClientMock, responsesChannel := s.setupApnsPusher() + + producer, err := kafka.NewProducer(&kafka.ConfigMap{ + "bootstrap.servers": s.config.Queue.Brokers, + }) + s.Require().NoError(err) + + app := s.config.GetApnsAppsArray()[0] + topic := "push-" + app + "_apns-single" + token := "token" + done := make(chan bool) + + mockApnsClient.EXPECT(). + Push(gomock.Any()). + DoAndReturn(func(notification *apns2.Notification) error { + s.Equal(token, notification.DeviceToken) + s.Equal(s.config.Apns.Certs[app].Topic, notification.Topic) + + go func() { + responsesChannel <- &structs.ResponseWithMetadata{ + ApnsID: notification.ApnsID, + Sent: true, + StatusCode: 429, + Reason: apns2.ReasonTooManyRequests, + DeviceToken: token, + } + }() + return nil + }) + + mockApnsClient.EXPECT(). + Push(gomock.Any()). + DoAndReturn(func(notification *apns2.Notification) error { + s.Equal(token, notification.DeviceToken) + s.Equal(s.config.Apns.Certs[app].Topic, notification.Topic) + + go func() { + responsesChannel <- &structs.ResponseWithMetadata{ + ApnsID: notification.ApnsID, + Sent: true, + StatusCode: 200, + DeviceToken: token, + } + }() + return nil + }) + + statsdClientMock.EXPECT(). + Incr("sent", []string{fmt.Sprintf("platform:%s", "apns"), fmt.Sprintf("game:%s", app)}, float64(1)). + DoAndReturn(func(string, []string, float64) error { + return nil + }) + + statsdClientMock.EXPECT(). + Incr("ack", []string{fmt.Sprintf("platform:%s", "apns"), fmt.Sprintf("game:%s", app)}, float64(1)). + DoAndReturn(func(string, []string, float64) error { + done <- true + return nil + }) + + err = producer.Produce(&kafka.Message{ + TopicPartition: kafka.TopicPartition{ + Topic: &topic, + Partition: kafka.PartitionAny, + }, + Value: []byte(`{"deviceToken":"` + token + `", "payload": {"aps": {"alert": "Hello"}}}`), + }, + nil) + s.Require().NoError(err) + + //Give it some time to process the message + timer := time.NewTimer(timeout) + select { + case <-done: + // Wait some time to make sure it won't call the push client again after the done signal + time.Sleep(wait) + case <-timer.C: + s.FailNow("Timeout waiting for Handler to report notification sent") + } +} + +func (s *ApnsE2ETestSuite) TestMultipleNotifications() { + appName := strings.Split(uuid.NewString(), "-")[0] + s.config.Apns.Apps = appName + s.vConfig.Set("queue.topics", []string{fmt.Sprintf(topicTemplate, appName)}) + + mockApnsClient, statsdClientMock, responsesChannel := s.setupApnsPusher() + + notificationsToSend := 10 + producer, err := kafka.NewProducer(&kafka.ConfigMap{ + "bootstrap.servers": s.config.Queue.Brokers, + }) + s.Require().NoError(err) + + app := s.config.GetApnsAppsArray()[0] + topic := fmt.Sprintf(topicTemplate, app) + token := "token" + done := make(chan bool) + + for i := 0; i < notificationsToSend; i++ { + mockApnsClient.EXPECT(). + Push(gomock.Any()). + DoAndReturn(func(notification *apns2.Notification) error { + s.Equal(s.config.Apns.Certs[app].Topic, notification.Topic) + + go func() { + responsesChannel <- &structs.ResponseWithMetadata{ + ApnsID: notification.ApnsID, + Sent: true, + StatusCode: 200, + DeviceToken: notification.DeviceToken, + } + }() + return nil + }) + } + + statsdClientMock.EXPECT(). + Incr("sent", []string{fmt.Sprintf("platform:%s", "apns"), fmt.Sprintf("game:%s", app)}, float64(1)). + Times(notificationsToSend). + DoAndReturn(func(string, []string, float64) error { + return nil + }) + + statsdClientMock.EXPECT(). + Incr("ack", []string{fmt.Sprintf("platform:%s", "apns"), fmt.Sprintf("game:%s", app)}, float64(1)). + Times(notificationsToSend). + DoAndReturn(func(string, []string, float64) error { + done <- true + return nil + }) + + for i := 0; i < notificationsToSend; i++ { + err = producer.Produce(&kafka.Message{ + TopicPartition: kafka.TopicPartition{ + Topic: &topic, + Partition: kafka.PartitionAny, + }, + Value: []byte(`{"deviceToken":"` + fmt.Sprintf("%s%d", token, i) + `", "payload": {"aps": {"alert": "Hello"}}}`), + }, + nil) + s.Require().NoError(err) + } + //Give it some time to process the message + timer := time.NewTimer(timeout) + for i := 0; i < notificationsToSend; i++ { + select { + case <-done: + case <-timer.C: + s.FailNow("Timeout waiting for Handler to report notification sent") + } + } + // Wait some time to make sure it won't call the push client again after everything is done + time.Sleep(wait) +} + +func (s *ApnsE2ETestSuite) assureTopicsExist() { + producer, err := kafka.NewProducer(&kafka.ConfigMap{ + "bootstrap.servers": s.config.Queue.Brokers, + }) + s.Require().NoError(err) + + apnsApps := s.config.GetApnsAppsArray() + for _, a := range apnsApps { + topic := fmt.Sprintf(topicTemplate, a) + err = producer.Produce(&kafka.Message{ + TopicPartition: kafka.TopicPartition{ + Topic: &topic, + Partition: kafka.PartitionAny, + }, + Value: []byte("not a notification"), + }, + nil) + s.Require().NoError(err) + } +} diff --git a/extensions/apns_message_handler.go b/extensions/apns_message_handler.go index 8df9785..83c7f58 100644 --- a/extensions/apns_message_handler.go +++ b/extensions/apns_message_handler.go @@ -61,7 +61,7 @@ type APNSMessageHandler struct { teamID string appName string PushQueue interfaces.APNSPushQueue - Topic string + ApnsTopic string Config *viper.Viper failuresReceived int64 InFlightNotificationsMap map[string]*inFlightNotification @@ -99,7 +99,7 @@ func NewAPNSMessageHandler( authKeyPath: authKeyPath, keyID: keyID, teamID: teamID, - Topic: topic, + ApnsTopic: topic, appName: appName, Config: config, failuresReceived: 0, @@ -118,6 +118,14 @@ func NewAPNSMessageHandler( PushQueue: pushQueue, consumptionManager: consumptionManager, } + + if a.Logger != nil { + a.Logger = a.Logger.WithFields(log.Fields{ + "source": "APNSMessageHandler", + "game": appName, + }).Logger + } + if err := a.configure(); err != nil { return nil, err } @@ -165,7 +173,13 @@ func (a *APNSMessageHandler) loadConfigurationDefaults() { // HandleResponses from apns. func (a *APNSMessageHandler) HandleResponses() { for response := range a.PushQueue.ResponseChannel() { - a.handleAPNSResponse(response) + err := a.handleAPNSResponse(response) + if err != nil { + a.Logger. + WithField("method", "HandleResponses"). + WithError(err). + Error("error handling response") + } } } @@ -193,8 +207,13 @@ func (a *APNSMessageHandler) CleanMetadataCache() { } // HandleMessages get messages from msgChan and send to APNS. -func (a *APNSMessageHandler) HandleMessages(ctx context.Context, message interfaces.KafkaMessage) { - a.Logger.WithField("message", message).Debug("received message to send to apns") +func (a *APNSMessageHandler) HandleMessages(_ context.Context, message interfaces.KafkaMessage) { + l := a.Logger.WithFields(log.Fields{ + "method": "HandleMessages", + "jsonValue": string(message.Value), + "topic": message.Topic, + }) + l.Debug("received message to send to apns") notification, err := a.buildNotification(message) if err != nil { return @@ -203,7 +222,11 @@ func (a *APNSMessageHandler) HandleMessages(ctx context.Context, message interfa return } statsReporterHandleNotificationSent(a.StatsReporters, a.appName, "apns") + + apnsResMutex.Lock() a.sentMessages++ + apnsResMutex.Unlock() + a.inFlightNotificationsMapLock.Lock() ifn := &inFlightNotification{ notification: notification, @@ -244,7 +267,11 @@ func (a *APNSMessageHandler) sendNotification(notification *Notification) error l := a.Logger.WithField("method", "sendNotification") if notification.PushExpiry > 0 && notification.PushExpiry < MakeTimestamp() { l.Warnf("ignoring push message because it has expired: %s", notification.Payload) + + apnsResMutex.Lock() a.ignoredMessages++ + apnsResMutex.Unlock() + if a.pendingMessagesWG != nil { a.pendingMessagesWG.Done() } @@ -253,7 +280,11 @@ func (a *APNSMessageHandler) sendNotification(notification *Notification) error payload, err := json.Marshal(notification.Payload) if err != nil { l.WithError(err).Error("error marshaling message payload") + + apnsResMutex.Lock() a.ignoredMessages++ + apnsResMutex.Unlock() + if a.pendingMessagesWG != nil { a.pendingMessagesWG.Done() } @@ -261,7 +292,7 @@ func (a *APNSMessageHandler) sendNotification(notification *Notification) error } l.WithField("notification", notification).Debug("adding notification to apns push queue") a.PushQueue.Push(&apns2.Notification{ - Topic: a.Topic, + Topic: a.ApnsTopic, DeviceToken: notification.DeviceToken, Payload: payload, ApnsID: notification.ApnsID, @@ -288,20 +319,25 @@ func (a *APNSMessageHandler) handleAPNSResponse(responseWithMetadata *structs.Re sendAttempts := inFlightNotificationInstance.sendAttempts.Load() if responseWithMetadata.Reason == apns2.ReasonTooManyRequests && uint(sendAttempts) < a.maxRetryAttempts { - a.consumptionManager.Pause(inFlightNotificationInstance.kafkaTopic) + l.WithFields(log.Fields{ + "sendAttempts": sendAttempts, + "maxRetries": a.maxRetryAttempts, + "apnsID": responseWithMetadata.ApnsID, + }).Debug("retrying notification") inFlightNotificationInstance.sendAttempts.Add(1) <-time.After(a.retryInterval) if err := a.sendNotification(inFlightNotificationInstance.notification); err == nil { return nil } } - if uint(sendAttempts) > 0 { - a.consumptionManager.Resume(inFlightNotificationInstance.kafkaTopic) - } + responseWithMetadata.Metadata = inFlightNotificationInstance.notification.Metadata responseWithMetadata.Timestamp = responseWithMetadata.Metadata["timestamp"].(int64) delete(responseWithMetadata.Metadata, "timestamp") + + a.inFlightNotificationsMapLock.Lock() delete(a.InFlightNotificationsMap, responseWithMetadata.ApnsID) + a.inFlightNotificationsMapLock.Unlock() } apnsResMutex.Lock() diff --git a/extensions/apns_message_handler_test.go b/extensions/apns_message_handler_test.go index 381a950..571b29b 100644 --- a/extensions/apns_message_handler_test.go +++ b/extensions/apns_message_handler_test.go @@ -26,19 +26,19 @@ import ( "context" "encoding/json" "fmt" + uuid "github.com/satori/go.uuid" + "github.com/sideshow/apns2" + mock_interfaces "github.com/topfreegames/pusher/mocks/interfaces" "os" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - uuid "github.com/satori/go.uuid" - "github.com/sideshow/apns2" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/topfreegames/pusher/interfaces" "github.com/topfreegames/pusher/mocks" "github.com/topfreegames/pusher/structs" - . "github.com/topfreegames/pusher/testing" "github.com/topfreegames/pusher/util" ) @@ -50,6 +50,7 @@ var _ = FDescribe("APNS Message Handler", func() { var mockPushQueue *mocks.APNSPushQueueMock var mockStatsDClient *mocks.StatsDClientMock var statsClients []interfaces.StatsReporter + mockConsumptionManager := mock_interfaces.NewMockConsumptionManager() ctx := context.Background() configFile := os.Getenv("CONFIG_FILE") @@ -96,7 +97,7 @@ var _ = FDescribe("APNS Message Handler", func() { statsClients, feedbackClients, mockPushQueue, - nil, + mockConsumptionManager, ) Expect(err).NotTo(HaveOccurred()) db.(*mocks.PGMock).RowsReturned = 0 @@ -109,8 +110,10 @@ var _ = FDescribe("APNS Message Handler", func() { Expect(handler).NotTo(BeNil()) Expect(handler.Config).NotTo(BeNil()) Expect(handler.IsProduction).To(Equal(isProduction)) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(0))) Expect(handler.sentMessages).To(Equal(int64(0))) + apnsResMutex.Unlock() }) }) @@ -129,8 +132,10 @@ var _ = FDescribe("APNS Message Handler", func() { ApnsID: apnsID, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.successesReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ReasonUnregistered", func() { @@ -148,8 +153,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonUnregistered, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrBadDeviceToken", func() { @@ -167,8 +174,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonBadDeviceToken, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrBadCertificate", func() { @@ -186,8 +195,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonBadCertificate, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrBadCertificateEnvironment", func() { @@ -205,8 +216,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonBadCertificateEnvironment, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrForbidden", func() { @@ -224,8 +237,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonForbidden, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrMissingTopic", func() { @@ -243,8 +258,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonMissingTopic, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrTopicDisallowed", func() { @@ -262,8 +279,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonTopicDisallowed, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrDeviceTokenNotForTopic", func() { @@ -281,8 +300,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonDeviceTokenNotForTopic, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrIdleTimeout", func() { @@ -300,8 +321,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonIdleTimeout, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrShutdown", func() { @@ -319,8 +342,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonShutdown, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrInternalServerError", func() { @@ -338,8 +363,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonInternalServerError, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has error push.ErrServiceUnavailable", func() { @@ -357,8 +384,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonServiceUnavailable, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("if response has untracked error", func() { @@ -376,8 +405,10 @@ var _ = FDescribe("APNS Message Handler", func() { Reason: apns2.ReasonMethodNotAllowed, } handler.handleAPNSResponse(res) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) Expect(handler.failuresReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) }) @@ -480,9 +511,11 @@ var _ = FDescribe("APNS Message Handler", func() { Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), }) Expect(func() { go handler.CleanMetadataCache() }).ShouldNot(Panic()) - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Duration(config.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) + handler.inFlightNotificationsMapLock.Lock() Expect(*handler.requestsHeap).To(BeEmpty()) Expect(handler.InFlightNotificationsMap).To(BeEmpty()) + handler.inFlightNotificationsMapLock.Unlock() }) It("should not panic if a request got a response", func() { @@ -497,9 +530,12 @@ var _ = FDescribe("APNS Message Handler", func() { } handler.handleAPNSResponse(res) - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Duration(config.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) + + handler.inFlightNotificationsMapLock.Lock() Expect(*handler.requestsHeap).To(BeEmpty()) Expect(handler.InFlightNotificationsMap).To(BeEmpty()) + handler.inFlightNotificationsMapLock.Unlock() }) It("should handle all responses or remove them after timeout", func() { @@ -528,10 +564,12 @@ var _ = FDescribe("APNS Message Handler", func() { Expect(func() { go sendRequests() }).ShouldNot(Panic()) time.Sleep(10 * time.Millisecond) Expect(func() { go handleResponses() }).ShouldNot(Panic()) - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Duration(config.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) + handler.inFlightNotificationsMapLock.Lock() Expect(*handler.requestsHeap).To(BeEmpty()) Expect(handler.InFlightNotificationsMap).To(BeEmpty()) + handler.inFlightNotificationsMapLock.Unlock() }) }) @@ -542,11 +580,14 @@ var _ = FDescribe("APNS Message Handler", func() { handler.successesReceived = 60 handler.failuresReceived = 30 Expect(func() { go handler.LogStats() }).ShouldNot(Panic()) - Eventually(func() []logrus.Entry { return hook.Entries }).Should(ContainLogMessage("flushing stats")) + time.Sleep(2 * handler.LogStatsInterval) + + apnsResMutex.Lock() Eventually(func() int64 { return handler.sentMessages }).Should(Equal(int64(0))) Eventually(func() int64 { return handler.responsesReceived }).Should(Equal(int64(0))) Eventually(func() int64 { return handler.successesReceived }).Should(Equal(int64(0))) Eventually(func() int64 { return handler.failuresReceived }).Should(Equal(int64(0))) + apnsResMutex.Unlock() }) }) @@ -606,7 +647,6 @@ var _ = FDescribe("APNS Message Handler", func() { Describe("Feedback Reporter sent message", func() { BeforeEach(func() { mockKafkaProducerClient = mocks.NewKafkaProducerClientMock() - kc, err := NewKafkaProducer(config, logger, mockKafkaProducerClient) Expect(err).NotTo(HaveOccurred()) @@ -627,7 +667,7 @@ var _ = FDescribe("APNS Message Handler", func() { statsClients, feedbackClients, mockPushQueue, - nil, + mockConsumptionManager, ) Expect(err).NotTo(HaveOccurred()) }) @@ -643,7 +683,9 @@ var _ = FDescribe("APNS Message Handler", func() { "game": "game", "platform": "apns", } + handler.inFlightNotificationsMapLock.Lock() handler.InFlightNotificationsMap["idTest1"] = &inFlightNotification{notification: &Notification{Metadata: metadata}} + handler.inFlightNotificationsMapLock.Unlock() res := &structs.ResponseWithMetadata{ StatusCode: 200, ApnsID: "idTest1", @@ -665,7 +707,9 @@ var _ = FDescribe("APNS Message Handler", func() { "game": "game", "platform": "apns", } + handler.inFlightNotificationsMapLock.Lock() handler.InFlightNotificationsMap["idTest1"] = &inFlightNotification{notification: &Notification{Metadata: metadata}} + handler.inFlightNotificationsMapLock.Unlock() res := &structs.ResponseWithMetadata{ StatusCode: 200, ApnsID: "idTest1", @@ -680,6 +724,7 @@ var _ = FDescribe("APNS Message Handler", func() { }) It("should send feedback if success and metadata is not present", func() { + handler.inFlightNotificationsMapLock.Lock() handler.InFlightNotificationsMap["idTest1"] = &inFlightNotification{ notification: &Notification{ Metadata: map[string]interface{}{ @@ -687,6 +732,7 @@ var _ = FDescribe("APNS Message Handler", func() { }, }, } + handler.inFlightNotificationsMapLock.Unlock() res := &structs.ResponseWithMetadata{ StatusCode: 200, ApnsID: "idTest1", @@ -708,7 +754,10 @@ var _ = FDescribe("APNS Message Handler", func() { "game": "game", "platform": "apns", } + handler.inFlightNotificationsMapLock.Lock() handler.InFlightNotificationsMap["idTest1"] = &inFlightNotification{notification: &Notification{Metadata: metadata}} + handler.inFlightNotificationsMapLock.Unlock() + res := &structs.ResponseWithMetadata{ StatusCode: 400, ApnsID: "idTest1", @@ -732,7 +781,11 @@ var _ = FDescribe("APNS Message Handler", func() { "game": "game", "platform": "apns", } + + handler.inFlightNotificationsMapLock.Lock() handler.InFlightNotificationsMap["idTest1"] = &inFlightNotification{notification: &Notification{Metadata: metadata}} + handler.inFlightNotificationsMapLock.Unlock() + res := &structs.ResponseWithMetadata{ StatusCode: 400, ApnsID: "idTest1", @@ -765,6 +818,41 @@ var _ = FDescribe("APNS Message Handler", func() { Expect(fromKafka.Metadata).To(BeNil()) Expect(string(msg.Value)).To(ContainSubstring("BadDeviceToken")) }) + + It("should not deadlock on handle retry for handle apns response", func() { + metadata := map[string]interface{}{ + "some": "metadata", + "timestamp": time.Now().Unix(), + "game": "game", + "platform": "apns", + } + handler.inFlightNotificationsMapLock.Lock() + handler.InFlightNotificationsMap["idTest1"] = &inFlightNotification{notification: &Notification{Metadata: metadata}} + handler.InFlightNotificationsMap["idTest2"] = &inFlightNotification{notification: &Notification{Metadata: metadata}} + handler.inFlightNotificationsMapLock.Unlock() + + res := &structs.ResponseWithMetadata{ + StatusCode: 429, + ApnsID: "idTest1", + Reason: apns2.ReasonTooManyRequests, + } + + res2 := &structs.ResponseWithMetadata{ + StatusCode: 429, + ApnsID: "idTest2", + Reason: apns2.ReasonTooManyRequests, + } + go func() { + defer GinkgoRecover() + err := handler.handleAPNSResponse(res) + Expect(err).NotTo(HaveOccurred()) + }() + go func() { + defer GinkgoRecover() + err := handler.handleAPNSResponse(res2) + Expect(err).NotTo(HaveOccurred()) + }() + }) }) Describe("Cleanup", func() { @@ -803,9 +891,23 @@ var _ = FDescribe("APNS Message Handler", func() { Topic: "push-game_apns", Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), }) + + apnsResMutex.Lock() Expect(handler.ignoredMessages).To(Equal(int64(0))) Eventually(handler.PushQueue.ResponseChannel(), 5*time.Second).Should(Receive()) Expect(handler.sentMessages).To(Equal(int64(1))) + apnsResMutex.Unlock() + }) + + It("should be able to call HandleMessages concurrently with no errors", func() { + msg := interfaces.KafkaMessage{ + Topic: "push-game_apns", + Value: []byte(`{ "aps" : { "alert" : "Hello" } }`), + } + + go handler.HandleMessages(context.Background(), msg) + go handler.HandleMessages(context.Background(), msg) + go handler.HandleMessages(context.Background(), msg) }) }) @@ -816,8 +918,10 @@ var _ = FDescribe("APNS Message Handler", func() { Value: []byte(fmt.Sprintf(`{ "aps" : { "alert" : "Hello HTTP/2" }, "push_expiry": %d }`, MakeTimestamp()-int64(100))), }) Eventually(handler.PushQueue.ResponseChannel(), 5*time.Second).ShouldNot(Receive()) + apnsResMutex.Lock() Expect(handler.sentMessages).To(Equal(int64(0))) Expect(handler.ignoredMessages).To(Equal(int64(1))) + apnsResMutex.Unlock() }) It("should send message if PushExpiry is in the future", func() { handler.HandleMessages(ctx, interfaces.KafkaMessage{ @@ -825,7 +929,10 @@ var _ = FDescribe("APNS Message Handler", func() { Value: []byte(fmt.Sprintf(`{ "aps" : { "alert" : "Hello HTTP/2" }, "push_expiry": %d}`, MakeTimestamp()+int64(100))), }) Eventually(handler.PushQueue.ResponseChannel(), 5*time.Second).ShouldNot(Receive()) + + apnsResMutex.Lock() Expect(handler.sentMessages).To(Equal(int64(1))) + apnsResMutex.Unlock() }) }) @@ -834,7 +941,9 @@ var _ = FDescribe("APNS Message Handler", func() { Expect(func() { go handler.HandleResponses() }).ShouldNot(Panic()) handler.PushQueue.ResponseChannel() <- &structs.ResponseWithMetadata{} time.Sleep(50 * time.Millisecond) + apnsResMutex.Lock() Expect(handler.responsesReceived).To(Equal(int64(1))) + apnsResMutex.Unlock() }) }) diff --git a/extensions/apns_push_queue.go b/extensions/apns_push_queue.go index 10cc5fe..d5d890c 100644 --- a/extensions/apns_push_queue.go +++ b/extensions/apns_push_queue.go @@ -66,7 +66,10 @@ func NewAPNSPushQueue( // Configure configures queues and token func (p *APNSPushQueue) Configure() error { - l := p.Logger.WithField("method", "configure") + l := p.Logger.WithFields(log.Fields{ + "source": "APNSPushQueue", + "method": "configure", + }) err := p.configureCertificate() if err != nil { return err @@ -77,8 +80,10 @@ func (p *APNSPushQueue) Configure() error { for i := 0; i < connectionPoolSize; i++ { client := apns2.NewTokenClient(p.token) if p.IsProduction { + l.Debug("using production") client = client.Production() } else { + l.Debug("using development") client = client.Development() } p.clients <- client diff --git a/extensions/apns_push_queue_test.go b/extensions/apns_push_queue_test.go index c52e11d..fed1e77 100644 --- a/extensions/apns_push_queue_test.go +++ b/extensions/apns_push_queue_test.go @@ -32,7 +32,7 @@ import ( "github.com/topfreegames/pusher/util" ) -var _ = Describe("APNS Message Handler", func() { +var _ = Describe("APNS Push Queue", func() { var queue *APNSPushQueue configFile := os.Getenv("CONFIG_FILE") diff --git a/extensions/gcm_message_handler_test.go b/extensions/gcm_message_handler_test.go index 4f8532d..0b3dada 100644 --- a/extensions/gcm_message_handler_test.go +++ b/extensions/gcm_message_handler_test.go @@ -24,6 +24,7 @@ package extensions import ( "encoding/json" + "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/suite" "github.com/topfreegames/pusher/config" @@ -32,7 +33,6 @@ import ( "time" uuid "github.com/satori/go.uuid" - "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/topfreegames/go-gcm" "github.com/topfreegames/pusher/interfaces" @@ -45,14 +45,6 @@ type GCMMessageHandlerTestSuite struct { config *config.Config vConfig *viper.Viper game string - logger *logrus.Logger - hooks *test.Hook - - mockClient *mocks.GCMClientMock - mockStatsdClient *mocks.StatsDClientMock - mockKafkaProducer *mocks.KafkaProducerClientMock - - handler *GCMMessageHandler } func TestGCMMessageHandlerSuite(t *testing.T) { @@ -72,31 +64,34 @@ func (s *GCMMessageHandlerTestSuite) SetupSuite() { s.game = "game" } -func (s *GCMMessageHandlerTestSuite) SetupSubTest() { - s.logger, s.hooks = test.NewNullLogger() - - s.mockClient = mocks.NewGCMClientMock() - - s.mockStatsdClient = mocks.NewStatsDClientMock() - statsD, err := NewStatsD(s.vConfig, s.logger, s.mockStatsdClient) +func (s *GCMMessageHandlerTestSuite) setupHandler() ( + *GCMMessageHandler, + *mocks.GCMClientMock, + *mocks.StatsDClientMock, + *mocks.KafkaProducerClientMock, +) { + logger, _ := test.NewNullLogger() + mockClient := mocks.NewGCMClientMock() + mockStatsdClient := mocks.NewStatsDClientMock() + + statsD, err := NewStatsD(s.vConfig, logger, mockStatsdClient) s.Require().NoError(err) - s.mockKafkaProducer = mocks.NewKafkaProducerClientMock() - kc, err := NewKafkaProducer(s.vConfig, s.logger, s.mockKafkaProducer) + mockKafkaProducer := mocks.NewKafkaProducerClientMock() + kc, err := NewKafkaProducer(s.vConfig, logger, mockKafkaProducer) s.Require().NoError(err) statsClients := []interfaces.StatsReporter{statsD} feedbackClients := []interfaces.FeedbackReporter{kc} - handler, err := NewGCMMessageHandlerWithClient( s.game, false, s.vConfig, - s.logger, + logger, nil, statsClients, feedbackClients, - s.mockClient, + mockClient, ) s.NoError(err) s.Require().NotNil(handler) @@ -105,9 +100,9 @@ func (s *GCMMessageHandlerTestSuite) SetupSubTest() { s.False(handler.IsProduction) s.Equal(int64(0), handler.responsesReceived) s.Equal(int64(0), handler.sentMessages) - s.Len(s.mockClient.MessagesSent, 0) + s.Len(mockClient.MessagesSent, 0) - s.handler = handler + return handler, mockClient, mockStatsdClient, mockKafkaProducer } func (s *GCMMessageHandlerTestSuite) TestConfigureHandler() { @@ -116,7 +111,7 @@ func (s *GCMMessageHandlerTestSuite) TestConfigureHandler() { s.game, false, s.vConfig, - s.logger, + logrus.New(), nil, []interfaces.StatsReporter{}, []interfaces.FeedbackReporter{}, @@ -129,105 +124,115 @@ func (s *GCMMessageHandlerTestSuite) TestConfigureHandler() { func (s *GCMMessageHandlerTestSuite) TestHandleGCMResponse() { s.Run("should succeed if response has no error", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{} - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.NoError(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.successesReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.successesReceived) }) s.Run("if response has error DEVICE_UNREGISTERED", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ Error: "DEVICE_UNREGISTERED", } - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.Error(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.failuresReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.failuresReceived) }) s.Run("if response has error BAD_REGISTRATION", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ Error: "BAD_REGISTRATION", } - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.Error(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.failuresReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.failuresReceived) }) s.Run("if response has error INVALID_JSON", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ Error: "INVALID_JSON", } - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.Error(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.failuresReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.failuresReceived) }) s.Run("if response has error SERVICE_UNAVAILABLE", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ Error: "SERVICE_UNAVAILABLE", } - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.Error(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.failuresReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.failuresReceived) }) s.Run("if response has error INTERNAL_SERVER_ERROR", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ Error: "INTERNAL_SERVER_ERROR", } - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.Error(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.failuresReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.failuresReceived) }) s.Run("if response has error DEVICE_MESSAGE_RATE_EXCEEDED", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ Error: "DEVICE_MESSAGE_RATE_EXCEEDED", } - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.Error(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.failuresReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.failuresReceived) }) s.Run("if response has error TOPICS_MESSAGE_RATE_EXCEEDED", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ Error: "TOPICS_MESSAGE_RATE_EXCEEDED", } - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.Error(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.failuresReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.failuresReceived) }) s.Run("if response has untracked error", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ Error: "BAD_ACK", } - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.handleGCMResponse(res) + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.handleGCMResponse(res) s.Error(err) - s.Equal(int64(1), s.handler.responsesReceived) - s.Equal(int64(1), s.handler.failuresReceived) + s.Equal(int64(1), handler.responsesReceived) + s.Equal(int64(1), handler.failuresReceived) }) } func (s *GCMMessageHandlerTestSuite) TestSendMessage() { s.Run("should not send message if expire is in the past", func() { + handler, _, _, _ := s.setupHandler() ttl := uint(0) metadata := map[string]interface{}{ "some": "metadata", @@ -251,17 +256,17 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { msgBytes, err := json.Marshal(msg) s.Require().NoError(err) - err = s.handler.sendMessage(interfaces.KafkaMessage{ + err = handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: msgBytes, }) s.Require().NoError(err) - s.Equal(int64(0), s.handler.sentMessages) - s.Equal(int64(1), s.handler.ignoredMessages) - s.Contains(s.hooks.LastEntry().Message, "ignoring push") + s.Equal(int64(0), handler.sentMessages) + s.Equal(int64(1), handler.ignoredMessages) }) s.Run("should send message if PushExpiry is in the future", func() { + handler, _, _, _ := s.setupHandler() ttl := uint(0) metadata := map[string]interface{}{ "some": "metadata", @@ -285,28 +290,33 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { msgBytes, err := json.Marshal(msg) s.Require().NoError(err) - err = s.handler.sendMessage(interfaces.KafkaMessage{ + err = handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: msgBytes, }) s.Require().NoError(err) - s.Equal(int64(1), s.handler.sentMessages) - s.Equal(int64(0), s.handler.ignoredMessages) + gcmResMutex.Lock() + s.Equal(int64(1), handler.sentMessages) + s.Equal(int64(0), handler.ignoredMessages) + gcmResMutex.Unlock() }) s.Run("should send message and not increment sentMessages if an error occurs", func() { - err := s.handler.sendMessage(interfaces.KafkaMessage{ + handler, mockClient, _, _ := s.setupHandler() + err := handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", - Value: []byte("gogogo"), + Value: []byte("value"), }) s.Require().Error(err) - s.Equal(int64(0), s.handler.sentMessages) - s.Equal(s.hooks.LastEntry().Message, "Error unmarshalling message.") - s.Len(s.mockClient.MessagesSent, 0) - s.Len(s.handler.pendingMessages, 0) + gcmResMutex.Lock() + s.Equal(int64(0), handler.sentMessages) + s.Len(handler.pendingMessages, 0) + gcmResMutex.Unlock() + s.Len(mockClient.MessagesSent, 0) }) s.Run("should send xmpp message", func() { + handler, mockClient, _, _ := s.setupHandler() ttl := uint(0) msg := &interfaces.Message{ TimeToLive: &ttl, @@ -320,17 +330,20 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { msgBytes, err := json.Marshal(msg) s.Require().NoError(err) - err = s.handler.sendMessage(interfaces.KafkaMessage{ + err = handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: msgBytes, }) s.Require().NoError(err) - s.Equal(int64(1), s.handler.sentMessages) - s.Len(s.mockClient.MessagesSent, 1) - s.Len(s.handler.pendingMessages, 1) + gcmResMutex.Lock() + s.Equal(int64(1), handler.sentMessages) + s.Len(mockClient.MessagesSent, 1) + s.Len(handler.pendingMessages, 1) + gcmResMutex.Unlock() }) s.Run("should send xmpp message with metadata", func() { + handler, mockClient, _, _ := s.setupHandler() ttl := uint(0) metadata := map[string]interface{}{ "some": "metadata", @@ -354,17 +367,20 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { msgBytes, err := json.Marshal(msg) s.Require().NoError(err) - err = s.handler.sendMessage(interfaces.KafkaMessage{ + err = handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: msgBytes, }) s.Require().NoError(err) - s.Equal(int64(1), s.handler.sentMessages) - s.Len(s.mockClient.MessagesSent, 1) - s.Len(s.handler.pendingMessages, 1) + gcmResMutex.Lock() + s.Equal(int64(1), handler.sentMessages) + s.Len(mockClient.MessagesSent, 1) + s.Len(handler.pendingMessages, 1) + gcmResMutex.Unlock() }) s.Run("should forward metadata content on GCM request", func() { + handler, mockClient, _, _ := s.setupHandler() ttl := uint(0) metadata := map[string]interface{}{ "some": "metadata", @@ -385,22 +401,25 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { msgBytes, err := json.Marshal(msg) s.Require().NoError(err) - err = s.handler.sendMessage(interfaces.KafkaMessage{ + err = handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: msgBytes, }) s.Require().NoError(err) - s.Equal(int64(1), s.handler.sentMessages) - s.Len(s.mockClient.MessagesSent, 1) - s.Len(s.handler.pendingMessages, 1) + gcmResMutex.Lock() + s.Equal(int64(1), handler.sentMessages) + s.Len(mockClient.MessagesSent, 1) + s.Len(handler.pendingMessages, 1) + gcmResMutex.Unlock() - sentMessage := s.mockClient.MessagesSent[0] + sentMessage := mockClient.MessagesSent[0] s.NotNil(sentMessage) s.Equal("metadata", sentMessage.Data["some"]) }) s.Run("should forward nested metadata content on GCM request", func() { + handler, mockClient, _, _ := s.setupHandler() ttl := uint(0) metadata := map[string]interface{}{ "some": "metadata", @@ -423,17 +442,19 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { msgBytes, err := json.Marshal(msg) s.Require().NoError(err) - err = s.handler.sendMessage(interfaces.KafkaMessage{ + err = handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: msgBytes, }) s.Require().NoError(err) - s.Equal(int64(1), s.handler.sentMessages) - s.Len(s.mockClient.MessagesSent, 1) - s.Len(s.handler.pendingMessages, 1) + gcmResMutex.Lock() + s.Equal(int64(1), handler.sentMessages) + s.Len(mockClient.MessagesSent, 1) + s.Len(handler.pendingMessages, 1) + gcmResMutex.Unlock() - sentMessage := s.mockClient.MessagesSent[0] + sentMessage := mockClient.MessagesSent[0] s.NotNil(sentMessage) s.Equal("metadata", sentMessage.Data["some"]) s.Len(sentMessage.Data["nested"], 1) @@ -441,6 +462,7 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { }) s.Run("should wait to send message if maxPendingMessages is reached", func() { + handler, _, _, _ := s.setupHandler() ttl := uint(0) msg := &gcm.XMPPMessage{ TimeToLive: &ttl, @@ -453,23 +475,30 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { s.NoError(err) for i := 1; i <= 3; i++ { - err = s.handler.sendMessage(interfaces.KafkaMessage{ + err = handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: msgBytes, }) s.NoError(err) - s.Equal(int64(i), s.handler.sentMessages) - s.Equal(i, len(s.handler.pendingMessages)) + s.Equal(int64(i), handler.sentMessages) + s.Equal(i, len(handler.pendingMessages)) } - go s.handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) + go func() { + err := handler.sendMessage(interfaces.KafkaMessage{ + Topic: "push-game_gcm", + Value: msgBytes, + }) + s.Require().NoError(err) + }() - <-s.handler.pendingMessages + <-handler.pendingMessages s.Eventually( - func() bool { return s.handler.sentMessages == 4 }, + func() bool { + gcmResMutex.Lock() + defer gcmResMutex.Unlock() + return handler.sentMessages == 4 + }, 5*time.Second, 100*time.Millisecond, ) @@ -478,30 +507,34 @@ func (s *GCMMessageHandlerTestSuite) TestSendMessage() { func (s *GCMMessageHandlerTestSuite) TestCleanCache() { s.Run("should remove from push queue after timeout", func() { - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.sendMessage(interfaces.KafkaMessage{ + handler, _, _, mockKafkaProducer := s.setupHandler() + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), }) s.Require().NoError(err) - go s.handler.CleanMetadataCache() + go handler.CleanMetadataCache() - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Duration(s.vConfig.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) - s.Empty(s.handler.requestsHeap) - s.Empty(s.handler.InflightMessagesMetadata) + s.True(handler.requestsHeap.Empty()) + handler.inflightMessagesMetadataLock.Lock() + s.Empty(handler.InflightMessagesMetadata) + handler.inflightMessagesMetadataLock.Unlock() }) s.Run("should succeed if request gets a response", func() { - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := s.handler.sendMessage(interfaces.KafkaMessage{ + handler, _, _, mockKafkaProducer := s.setupHandler() + mockKafkaProducer.StartConsumingMessagesInProduceChannel() + err := handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), }) s.Require().NoError(err) - go s.handler.CleanMetadataCache() + go handler.CleanMetadataCache() res := gcm.CCSMessage{ From: "testToken1", @@ -509,21 +542,24 @@ func (s *GCMMessageHandlerTestSuite) TestCleanCache() { MessageType: "ack", Category: "testCategory", } - err = s.handler.handleGCMResponse(res) + err = handler.handleGCMResponse(res) s.Require().NoError(err) - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Duration(s.vConfig.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) - s.Empty(s.handler.requestsHeap) - s.Empty(s.handler.InflightMessagesMetadata) + s.True(handler.requestsHeap.Empty()) + handler.inflightMessagesMetadataLock.Lock() + s.Empty(handler.InflightMessagesMetadata) + handler.inflightMessagesMetadataLock.Unlock() }) s.Run("should handle all responses or remove them after timeout", func() { - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() + handler, _, _, mockKafkaProducer := s.setupHandler() + mockKafkaProducer.StartConsumingMessagesInProduceChannel() n := 10 sendRequests := func() { for i := 0; i < n; i++ { - err := s.handler.sendMessage(interfaces.KafkaMessage{ + err := handler.sendMessage(interfaces.KafkaMessage{ Topic: "push-game_gcm", Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), }) @@ -540,54 +576,52 @@ func (s *GCMMessageHandlerTestSuite) TestCleanCache() { Category: "testCategory", } - err := s.handler.handleGCMResponse(res) + err := handler.handleGCMResponse(res) s.Require().NoError(err) } } - go s.handler.CleanMetadataCache() + go handler.CleanMetadataCache() go sendRequests() - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Duration(s.vConfig.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) go handleResponses() - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Duration(s.vConfig.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) - s.Empty(s.handler.requestsHeap) - s.Empty(s.handler.InflightMessagesMetadata) + s.True(handler.requestsHeap.Empty()) + handler.inflightMessagesMetadataLock.Lock() + s.Empty(handler.InflightMessagesMetadata) + handler.inflightMessagesMetadataLock.Unlock() }) } func (s *GCMMessageHandlerTestSuite) TestLogStats() { s.Run("should log stats and reset them", func() { - s.handler.sentMessages = 100 - s.handler.responsesReceived = 90 - s.handler.successesReceived = 60 - s.handler.failuresReceived = 30 - s.handler.ignoredMessages = 10 + handler, _, _, _ := s.setupHandler() + handler.sentMessages = 100 + handler.responsesReceived = 90 + handler.successesReceived = 60 + handler.failuresReceived = 30 + handler.ignoredMessages = 10 + + go handler.LogStats() - go s.handler.LogStats() s.Eventually(func() bool { - for _, e := range s.hooks.Entries { - if e.Message == "flushing stats" { - return true - } - } - return false - }, - time.Second, - time.Millisecond*100, - ) - s.Eventually(func() bool { return s.handler.sentMessages == int64(0) }, time.Second, time.Millisecond*100) - s.Eventually(func() bool { return s.handler.responsesReceived == int64(0) }, time.Second, time.Millisecond*100) - s.Eventually(func() bool { return s.handler.successesReceived == int64(0) }, time.Second, time.Millisecond*100) - s.Eventually(func() bool { return s.handler.failuresReceived == int64(0) }, time.Second, time.Millisecond*100) - s.Eventually(func() bool { return s.handler.ignoredMessages == int64(0) }, time.Second, time.Millisecond*100) + gcmResMutex.Lock() + defer gcmResMutex.Unlock() + return handler.sentMessages == int64(0) && + handler.responsesReceived == int64(0) && + handler.successesReceived == int64(0) && + handler.failuresReceived == int64(0) && + handler.ignoredMessages == int64(0) + }, time.Second, time.Millisecond*100) }) } func (s *GCMMessageHandlerTestSuite) TestStatsReporter() { s.Run("should call HandleNotificationSent upon message sent to queue", func() { - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() + handler, _, mockStatsdClient, mockKafkaProducer := s.setupHandler() + mockKafkaProducer.StartConsumingMessagesInProduceChannel() ttl := uint(0) msg := &gcm.XMPPMessage{ TimeToLive: &ttl, @@ -604,40 +638,43 @@ func (s *GCMMessageHandlerTestSuite) TestStatsReporter() { Topic: "push-game_gcm", Value: msgBytes, } - err = s.handler.sendMessage(kafkaMessage) + err = handler.sendMessage(kafkaMessage) s.NoError(err) - err = s.handler.sendMessage(kafkaMessage) + err = handler.sendMessage(kafkaMessage) s.NoError(err) - s.Equal(int64(2), s.mockStatsdClient.Counts["sent"]) + s.Equal(int64(2), mockStatsdClient.Counts["sent"]) }) s.Run("should call HandleNotificationSuccess upon response received", func() { - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() + handler, _, mockStatsdClient, mockKafkaProducer := s.setupHandler() + mockKafkaProducer.StartConsumingMessagesInProduceChannel() res := gcm.CCSMessage{} - err := s.handler.handleGCMResponse(res) + err := handler.handleGCMResponse(res) s.Require().NoError(err) - err = s.handler.handleGCMResponse(res) + err = handler.handleGCMResponse(res) s.Require().NoError(err) - s.Equal(int64(2), s.mockStatsdClient.Counts["ack"]) + s.Equal(int64(2), mockStatsdClient.Counts["ack"]) }) s.Run("should call HandleNotificationFailure upon error response received", func() { - s.mockKafkaProducer.StartConsumingMessagesInProduceChannel() + handler, _, mockStatsdClient, mockKafkaProducer := s.setupHandler() + mockKafkaProducer.StartConsumingMessagesInProduceChannel() res := gcm.CCSMessage{ Error: "DEVICE_UNREGISTERED", } - s.handler.handleGCMResponse(res) - s.handler.handleGCMResponse(res) + err := handler.handleGCMResponse(res) + s.Error(err) - s.Equal(int64(2), s.mockStatsdClient.Counts["failed"]) + s.Equal(int64(1), mockStatsdClient.Counts["failed"]) }) } func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { s.Run("should include a timestamp in feedback root and the hostname in metadata", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() timestampNow := time.Now().Unix() hostname, err := os.Hostname() s.Require().NoError(err) @@ -649,17 +686,22 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { "game": "game", "platform": "gcm", } - s.handler.InflightMessagesMetadata["idTest1"] = metadata + handler.inflightMessagesMetadataLock.Lock() + handler.InflightMessagesMetadata["idTest1"] = metadata + handler.inflightMessagesMetadataLock.Unlock() res := gcm.CCSMessage{ From: "testToken1", MessageID: "idTest1", MessageType: "ack", Category: "testCategory", } - go s.handler.handleGCMResponse(res) + go func() { + err := handler.handleGCMResponse(res) + s.Require().NoError(err) + }() fromKafka := &CCSMessageWithMetadata{} - msg := <-s.mockKafkaProducer.ProduceChannel() + msg := <-mockKafkaProducer.ProduceChannel() err = json.Unmarshal(msg.Value, fromKafka) s.Require().NoError(err) s.Equal(timestampNow, fromKafka.Timestamp) @@ -667,23 +709,31 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { }) s.Run("should send feedback if success and metadata is present", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() metadata := map[string]interface{}{ "some": "metadata", "timestamp": time.Now().Unix(), "game": "game", "platform": "gcm", } - s.handler.InflightMessagesMetadata["idTest1"] = metadata + + handler.inflightMessagesMetadataLock.Lock() + handler.InflightMessagesMetadata["idTest1"] = metadata + handler.inflightMessagesMetadataLock.Unlock() + res := gcm.CCSMessage{ From: "testToken1", MessageID: "idTest1", MessageType: "ack", Category: "testCategory", } - go s.handler.handleGCMResponse(res) + go func() { + err := handler.handleGCMResponse(res) + s.Require().NoError(err) + }() fromKafka := &CCSMessageWithMetadata{} - msg := <-s.mockKafkaProducer.ProduceChannel() + msg := <-mockKafkaProducer.ProduceChannel() err := json.Unmarshal(msg.Value, fromKafka) s.Require().NoError(err) s.Equal(res.From, fromKafka.From) @@ -694,16 +744,20 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { }) s.Run("should send feedback if success and metadata is not present", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ From: "testToken1", MessageID: "idTest1", MessageType: "ack", Category: "testCategory", } - go s.handler.handleGCMResponse(res) + go func() { + err := handler.handleGCMResponse(res) + s.Require().NoError(err) + }() fromKafka := &CCSMessageWithMetadata{} - msg := <-s.mockKafkaProducer.ProduceChannel() + msg := <-mockKafkaProducer.ProduceChannel() err := json.Unmarshal(msg.Value, fromKafka) s.Require().NoError(err) s.Equal(res.From, fromKafka.From) @@ -714,13 +768,18 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { }) s.Run("should send feedback if error and metadata is present and token should be deleted", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() metadata := map[string]interface{}{ "some": "metadata", "timestamp": time.Now().Unix(), "game": "game", "platform": "gcm", } - s.handler.InflightMessagesMetadata["idTest1"] = metadata + + handler.inflightMessagesMetadataLock.Lock() + handler.InflightMessagesMetadata["idTest1"] = metadata + handler.inflightMessagesMetadataLock.Unlock() + res := gcm.CCSMessage{ From: "testToken1", MessageID: "idTest1", @@ -728,10 +787,13 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { Category: "testCategory", Error: "BAD_REGISTRATION", } - go s.handler.handleGCMResponse(res) + go func() { + err := handler.handleGCMResponse(res) + s.Error(err) + }() fromKafka := &CCSMessageWithMetadata{} - msg := <-s.mockKafkaProducer.ProduceChannel() + msg := <-mockKafkaProducer.ProduceChannel() err := json.Unmarshal(msg.Value, fromKafka) s.Require().NoError(err) s.Equal(res.From, fromKafka.From) @@ -744,13 +806,18 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { }) s.Run("should send feedback if error and metadata is present and token should not be deleted", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() metadata := map[string]interface{}{ "some": "metadata", "timestamp": time.Now().Unix(), "game": "game", "platform": "gcm", } - s.handler.InflightMessagesMetadata["idTest1"] = metadata + + handler.inflightMessagesMetadataLock.Lock() + handler.InflightMessagesMetadata["idTest1"] = metadata + handler.inflightMessagesMetadataLock.Unlock() + res := gcm.CCSMessage{ From: "testToken1", MessageID: "idTest1", @@ -758,10 +825,13 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { Category: "testCategory", Error: "INVALID_JSON", } - go s.handler.handleGCMResponse(res) + go func() { + err := handler.handleGCMResponse(res) + s.Error(err) + }() fromKafka := &CCSMessageWithMetadata{} - msg := <-s.mockKafkaProducer.ProduceChannel() + msg := <-mockKafkaProducer.ProduceChannel() err := json.Unmarshal(msg.Value, fromKafka) s.Require().NoError(err) s.Equal(res.From, fromKafka.From) @@ -773,6 +843,7 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { s.Nil(fromKafka.Metadata["deleteToken"]) }) s.Run("should send feedback if error and metadata is not present", func() { + handler, _, _, mockKafkaProducer := s.setupHandler() res := gcm.CCSMessage{ From: "testToken1", MessageID: "idTest1", @@ -780,10 +851,13 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { Category: "testCategory", Error: "BAD_REGISTRATION", } - go s.handler.handleGCMResponse(res) + go func() { + err := handler.handleGCMResponse(res) + s.Error(err) + }() fromKafka := &CCSMessageWithMetadata{} - msg := <-s.mockKafkaProducer.ProduceChannel() + msg := <-mockKafkaProducer.ProduceChannel() err := json.Unmarshal(msg.Value, fromKafka) s.Require().NoError(err) s.Equal(res.From, fromKafka.From) @@ -794,11 +868,3 @@ func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { s.Nil(fromKafka.Metadata) }) } - -//func (s *GCMMessageHandlerTestSuite) TestCleanup() { -// s.Run("should close GCM client without errors", func() { -// err := s.handler.Cleanup() -// s.NoError(err) -// s.True(s.mockClient.Closed) -// }) -//} diff --git a/extensions/handler/message_handler_test.go b/extensions/handler/message_handler_test.go index 447af61..fbece07 100644 --- a/extensions/handler/message_handler_test.go +++ b/extensions/handler/message_handler_test.go @@ -88,9 +88,11 @@ func (s *MessageHandlerTestSuite) TestSendMessage() { } s.handler.HandleMessages(ctx, message) + s.handler.statsMutex.Lock() s.Equal(int64(0), s.handler.stats.sent) s.Equal(int64(0), s.handler.stats.failures) s.Equal(int64(0), s.handler.stats.ignored) + s.handler.statsMutex.Unlock() }) s.Run("should ignore message if it has expired", func() { @@ -103,7 +105,9 @@ func (s *MessageHandlerTestSuite) TestSendMessage() { s.Require().NoError(err) s.handler.HandleMessages(ctx, interfaces.KafkaMessage{Value: bytes}) + s.handler.statsMutex.Lock() s.Equal(int64(1), s.handler.stats.ignored) + s.handler.statsMutex.Unlock() }) s.Run("should report failure if cannot send message", func() { @@ -160,7 +164,10 @@ func (s *MessageHandlerTestSuite) TestSendMessage() { s.Fail("did not send feedback to kafka") } + s.handler.statsMutex.Lock() s.Equal(int64(1), s.handler.stats.failures) + s.handler.statsMutex.Unlock() + s.Equal(int64(1), s.mockStatsdClient.Counts["failed"]) }) @@ -218,7 +225,9 @@ func (s *MessageHandlerTestSuite) TestSendMessage() { s.Fail("did not send feedback to kafka") } + s.handler.statsMutex.Lock() s.Equal(int64(1), s.handler.stats.sent) + s.handler.statsMutex.Unlock() s.Equal(int64(1), s.mockStatsdClient.Counts["sent"]) s.Equal(int64(1), s.mockStatsdClient.Counts["ack"]) }) diff --git a/extensions/kafka_consumer.go b/extensions/kafka_consumer.go index 696a7b8..332a8b8 100644 --- a/extensions/kafka_consumer.go +++ b/extensions/kafka_consumer.go @@ -23,8 +23,10 @@ package extensions import ( + "context" "fmt" "sync" + "time" "github.com/confluentinc/confluent-kafka-go/v2/kafka" "github.com/getsentry/raven-go" @@ -46,12 +48,10 @@ type KafkaConsumer struct { Logger *logrus.Logger FetchMinBytes int FetchWaitMaxMs int - messagesReceived int64 msgChan chan interfaces.KafkaMessage SessionTimeout int pendingMessagesWG *sync.WaitGroup stopChannel chan struct{} - run bool HandleAllMessagesBeforeExiting bool } @@ -64,8 +64,7 @@ func NewKafkaConsumer( ) (*KafkaConsumer, error) { q := &KafkaConsumer{ Config: config, - Logger: logger, - messagesReceived: 0, + Logger: logger.WithField("source", "extensions.KafkaConsumer").Logger, pendingMessagesWG: nil, stopChannel: *stopChannel, } @@ -126,7 +125,6 @@ func (q *KafkaConsumer) configureConsumer(client interfaces.KafkaConsumerClient) "session.timeout.ms": q.SessionTimeout, "fetch.min.bytes": q.FetchMinBytes, "fetch.wait.max.ms": q.FetchWaitMaxMs, - "enable.auto.commit": true, "default.topic.config": kafka.ConfigMap{ "auto.offset.reset": q.OffsetResetStrategy, }, @@ -141,10 +139,8 @@ func (q *KafkaConsumer) configureConsumer(client interfaces.KafkaConsumerClient) "fetch.min.bytes": q.FetchMinBytes, "fetch.wait.max.ms": q.FetchWaitMaxMs, "session.timeout.ms": q.SessionTimeout, - "enable.auto.commit": true, "default.topic.config": kafka.ConfigMap{ - "auto.offset.reset": q.OffsetResetStrategy, - "enable.auto.commit": true, + "auto.offset.reset": q.OffsetResetStrategy, }, }) if err != nil { @@ -166,7 +162,7 @@ func (q *KafkaConsumer) PendingMessagesWaitGroup() *sync.WaitGroup { // StopConsuming stops consuming messages from the queue func (q *KafkaConsumer) StopConsuming() { - q.run = false + close(q.stopChannel) } func (q *KafkaConsumer) Pause(topic string) error { @@ -209,14 +205,16 @@ func (q *KafkaConsumer) MessagesChannel() *chan interfaces.KafkaMessage { } // ConsumeLoop consume messages from the queue and put in messages to send channel -func (q *KafkaConsumer) ConsumeLoop() error { - q.run = true +func (q *KafkaConsumer) ConsumeLoop(ctx context.Context) error { l := q.Logger.WithFields(logrus.Fields{ "method": "ConsumeLoop", "topics": q.Topics, }) - err := q.Consumer.SubscribeTopics(q.Topics, nil) + err := q.Consumer.SubscribeTopics(q.Topics, func(_ *kafka.Consumer, event kafka.Event) error { + l.WithField("event", event.String()).Debug("got event from Kafka") + return nil + }) if err != nil { l.WithError(err).Error("error subscribing to topics") @@ -226,33 +224,46 @@ func (q *KafkaConsumer) ConsumeLoop() error { l.Info("successfully subscribed to topics") //nolint[:gosimple] - for q.run { - message, err := q.Consumer.ReadMessage(100) - if message == nil && err.(kafka.Error).IsTimeout() { - continue - } - if err != nil { - q.handleError(err) - continue + for { + select { + case <-q.stopChannel: + l.Info("stopping kafka consumer") + return nil + case <-ctx.Done(): + l.Info("context done. Will stop consuming messages.") + return nil + default: + message, err := q.Consumer.ReadMessage(100 * time.Millisecond) + if message == nil && err.(kafka.Error).IsTimeout() { + continue + } + if err != nil { + q.handleError(err) + continue + } + l.Debug("got message from Kafka") + q.receiveMessage(message.TopicPartition, message.Value) + + _, err = q.Consumer.CommitMessage(message) + if err != nil { + q.handleError(err) + return fmt.Errorf("error committing message: %s", err.Error()) + } } - q.receiveMessage(message.TopicPartition, message.Value) } - return nil } func (q *KafkaConsumer) receiveMessage(topicPartition kafka.TopicPartition, value []byte) { l := q.Logger.WithFields(logrus.Fields{ - "method": "receiveMessage", + "method": "receiveMessage", + "topic": *topicPartition.Topic, + "partitionKey": topicPartition.Partition, + "jsonValue": string(value), }) l.Debug("Processing received message...") - q.messagesReceived++ - if q.messagesReceived%1000 == 0 { - l.Infof("messages from kafka: %d", q.messagesReceived) - } - l.Debugf("message on %s:\n%s\n", topicPartition, string(value)) if q.pendingMessagesWG != nil { q.pendingMessagesWG.Add(1) } @@ -265,7 +276,7 @@ func (q *KafkaConsumer) receiveMessage(topicPartition kafka.TopicPartition, valu q.msgChan <- message - l.Debug("Received message processed.") + l.Debug("added message to channel") } func (q *KafkaConsumer) handleError(err error) { @@ -282,9 +293,6 @@ func (q *KafkaConsumer) handleError(err error) { // Cleanup closes kafka consumer connection func (q *KafkaConsumer) Cleanup() error { - if q.run { - q.StopConsuming() - } if q.Consumer != nil { err := q.Consumer.Close() if err != nil { diff --git a/extensions/kafka_consumer_test.go b/extensions/kafka_consumer_test.go index 36196e3..a582139 100644 --- a/extensions/kafka_consumer_test.go +++ b/extensions/kafka_consumer_test.go @@ -1,6 +1,7 @@ package extensions import ( + "context" "fmt" "os" "time" @@ -31,7 +32,8 @@ var _ = Describe("Kafka Extension", func() { startConsuming := func() { go func() { defer GinkgoRecover() - consumer.ConsumeLoop() + goFuncErr := consumer.ConsumeLoop(context.Background()) + Expect(goFuncErr).NotTo(HaveOccurred()) }() time.Sleep(5 * time.Millisecond) } @@ -74,24 +76,27 @@ var _ = Describe("Kafka Extension", func() { Describe("Stop consuming", func() { It("should stop consuming", func() { - consumer.run = true consumer.StopConsuming() - Expect(consumer.run).To(BeFalse()) + Expect(consumer.stopChannel).To(BeClosed()) }) }) Describe("Consume loop", func() { It("should fail if subscribing to topic fails", func() { kafkaConsumerClientMock.Error = fmt.Errorf("could not subscribe") - err := consumer.ConsumeLoop() + err := consumer.ConsumeLoop(context.Background()) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("could not subscribe")) }) It("should subscribe to topic", func() { startConsuming() - defer consumer.StopConsuming() - Eventually(kafkaConsumerClientMock.SubscribedTopics, 5).Should(HaveKey("com.games.test")) + + time.Sleep(100 * time.Millisecond) + consumer.StopConsuming() + time.Sleep(100 * time.Millisecond) + + Expect(kafkaConsumerClientMock.SubscribedTopics).To(HaveKey("com.games.test")) }) It("should receive message", func() { @@ -104,14 +109,12 @@ var _ = Describe("Kafka Extension", func() { } val := []byte("test") event := &kafka.Message{TopicPartition: part, Value: val} - consumer.messagesReceived = 999 publishEvent(event) Eventually(consumer.msgChan, 5).Should(Receive(&interfaces.KafkaMessage{ Topic: topic, Value: val, })) - Expect(consumer.messagesReceived).To(BeEquivalentTo(1000)) }) }) @@ -134,17 +137,15 @@ var _ = Describe("Kafka Extension", func() { Describe("Pending Messages Waiting Group", func() { It("should return the waiting group", func() { - pmwg := consumer.PendingMessagesWaitGroup() - Expect(pmwg).NotTo(BeNil()) + pendingMessagesWaitGroup := consumer.PendingMessagesWaitGroup() + Expect(pendingMessagesWaitGroup).NotTo(BeNil()) }) }) Describe("Cleanup", func() { It("should stop running upon cleanup", func() { - consumer.run = true err := consumer.Cleanup() Expect(err).NotTo(HaveOccurred()) - Expect(consumer.run).To(BeFalse()) }) It("should close connection to kafka upon cleanup", func() { @@ -195,7 +196,10 @@ var _ = Describe("Kafka Extension", func() { Expect(err).NotTo(HaveOccurred()) Expect(client).NotTo(BeNil()) defer client.StopConsuming() - go client.ConsumeLoop() + go func() { + goFuncErr := client.ConsumeLoop(context.Background()) + Expect(goFuncErr).NotTo(HaveOccurred()) + }() // Required to assure the consumer to be ready before producing a message time.Sleep(5 * time.Second) diff --git a/extensions/kafka_producer.go b/extensions/kafka_producer.go index c6a375d..6c31f34 100644 --- a/extensions/kafka_producer.go +++ b/extensions/kafka_producer.go @@ -52,6 +52,11 @@ func NewKafkaProducer(config *viper.Viper, logger *log.Logger, clientOrNil ...in if len(clientOrNil) == 1 { producer = clientOrNil[0] } + + if q.Logger != nil { + q.Logger = q.Logger.WithField("source", "extensions.KafkaProducer").Logger + } + err := q.configure(producer) return q, err } @@ -125,6 +130,10 @@ func (q *KafkaProducer) listenForKafkaResponses() { // SendFeedback sends the feedback to the kafka Queue func (q *KafkaProducer) SendFeedback(game string, platform string, feedback []byte) { topic := "push-" + game + "-" + platform + "-feedbacks" + l := q.Logger.WithFields(log.Fields{ + "method": "SendFeedback", + "topic": topic, + }) m := &kafka.Message{ TopicPartition: kafka.TopicPartition{ Topic: &topic, @@ -133,4 +142,5 @@ func (q *KafkaProducer) SendFeedback(game string, platform string, feedback []by Value: feedback, } q.Producer.ProduceChannel() <- m + l.Debug("feedback sent to ProduceChannel") } diff --git a/extensions/timeout_heap.go b/extensions/timeout_heap.go index a6459a2..6a78a4a 100644 --- a/extensions/timeout_heap.go +++ b/extensions/timeout_heap.go @@ -90,8 +90,10 @@ func (th *TimeoutHeap) Pop() interface{} { return node } -// Returns true if heap is empty -func (th *TimeoutHeap) empty() bool { +// Empty returns true if heap has no elements +func (th *TimeoutHeap) Empty() bool { + mutex.Lock() + defer mutex.Unlock() return th.Len() == 0 } @@ -99,7 +101,7 @@ func (th *TimeoutHeap) completeHasExpiredRequest() (string, int64, bool) { mutex.Lock() defer mutex.Unlock() - if th.empty() { + if len(*th) == 0 { return "", 0, false } diff --git a/extensions/timeout_heap_test.go b/extensions/timeout_heap_test.go index b88210d..0656efa 100644 --- a/extensions/timeout_heap_test.go +++ b/extensions/timeout_heap_test.go @@ -71,12 +71,12 @@ var _ = Describe("[Unit]", func() { } }) - It("should return true if heap is empty", func() { + It("should return true if heap is Empty", func() { th := NewTimeoutHeap(config) - Ω(th.empty()).Should(BeTrue()) + Ω(th.Empty()).Should(BeTrue()) th.AddRequest("token") - Ω(th.empty()).Should(BeFalse()) + Ω(th.Empty()).Should(BeFalse()) }) It("should return nodes in order of time stamp from threads", func() { diff --git a/feedback/broker.go b/feedback/broker.go index ee6c50d..6ae1fb3 100644 --- a/feedback/broker.go +++ b/feedback/broker.go @@ -108,7 +108,6 @@ func (b *Broker) Start() { // Stop stops all routines from processing the in channel and closes all output channels. func (b *Broker) Stop() { - b.run = false close(b.stopChannel) close(b.InvalidTokenOutChan) } @@ -118,7 +117,7 @@ func (b *Broker) processMessages() { "method", "processMessages", ) - for b.run { + for { select { case msg, ok := <-b.InChan: if ok { @@ -144,11 +143,11 @@ func (b *Broker) processMessages() { } case <-b.stopChannel: - break + l.Info("stop processing Broker's in channel") + return } } - l.Info("stop processing Broker's in channel") } func (b *Broker) routeAPNSMessage(msg *structs.ResponseWithMetadata, game string) { diff --git a/feedback/broker_test.go b/feedback/broker_test.go index 8727167..b6922f1 100644 --- a/feedback/broker_test.go +++ b/feedback/broker_test.go @@ -30,19 +30,16 @@ import ( . "github.com/onsi/gomega" uuid "github.com/satori/go.uuid" "github.com/sideshow/apns2" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" "github.com/spf13/viper" "github.com/topfreegames/go-gcm" "github.com/topfreegames/pusher/structs" - "github.com/topfreegames/pusher/testing" - - "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/test" "github.com/topfreegames/pusher/util" ) var _ = Describe("Broker", func() { var logger *logrus.Logger - var hook *test.Hook var inChan chan QueueMessage var config *viper.Viper var err error @@ -53,7 +50,7 @@ var _ = Describe("Broker", func() { } BeforeEach(func() { - logger, hook = test.NewNullLogger() + logger, _ = test.NewNullLogger() config, err = util.NewViperWithConfigFile(configFile) Expect(err).NotTo(HaveOccurred()) @@ -70,8 +67,7 @@ var _ = Describe("Broker", func() { close(inChan) broker.Stop() - Eventually(func() []logrus.Entry { return hook.Entries }). - Should(testing.ContainLogMessage("stop processing Broker's in channel")) + Expect(broker.stopChannel).To(BeClosed()) }) Describe("APNS Feedback Messages", func() { diff --git a/feedback/invalid_token.go b/feedback/invalid_token.go index 6439597..ae00733 100644 --- a/feedback/invalid_token.go +++ b/feedback/invalid_token.go @@ -25,6 +25,7 @@ package feedback import ( "fmt" "strings" + "sync" "time" "github.com/getsentry/raven-go" @@ -62,9 +63,9 @@ type InvalidTokenHandler struct { InChan chan *InvalidToken Buffer []*InvalidToken + BufferLock sync.Mutex bufferSize int - run bool stopChan chan bool } @@ -127,13 +128,11 @@ func (i *InvalidTokenHandler) Start() { ) l.Info("starting invalid token handler") - i.run = true go i.processMessages() } // Stop stops the Handler from consuming messages from the intake channel func (i *InvalidTokenHandler) Stop() { - i.run = false close(i.stopChan) } @@ -145,10 +144,11 @@ func (i *InvalidTokenHandler) processMessages() { flushTicker := time.NewTicker(i.flushTime) defer flushTicker.Stop() - for i.run { + for { select { case tk, ok := <-i.InChan: if ok { + i.BufferLock.Lock() i.Buffer = append(i.Buffer, tk) if len(i.Buffer) >= i.bufferSize { @@ -156,19 +156,21 @@ func (i *InvalidTokenHandler) processMessages() { i.deleteTokens(i.Buffer) i.Buffer = make([]*InvalidToken, 0, i.bufferSize) } + i.BufferLock.Unlock() } case <-flushTicker.C: - l.Debug("flush ticker") + i.BufferLock.Lock() i.deleteTokens(i.Buffer) i.Buffer = make([]*InvalidToken, 0, i.bufferSize) + i.BufferLock.Unlock() case <-i.stopChan: - break + l.Info("stop processing Invalid Token Handler's in channel") + return } } - l.Info("stop processing Invalid Token Handler's in channel") } // deleteTokens groups tokens by game and platform and deletes them from the diff --git a/feedback/invalid_token_test.go b/feedback/invalid_token_test.go index bf0656b..5169932 100644 --- a/feedback/invalid_token_test.go +++ b/feedback/invalid_token_test.go @@ -35,7 +35,6 @@ import ( "github.com/topfreegames/pusher/extensions" "github.com/topfreegames/pusher/interfaces" "github.com/topfreegames/pusher/mocks" - "github.com/topfreegames/pusher/testing" "github.com/topfreegames/pusher/util" ) @@ -92,7 +91,7 @@ var _ = Describe("InvalidToken Handler", func() { }) It("Should flush because buffer is full", func() { - logger, hook := test.NewNullLogger() + logger, _ := test.NewNullLogger() logger.Level = logrus.DebugLevel mockClient := mocks.NewPGMock(0, 1) @@ -118,8 +117,9 @@ var _ = Describe("InvalidToken Handler", func() { inChan <- t } - Eventually(func() []logrus.Entry { return hook.Entries }). - Should(testing.ContainLogMessage("buffer is full")) + time.Sleep(time.Second) + handler.Stop() + time.Sleep(500 * time.Millisecond) Eventually(func() int64 { return mockStatsDClient.Counts[MetricsTokensDeleteSuccess] @@ -135,7 +135,7 @@ var _ = Describe("InvalidToken Handler", func() { }) It("Should flush because reached flush timeout", func() { - logger, hook := test.NewNullLogger() + logger, _ := test.NewNullLogger() logger.Level = logrus.DebugLevel mockClient := mocks.NewPGMock(0, 1) @@ -161,20 +161,13 @@ var _ = Describe("InvalidToken Handler", func() { inChan <- t } - Eventually(func() []logrus.Entry { return hook.Entries }). - Should(testing.ContainLogMessage("flush ticker")) - - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteSuccess] - }).Should(BeEquivalentTo(2)) - - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteError] - }).Should(BeEquivalentTo(0)) + time.Sleep(200 * time.Millisecond) + handler.Stop() + time.Sleep(500 * time.Millisecond) - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteNonexistent] - }).Should(BeEquivalentTo(0)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteSuccess]).To(BeEquivalentTo(2)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteError]).To(BeEquivalentTo(0)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteNonexistent]).To(BeEquivalentTo(0)) }) }) @@ -258,6 +251,9 @@ var _ = Describe("InvalidToken Handler", func() { Tokens: []interface{}{"EEEEEEEEEE", "FFFFFFFFFF"}, }, } + time.Sleep(time.Second) + handler.Stop() + time.Sleep(500 * time.Millisecond) for _, res := range expResults { Eventually(func() interface{} { @@ -272,7 +268,7 @@ var _ = Describe("InvalidToken Handler", func() { }) It("should not break if token does not exist in db", func() { - logger, hook := test.NewNullLogger() + logger, _ := test.NewNullLogger() mockClient := mocks.NewPGMock(0, 1) inChan := make(chan *InvalidToken, 100) @@ -302,24 +298,17 @@ var _ = Describe("InvalidToken Handler", func() { Game: "sniper", Platform: "apns", } - Consistently(func() []logrus.Entry { return hook.Entries }). - ShouldNot(testing.ContainLogMessage("error deleting tokens")) - - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteSuccess] - }).Should(BeEquivalentTo(0)) - - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteError] - }).Should(BeEquivalentTo(0)) + time.Sleep(time.Second) + handler.Stop() + time.Sleep(time.Second) - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteNonexistent] - }).Should(BeEquivalentTo(1)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteSuccess]).To(BeEquivalentTo(0)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteError]).To(BeEquivalentTo(0)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteNonexistent]).To(BeEquivalentTo(1)) }) It("should not break if a pg error occurred", func() { - logger, hook := test.NewNullLogger() + logger, _ := test.NewNullLogger() mockClient := mocks.NewPGMock(0, 1) inChan := make(chan *InvalidToken, 100) @@ -349,13 +338,7 @@ var _ = Describe("InvalidToken Handler", func() { Platform: "apns", } - for len(mockClient.Execs) < 2 { - time.Sleep(10 * time.Millisecond) - } - - Eventually(func() []logrus.Entry { - return hook.Entries - }).Should(testing.ContainLogMessage("error deleting tokens")) + time.Sleep(100 * time.Millisecond) mockClient.Error = nil mockClient.RowsAffected = 1 @@ -370,31 +353,15 @@ var _ = Describe("InvalidToken Handler", func() { expQuery := "DELETE FROM sniper_apns WHERE token IN (?0);" expTokens := []interface{}{"BBBBBBBBBB"} - Eventually(func() interface{} { - if len(mockClient.Execs) >= 3 { - return mockClient.Execs[2][0] - } - return nil - }).Should(BeEquivalentTo(expQuery)) - - Eventually(func() interface{} { - if len(mockClient.Execs) >= 3 { - return mockClient.Execs[2][1] - } - return nil - }).Should(BeEquivalentTo(expTokens)) + time.Sleep(200 * time.Millisecond) + handler.Stop() + time.Sleep(200 * time.Millisecond) - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteSuccess] - }).Should(BeEquivalentTo(1)) - - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteError] - }).Should(BeEquivalentTo(1)) - - Eventually(func() int64 { - return mockStatsDClient.Counts[MetricsTokensDeleteNonexistent] - }).Should(BeEquivalentTo(0)) + Expect(mockClient.Execs[2][0]).To(BeEquivalentTo(expQuery)) + Expect(mockClient.Execs[2][1]).To(BeEquivalentTo(expTokens)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteSuccess]).To(BeEquivalentTo(1)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteError]).To(BeEquivalentTo(1)) + Expect(mockStatsDClient.Counts[MetricsTokensDeleteNonexistent]).To(BeEquivalentTo(0)) }) }) }) diff --git a/feedback/kafka_consumer.go b/feedback/kafka_consumer.go index 3d74810..b49a5b4 100644 --- a/feedback/kafka_consumer.go +++ b/feedback/kafka_consumer.go @@ -23,10 +23,12 @@ package feedback import ( + "context" + "fmt" "sync" "github.com/confluentinc/confluent-kafka-go/v2/kafka" - raven "github.com/getsentry/raven-go" + "github.com/getsentry/raven-go" "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/topfreegames/pusher/extensions" @@ -53,6 +55,8 @@ type KafkaConsumer struct { stopChannel chan struct{} run bool HandleAllMessagesBeforeExiting bool + consumerContext context.Context + stopFunc context.CancelFunc } // NewKafkaConsumer for creating a new KafkaConsumer instance @@ -64,7 +68,7 @@ func NewKafkaConsumer( ) (*KafkaConsumer, error) { q := &KafkaConsumer{ Config: config, - Logger: logger, + Logger: logger.WithField("source", "feedback.KafkaConsumer").Logger, messagesReceived: 0, pendingMessagesWG: nil, stopChannel: *stopChannel, @@ -80,6 +84,8 @@ func NewKafkaConsumer( return nil, err } + q.consumerContext, q.stopFunc = context.WithCancel(context.Background()) + return q, nil } @@ -130,14 +136,12 @@ func (q *KafkaConsumer) configureConsumer(client interfaces.KafkaConsumerClient) "session.timeout.ms": q.SessionTimeout, "fetch.min.bytes": q.FetchMinBytes, "fetch.wait.max.ms": q.FetchWaitMaxMs, - "enable.auto.commit": true, "default.topic.config": kafka.ConfigMap{ - "auto.offset.reset": q.OffsetResetStrategy, - "enable.auto.commit": true, + "auto.offset.reset": q.OffsetResetStrategy, }, "topics": q.Topics, }) - l.Debug("configuring kafka queue extension") + l.Debug("configuring kafka queue") if client == nil { c, err := kafka.NewConsumer(&kafka.ConfigMap{ @@ -146,10 +150,8 @@ func (q *KafkaConsumer) configureConsumer(client interfaces.KafkaConsumerClient) "fetch.min.bytes": q.FetchMinBytes, "fetch.wait.max.ms": q.FetchWaitMaxMs, "session.timeout.ms": q.SessionTimeout, - "enable.auto.commit": true, "default.topic.config": kafka.ConfigMap{ - "auto.offset.reset": q.OffsetResetStrategy, - "enable.auto.commit": true, + "auto.offset.reset": q.OffsetResetStrategy, }, }) @@ -174,7 +176,7 @@ func (q *KafkaConsumer) PendingMessagesWaitGroup() *sync.WaitGroup { // StopConsuming stops consuming messages from the queue func (q *KafkaConsumer) StopConsuming() { - q.run = false + q.stopFunc() } // MessagesChannel returns the channel that will receive all messages got from kafka @@ -183,7 +185,7 @@ func (q *KafkaConsumer) MessagesChannel() chan QueueMessage { } // ConsumeLoop consume messages from the queue and put in messages to send channel -func (q *KafkaConsumer) ConsumeLoop() error { +func (q *KafkaConsumer) ConsumeLoop(ctx context.Context) error { l := q.Logger.WithFields(logrus.Fields{ "method": "ConsumeLoop", "topics": q.Topics, @@ -197,20 +199,34 @@ func (q *KafkaConsumer) ConsumeLoop() error { l.Info("successfully subscribed to topics") - q.run = true - for q.run { - message, err := q.Consumer.ReadMessage(100) - if message == nil && err.(kafka.Error).IsTimeout() { - continue - } - if err != nil { - q.handleError(err) - continue + for { + select { + case <-q.consumerContext.Done(): + l.Info("context done, stopping consuming") + return nil + case <-q.stopChannel: + l.Info("stop channel closed, stopping consuming") + return nil + case <-ctx.Done(): + l.Info("context done, stopping consuming") + return nil + default: + message, err := q.Consumer.ReadMessage(100) + if message == nil && err.(kafka.Error).IsTimeout() { + continue + } + if err != nil { + q.handleError(err) + continue + } + q.receiveMessage(message.TopicPartition, message.Value) + _, err = q.Consumer.CommitMessage(message) + if err != nil { + q.handleError(err) + return fmt.Errorf("error committing message: %s", err.Error()) + } } - q.receiveMessage(message.TopicPartition, message.Value) } - - return nil } func (q *KafkaConsumer) receiveMessage(topicPartition kafka.TopicPartition, value []byte) { @@ -254,9 +270,12 @@ func (q *KafkaConsumer) handleError(err error) { // Cleanup closes kafka consumer connection func (q *KafkaConsumer) Cleanup() error { - if q.run { + select { + case <-q.consumerContext.Done(): + default: q.StopConsuming() } + if q.Consumer != nil { err := q.Consumer.Close() if err != nil { diff --git a/feedback/kafka_consumer_test.go b/feedback/kafka_consumer_test.go index ec7cfb1..136b2fd 100644 --- a/feedback/kafka_consumer_test.go +++ b/feedback/kafka_consumer_test.go @@ -23,6 +23,7 @@ package feedback import ( + "context" "fmt" "os" "time" @@ -52,7 +53,8 @@ var _ = Describe("Kafka Consumer", func() { startConsuming := func() { go func() { defer GinkgoRecover() - consumer.ConsumeLoop() + err := consumer.ConsumeLoop(context.Background()) + Expect(err).NotTo(HaveOccurred()) }() time.Sleep(5 * time.Millisecond) } @@ -95,23 +97,25 @@ var _ = Describe("Kafka Consumer", func() { Describe("Stop consuming", func() { It("should stop consuming", func() { - consumer.run = true consumer.StopConsuming() - Expect(consumer.run).To(BeFalse()) + Expect(consumer.consumerContext.Done()).To(BeClosed()) }) }) Describe("Consume loop", func() { It("should fail if subscribing to topic fails", func() { kafkaConsumerClientMock.Error = fmt.Errorf("could not subscribe") - err := consumer.ConsumeLoop() + err := consumer.ConsumeLoop(context.Background()) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("could not subscribe")) }) It("should subscribe to topic", func() { startConsuming() - defer consumer.StopConsuming() + time.Sleep(100 * time.Millisecond) + consumer.StopConsuming() + time.Sleep(100 * time.Millisecond) + Eventually(kafkaConsumerClientMock.SubscribedTopics, 5).Should(HaveKey("com.games.test")) }) @@ -163,10 +167,9 @@ var _ = Describe("Kafka Consumer", func() { Describe("Cleanup", func() { It("should stop running upon cleanup", func() { - consumer.run = true err := consumer.Cleanup() Expect(err).NotTo(HaveOccurred()) - Expect(consumer.run).To(BeFalse()) + Expect(consumer.consumerContext.Done()).To(BeClosed()) }) It("should close connection to kafka upon cleanup", func() { @@ -213,7 +216,7 @@ var _ = Describe("Kafka Consumer", func() { Describe("ConsumeLoop", func() { It("should consume message and add it to msgChan", func() { - logger, _ := test.NewNullLogger() + logger := logrus.New() logger.Level = logrus.DebugLevel stopChannel := make(chan struct{}) @@ -226,13 +229,20 @@ var _ = Describe("Kafka Consumer", func() { client, err := NewKafkaConsumer(config, logger, &stopChannel) Expect(err).NotTo(HaveOccurred()) Expect(client).NotTo(BeNil()) + p, err := kafka.NewProducer(&kafka.ConfigMap{"bootstrap.servers": client.Brokers}) + err = p.Produce(&kafka.Message{TopicPartition: kafka.TopicPartition{Topic: &client.Topics[0], Partition: kafka.PartitionAny}, Value: value}, nil) + Expect(err).NotTo(HaveOccurred()) + time.Sleep(5 * time.Second) + + go func() { + goFuncErr := client.ConsumeLoop(context.Background()) + Expect(goFuncErr).NotTo(HaveOccurred()) + }() defer client.StopConsuming() - go client.ConsumeLoop() // Required to assure the consumer to be ready before producing a message time.Sleep(5 * time.Second) - p, err := kafka.NewProducer(&kafka.ConfigMap{"bootstrap.servers": client.Brokers}) Expect(err).NotTo(HaveOccurred()) err = p.Produce( &kafka.Message{ @@ -244,7 +254,7 @@ var _ = Describe("Kafka Consumer", func() { nil, ) Expect(err).NotTo(HaveOccurred()) - Eventually(client.msgChan, 10*time.Second).Should(Receive(Equal(&KafkaMessage{ + Eventually(client.msgChan, 100*time.Millisecond).WithTimeout(5 * time.Second).Should(Receive(Equal(&KafkaMessage{ Game: game, Platform: platform, Value: value, diff --git a/feedback/listener.go b/feedback/listener.go index 2033d30..e05c636 100644 --- a/feedback/listener.go +++ b/feedback/listener.go @@ -23,6 +23,7 @@ package feedback import ( + "context" "fmt" "os" "os/signal" @@ -127,7 +128,7 @@ func (l *Listener) Start() { ) log.Info("starting the feedback listener...") - go l.Queue.ConsumeLoop() + go l.Queue.ConsumeLoop(context.Background()) l.Broker.Start() l.InvalidTokenHandler.Start() @@ -146,7 +147,7 @@ func (l *Listener) Start() { log.WithField("signal", sig.String()).Warn("terminating due to caught signal") l.run = false case <-l.stopChannel: - log.Warn("Stop channel closed\n") + log.Warn("Stop channel closed") l.run = false case <-flushTicker.C: l.flushStats() @@ -163,8 +164,12 @@ func (l *Listener) flushStats() { "broker_in_channel", float64(len(l.Broker.InChan)), "", "") statsReporterReportMetricGauge(l.StatsReporters, "broker_invalid_token_channel", float64(len(l.Broker.InvalidTokenOutChan)), "", "") + + l.InvalidTokenHandler.BufferLock.Lock() + bufferSize := float64(len(l.InvalidTokenHandler.Buffer)) + l.InvalidTokenHandler.BufferLock.Unlock() statsReporterReportMetricGauge(l.StatsReporters, - "invalid_token_handler_buffer", float64(len(l.InvalidTokenHandler.Buffer)), "", "") + "invalid_token_handler_buffer", bufferSize, "", "") } // Cleanup ends the Listener execution diff --git a/feedback/listener_test.go b/feedback/listener_test.go index ffa8ff1..793dba4 100644 --- a/feedback/listener_test.go +++ b/feedback/listener_test.go @@ -229,7 +229,7 @@ var _ = Describe("Feedback Listener", func() { }) It("should delete a batch of tokens from a single game", func() { - logger, _ := test.NewNullLogger() + logger := logrus.New() logger.Level = logrus.DebugLevel config.Set("feedbackListeners.queue.group", fmt.Sprintf("group-%s", uuid.NewV4().String())) diff --git a/feedback/queue.go b/feedback/queue.go index 9cd13a8..4b6f00d 100644 --- a/feedback/queue.go +++ b/feedback/queue.go @@ -22,7 +22,10 @@ package feedback -import "sync" +import ( + "context" + "sync" +) // QueueMessage defines the interface that should be implemented by the type // produced by a Queue @@ -35,7 +38,7 @@ type QueueMessage interface { // Queue interface for making new queues pluggable easily type Queue interface { MessagesChannel() chan QueueMessage - ConsumeLoop() error + ConsumeLoop(ctx context.Context) error StopConsuming() Cleanup() error PendingMessagesWaitGroup() *sync.WaitGroup diff --git a/interfaces/kafka.go b/interfaces/kafka.go index e9599d7..43b8145 100644 --- a/interfaces/kafka.go +++ b/interfaces/kafka.go @@ -42,4 +42,5 @@ type KafkaConsumerClient interface { Pause([]kafka.TopicPartition) error Resume([]kafka.TopicPartition) error Assignment() ([]kafka.TopicPartition, error) + CommitMessage(message *kafka.Message) ([]kafka.TopicPartition, error) } diff --git a/interfaces/queue.go b/interfaces/queue.go index 3b3356f..4695d1e 100644 --- a/interfaces/queue.go +++ b/interfaces/queue.go @@ -22,7 +22,10 @@ package interfaces -import "sync" +import ( + "context" + "sync" +) // KafkaMessage sent through the Channel. type KafkaMessage struct { @@ -34,7 +37,7 @@ type KafkaMessage struct { // Queue interface for making new queues pluggable easily. type Queue interface { MessagesChannel() *chan KafkaMessage - ConsumeLoop() error + ConsumeLoop(ctx context.Context) error StopConsuming() PendingMessagesWaitGroup() *sync.WaitGroup } diff --git a/mocks/apns.go b/mocks/apns.go index f952e96..238029c 100644 --- a/mocks/apns.go +++ b/mocks/apns.go @@ -25,6 +25,7 @@ package mocks import ( "github.com/sideshow/apns2" "github.com/topfreegames/pusher/structs" + "sync" ) // APNSPushQueueMock should be used for tests that need to send pushs to APNS. @@ -32,18 +33,22 @@ type APNSPushQueueMock struct { responseChannel chan *structs.ResponseWithMetadata Closed bool PushedNotification *apns2.Notification + internalLock sync.Mutex } // NewAPNSPushQueueMock creates a new instance. func NewAPNSPushQueueMock() *APNSPushQueueMock { return &APNSPushQueueMock{ responseChannel: make(chan *structs.ResponseWithMetadata), + internalLock: sync.Mutex{}, } } // Push records the sent message in the MessagesSent collection func (m *APNSPushQueueMock) Push(n *apns2.Notification) { + m.internalLock.Lock() m.PushedNotification = n + m.internalLock.Unlock() } func (m *APNSPushQueueMock) Configure() error { diff --git a/mocks/interfaces/apns.go b/mocks/interfaces/apns.go new file mode 100644 index 0000000..8eda2be --- /dev/null +++ b/mocks/interfaces/apns.go @@ -0,0 +1,93 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces/apns.go +// +// Generated by this command: +// +// mockgen -source=interfaces/apns.go -destination=mocks/interfaces/apns.go +// + +// Package mock_interfaces is a generated GoMock package. +package mock_interfaces + +import ( + reflect "reflect" + + apns2 "github.com/sideshow/apns2" + structs "github.com/topfreegames/pusher/structs" + gomock "go.uber.org/mock/gomock" +) + +// MockAPNSPushQueue is a mock of APNSPushQueue interface. +type MockAPNSPushQueue struct { + ctrl *gomock.Controller + recorder *MockAPNSPushQueueMockRecorder +} + +// MockAPNSPushQueueMockRecorder is the mock recorder for MockAPNSPushQueue. +type MockAPNSPushQueueMockRecorder struct { + mock *MockAPNSPushQueue +} + +// NewMockAPNSPushQueue creates a new mock instance. +func NewMockAPNSPushQueue(ctrl *gomock.Controller) *MockAPNSPushQueue { + mock := &MockAPNSPushQueue{ctrl: ctrl} + mock.recorder = &MockAPNSPushQueueMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAPNSPushQueue) EXPECT() *MockAPNSPushQueueMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockAPNSPushQueue) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockAPNSPushQueueMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAPNSPushQueue)(nil).Close)) +} + +// Configure mocks base method. +func (m *MockAPNSPushQueue) Configure() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Configure") + ret0, _ := ret[0].(error) + return ret0 +} + +// Configure indicates an expected call of Configure. +func (mr *MockAPNSPushQueueMockRecorder) Configure() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Configure", reflect.TypeOf((*MockAPNSPushQueue)(nil).Configure)) +} + +// Push mocks base method. +func (m *MockAPNSPushQueue) Push(arg0 *apns2.Notification) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Push", arg0) +} + +// Push indicates an expected call of Push. +func (mr *MockAPNSPushQueueMockRecorder) Push(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockAPNSPushQueue)(nil).Push), arg0) +} + +// ResponseChannel mocks base method. +func (m *MockAPNSPushQueue) ResponseChannel() chan *structs.ResponseWithMetadata { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResponseChannel") + ret0, _ := ret[0].(chan *structs.ResponseWithMetadata) + return ret0 +} + +// ResponseChannel indicates an expected call of ResponseChannel. +func (mr *MockAPNSPushQueueMockRecorder) ResponseChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseChannel", reflect.TypeOf((*MockAPNSPushQueue)(nil).ResponseChannel)) +} diff --git a/mocks/interfaces/consumption_manager.go b/mocks/interfaces/consumption_manager.go new file mode 100644 index 0000000..59b9e26 --- /dev/null +++ b/mocks/interfaces/consumption_manager.go @@ -0,0 +1,20 @@ +package mock_interfaces + +import "github.com/topfreegames/pusher/interfaces" + +type MockConsumptionManager struct { +} + +func (m *MockConsumptionManager) Pause(topic string) error { + return nil +} + +func (m *MockConsumptionManager) Resume(topic string) error { + return nil +} + +func NewMockConsumptionManager() *MockConsumptionManager { + return &MockConsumptionManager{} +} + +var _ interfaces.ConsumptionManager = &MockConsumptionManager{} diff --git a/mocks/interfaces/feedback_reporter.go b/mocks/interfaces/feedback_reporter.go new file mode 100644 index 0000000..cb9767a --- /dev/null +++ b/mocks/interfaces/feedback_reporter.go @@ -0,0 +1,51 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces/feedback_reporter.go +// +// Generated by this command: +// +// mockgen -source=interfaces/feedback_reporter.go -destination=mocks/interfaces/feedback_reporter.go +// + +// Package mock_interfaces is a generated GoMock package. +package mock_interfaces + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockFeedbackReporter is a mock of FeedbackReporter interface. +type MockFeedbackReporter struct { + ctrl *gomock.Controller + recorder *MockFeedbackReporterMockRecorder +} + +// MockFeedbackReporterMockRecorder is the mock recorder for MockFeedbackReporter. +type MockFeedbackReporterMockRecorder struct { + mock *MockFeedbackReporter +} + +// NewMockFeedbackReporter creates a new mock instance. +func NewMockFeedbackReporter(ctrl *gomock.Controller) *MockFeedbackReporter { + mock := &MockFeedbackReporter{ctrl: ctrl} + mock.recorder = &MockFeedbackReporterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFeedbackReporter) EXPECT() *MockFeedbackReporterMockRecorder { + return m.recorder +} + +// SendFeedback mocks base method. +func (m *MockFeedbackReporter) SendFeedback(game, platform string, feedback []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SendFeedback", game, platform, feedback) +} + +// SendFeedback indicates an expected call of SendFeedback. +func (mr *MockFeedbackReporterMockRecorder) SendFeedback(game, platform, feedback any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendFeedback", reflect.TypeOf((*MockFeedbackReporter)(nil).SendFeedback), game, platform, feedback) +} diff --git a/mocks/interfaces/statsd.go b/mocks/interfaces/statsd.go new file mode 100644 index 0000000..3ddbd08 --- /dev/null +++ b/mocks/interfaces/statsd.go @@ -0,0 +1,110 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces/statsd.go +// +// Generated by this command: +// +// mockgen -source=interfaces/statsd.go -destination=mocks/interfaces/statsd.go +// + +// Package mock_interfaces is a generated GoMock package. +package mock_interfaces + +import ( + reflect "reflect" + time "time" + + gomock "go.uber.org/mock/gomock" +) + +// MockStatsDClient is a mock of StatsDClient interface. +type MockStatsDClient struct { + ctrl *gomock.Controller + recorder *MockStatsDClientMockRecorder +} + +// MockStatsDClientMockRecorder is the mock recorder for MockStatsDClient. +type MockStatsDClientMockRecorder struct { + mock *MockStatsDClient +} + +// NewMockStatsDClient creates a new mock instance. +func NewMockStatsDClient(ctrl *gomock.Controller) *MockStatsDClient { + mock := &MockStatsDClient{ctrl: ctrl} + mock.recorder = &MockStatsDClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStatsDClient) EXPECT() *MockStatsDClientMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockStatsDClient) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockStatsDClientMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStatsDClient)(nil).Close)) +} + +// Count mocks base method. +func (m *MockStatsDClient) Count(arg0 string, arg1 int64, arg2 []string, arg3 float64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Count", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// Count indicates an expected call of Count. +func (mr *MockStatsDClientMockRecorder) Count(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockStatsDClient)(nil).Count), arg0, arg1, arg2, arg3) +} + +// Gauge mocks base method. +func (m *MockStatsDClient) Gauge(arg0 string, arg1 float64, arg2 []string, arg3 float64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Gauge", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// Gauge indicates an expected call of Gauge. +func (mr *MockStatsDClientMockRecorder) Gauge(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Gauge", reflect.TypeOf((*MockStatsDClient)(nil).Gauge), arg0, arg1, arg2, arg3) +} + +// Incr mocks base method. +func (m *MockStatsDClient) Incr(arg0 string, arg1 []string, arg2 float64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Incr", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// Incr indicates an expected call of Incr. +func (mr *MockStatsDClientMockRecorder) Incr(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Incr", reflect.TypeOf((*MockStatsDClient)(nil).Incr), arg0, arg1, arg2) +} + +// Timing mocks base method. +func (m *MockStatsDClient) Timing(arg0 string, arg1 time.Duration, arg2 []string, arg3 float64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Timing", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// Timing indicates an expected call of Timing. +func (mr *MockStatsDClientMockRecorder) Timing(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timing", reflect.TypeOf((*MockStatsDClient)(nil).Timing), arg0, arg1, arg2, arg3) +} diff --git a/mocks/kafka.go b/mocks/kafka.go index 8c41b2b..14f0056 100644 --- a/mocks/kafka.go +++ b/mocks/kafka.go @@ -178,3 +178,7 @@ func (k *KafkaConsumerClientMock) Assign(partitions []kafka.TopicPartition) erro func (k *KafkaConsumerClientMock) Assignment() ([]kafka.TopicPartition, error) { return k.Assignments, nil } + +func (k *KafkaConsumerClientMock) CommitMessage(m *kafka.Message) ([]kafka.TopicPartition, error) { + return nil, nil +} diff --git a/pusher/apns.go b/pusher/apns.go index 2d14210..ccdada2 100644 --- a/pusher/apns.go +++ b/pusher/apns.go @@ -25,10 +25,9 @@ package pusher import ( "errors" "fmt" - "strings" - "github.com/sirupsen/logrus" "github.com/spf13/viper" + "github.com/topfreegames/pusher/config" "github.com/topfreegames/pusher/extensions" "github.com/topfreegames/pusher/interfaces" ) @@ -41,7 +40,8 @@ type APNSPusher struct { // NewAPNSPusher for getting a new APNSPusher instance func NewAPNSPusher( isProduction bool, - config *viper.Viper, + vConfig *viper.Viper, + config *config.Config, logger *logrus.Logger, statsdClientOrNil interfaces.StatsDClient, db interfaces.DB, @@ -49,7 +49,8 @@ func NewAPNSPusher( ) (*APNSPusher, error) { a := &APNSPusher{ Pusher: Pusher{ - ViperConfig: config, + ViperConfig: vConfig, + Config: config, IsProduction: isProduction, Logger: logger, stopChannel: make(chan struct{}), @@ -91,15 +92,19 @@ func (a *APNSPusher) configure(queue interfaces.APNSPushQueue, db interfaces.DB, a.MessageHandler = make(map[string]interfaces.MessageHandler) a.Queue = q l.Info("Configuring messageHandler") - for _, k := range strings.Split(a.ViperConfig.GetString("apns.apps"), ",") { + for _, k := range a.Config.GetApnsAppsArray() { authKeyPath := a.ViperConfig.GetString("apns.certs." + k + ".authKeyPath") keyID := a.ViperConfig.GetString("apns.certs." + k + ".keyID") teamID := a.ViperConfig.GetString("apns.certs." + k + ".teamID") topic := a.ViperConfig.GetString("apns.certs." + k + ".topic") - l.Infof( - "Configuring messageHandler for game %s with key: %s", - k, authKeyPath, - ) + + l.WithFields(logrus.Fields{ + "app": k, + "authKeyPath": authKeyPath, + "teamID": teamID, + "topic": topic, + }).Info("configuring apns message handler") + handler, err := extensions.NewAPNSMessageHandler( authKeyPath, keyID, @@ -112,7 +117,7 @@ func (a *APNSPusher) configure(queue interfaces.APNSPushQueue, db interfaces.DB, a.Queue.PendingMessagesWaitGroup(), a.StatsReporters, a.feedbackReporters, - nil, + queue, interfaces.ConsumptionManager(q), ) if err == nil { diff --git a/pusher/apns_test.go b/pusher/apns_test.go index a1a3bbd..e38944e 100644 --- a/pusher/apns_test.go +++ b/pusher/apns_test.go @@ -24,6 +24,7 @@ package pusher import ( "context" + "github.com/topfreegames/pusher/config" "os" "time" @@ -32,11 +33,11 @@ import ( "github.com/sirupsen/logrus/hooks/test" "github.com/spf13/viper" "github.com/topfreegames/pusher/mocks" - "github.com/topfreegames/pusher/util" ) var _ = Describe("APNS Pusher", func() { - var config *viper.Viper + var vConfig *viper.Viper + var cfg *config.Config configFile := os.Getenv("CONFIG_FILE") if configFile == "" { configFile = "../config/test.yaml" @@ -46,7 +47,7 @@ var _ = Describe("APNS Pusher", func() { BeforeEach(func() { var err error - config, err = util.NewViperWithConfigFile(configFile) + cfg, vConfig, err = config.NewConfigAndViper(configFile) Expect(err).NotTo(HaveOccurred()) hook.Reset() }) @@ -67,7 +68,8 @@ var _ = Describe("APNS Pusher", func() { It("should return configured pusher", func() { pusher, err := NewAPNSPusher( isProduction, - config, + vConfig, + cfg, logger, mockStatsDClient, mockDB, @@ -76,7 +78,6 @@ var _ = Describe("APNS Pusher", func() { Expect(err).NotTo(HaveOccurred()) Expect(pusher).NotTo(BeNil()) Expect(pusher.IsProduction).To(Equal(isProduction)) - Expect(pusher.run).To(BeFalse()) Expect(pusher.Queue).NotTo(BeNil()) Expect(pusher.ViperConfig).NotTo(BeNil()) Expect(pusher.MessageHandler).NotTo(BeNil()) @@ -90,7 +91,8 @@ var _ = Describe("APNS Pusher", func() { It("should launch go routines and run forever", func() { pusher, err := NewAPNSPusher( isProduction, - config, + vConfig, + cfg, logger, mockStatsDClient, mockDB, @@ -99,25 +101,28 @@ var _ = Describe("APNS Pusher", func() { Expect(err).NotTo(HaveOccurred()) Expect(len(pusher.MessageHandler)).To(Equal(1)) Expect(pusher).NotTo(BeNil()) - defer func() { pusher.run = false }() - go pusher.Start(context.Background()) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go pusher.Start(ctx) time.Sleep(50 * time.Millisecond) }) It("should not ignore failed handlers", func() { - config.Set("apns.apps", "game,invalidgame") - config.Set("apns.certs.invalidgame.authKeyPath", "../tls/authkey_invalid.p8") - config.Set("apns.certs.invalidgame.keyID", "oiejowijefiowejf") - config.Set("apns.certs.invalidgame.teamID", "aoijeoijfiowejfoij") - config.Set("apns.certs.invalidgame.teamID", "com.invalidgame.test") + cfg.Apns.Apps = "game,invalidgame" + vConfig.Set("apns.certs.invalidgame.authKeyPath", "../tls/authkey_invalid.p8") + vConfig.Set("apns.certs.invalidgame.keyID", "oiejowijefiowejf") + vConfig.Set("apns.certs.invalidgame.teamID", "aoijeoijfiowejfoij") + vConfig.Set("apns.certs.invalidgame.teamID", "com.invalidgame.test") _, err := NewAPNSPusher( isProduction, - config, + vConfig, + cfg, logger, mockStatsDClient, mockDB, - mockPushQueue, + nil, ) Expect(err).To(HaveOccurred()) }) diff --git a/pusher/gcm.go b/pusher/gcm.go index 5d3ea37..38fb6a8 100644 --- a/pusher/gcm.go +++ b/pusher/gcm.go @@ -95,7 +95,7 @@ func (g *GCMPusher) createMessageHandlerForApps(ctx context.Context) error { }) g.MessageHandler = make(map[string]interfaces.MessageHandler) - for _, app := range g.Config.GetAppsArray() { + for _, app := range g.Config.GetGcmAppsArray() { credentials := g.ViperConfig.GetString("gcm.firebaseCredentials." + app) l = l.WithField("app", app) if credentials != "" { // Firebase is configured, use new handler diff --git a/pusher/pusher.go b/pusher/pusher.go index cf9d5e5..5dced5a 100644 --- a/pusher/pusher.go +++ b/pusher/pusher.go @@ -48,7 +48,6 @@ type Pusher struct { MessageHandler map[string]interfaces.MessageHandler stopChannel chan struct{} IsProduction bool - run bool } func (p *Pusher) loadConfigurationDefaults() { @@ -79,25 +78,39 @@ func (p *Pusher) configureStatsReporters(clientOrNil interfaces.StatsDClient) er } func (p *Pusher) routeMessages(ctx context.Context, msgChan *chan interfaces.KafkaMessage) { + l := p.Logger.WithFields(logrus.Fields{ + "method": "routeMessages", + "source": "pusher", + }) //nolint[:gosimple] - for p.run { + for { select { case message := <-*msgChan: + l = l.WithFields(logrus.Fields{ + "game": message.Game, + "jsonValue": string(message.Value), + "topic": message.Topic, + }) + + l.Debug("got message from message channel") + if handler, ok := p.MessageHandler[message.Game]; ok { handler.HandleMessages(ctx, message) } else { - p.Logger.WithFields(logrus.Fields{ - "method": "routeMessages", - "game": message.Game, - }).Error("Game not found") + l.Error("Game not found") } + case <-ctx.Done(): + l.Info("Context done. Will stop routing messages.") + return + case <-p.stopChannel: + l.Info("Stop channel closed. Will stop routing messages.") + return } } } // Start starts pusher func (p *Pusher) Start(ctx context.Context) { - p.run = true l := p.Logger.WithFields(logrus.Fields{ "method": "start", }) @@ -109,21 +122,17 @@ func (p *Pusher) Start(ctx context.Context) { go v.CleanMetadataCache() } //nolint[:errcheck] - go p.Queue.ConsumeLoop() + go p.Queue.ConsumeLoop(ctx) go p.reportGoStats() sigchan := make(chan os.Signal) signal.Notify(sigchan, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - for p.run { - select { - case sig := <-sigchan: - l.Warnf("caught signal %v: terminating\n", sig) - p.run = false - case <-p.stopChannel: - l.Warn("Stop channel closed\n") - p.run = false - } + select { + case sig := <-sigchan: + l.Infof("caught signal %v: terminating", sig) + case <-ctx.Done(): + l.Info("Context done. Will stop consuming.") } p.Queue.StopConsuming() GracefulShutdown(p.Queue.PendingMessagesWaitGroup(), time.Duration(p.GracefulShutdownTimeout)*time.Second)