diff --git a/fs2-aws-sqs/src/main/scala/fs2/aws/sqs/SQS.scala b/fs2-aws-sqs/src/main/scala/fs2/aws/sqs/SQS.scala index 0eae346f..a697f438 100644 --- a/fs2-aws-sqs/src/main/scala/fs2/aws/sqs/SQS.scala +++ b/fs2-aws-sqs/src/main/scala/fs2/aws/sqs/SQS.scala @@ -22,23 +22,14 @@ object SQS { type MsgBody = String def create[F[_]: Async]( - sqsConfig: SqsConfig, + sqsConfig: SqsConfig.Interface, sqs: SqsAsyncClientOp[F] ): F[SQS[F]] = new SQS[F] { override def sqsStream: fs2.Stream[F, Message] = fs2.Stream .awakeEvery[F](sqsConfig.pollRate) - .evalMap(_ => - sqs - .receiveMessage( - ReceiveMessageRequest - .builder() - .queueUrl(sqsConfig.queueUrl) - .maxNumberOfMessages(sqsConfig.fetchMessageCount) - .build() - ) - ) + .evalMap(_ => sqs.receiveMessage(sqsConfig.receiveMessageRequest)) .flatMap(response => fs2.Stream.emits(response.messages().asScala)) override def deleteMessagePipe: Pipe[F, Message, DeleteMessageResponse] = diff --git a/fs2-aws-sqs/src/main/scala/fs2/aws/sqs/SqsConfig.scala b/fs2-aws-sqs/src/main/scala/fs2/aws/sqs/SqsConfig.scala index e3e45090..c32e8049 100644 --- a/fs2-aws-sqs/src/main/scala/fs2/aws/sqs/SqsConfig.scala +++ b/fs2-aws-sqs/src/main/scala/fs2/aws/sqs/SqsConfig.scala @@ -1,9 +1,32 @@ package sqs +import software.amazon.awssdk.services.sqs.model.* import scala.concurrent.duration.* +import scala.jdk.CollectionConverters.* case class SqsConfig( - queueUrl: String, - pollRate: FiniteDuration = 3.seconds, - fetchMessageCount: Int = 100 -) + override val queueUrl: String, + override val pollRate: FiniteDuration = 3.seconds, + override val fetchMessageCount: Int = 100, + override val messageAttributeNames: Option[List[String]] = None +) extends SqsConfig.Interface + +object SqsConfig { + trait Interface { + def queueUrl: String + def pollRate: FiniteDuration + def fetchMessageCount: Int + def messageAttributeNames: Option[List[String]] + + def receiveMessageRequest: ReceiveMessageRequest = buildReceiveRequest.build() + + protected def buildReceiveRequest: ReceiveMessageRequest.Builder = { + val blder = ReceiveMessageRequest + .builder() + .queueUrl(queueUrl) + .maxNumberOfMessages(fetchMessageCount) + + messageAttributeNames.fold(blder)(names => blder.messageAttributeNames(names.asJava)) + } + } +}