Skip to content

Commit

Permalink
Add flag to remove need to ack (#37)
Browse files Browse the repository at this point in the history
* Split the ack and non-ack styles into separate classes.

Co-authored-by: Phil Story <[email protected]>
  • Loading branch information
wbaker-figure and mtps authored Dec 1, 2022
1 parent 1e1ea8e commit 51d0520
Showing 1 changed file with 217 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
package tech.figure.kafka.coroutines.channels

import tech.figure.kafka.loggingConsumerRebalanceListener
import tech.figure.kafka.records.CommitConsumerRecord
import tech.figure.kafka.records.UnAckedConsumerRecordImpl
import tech.figure.kafka.records.UnAckedConsumerRecords
import java.util.concurrent.atomic.AtomicInteger
import kotlin.concurrent.thread
import kotlin.time.Duration
import kotlin.time.ExperimentalTime
import kotlin.time.toJavaDuration
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.ExperimentalCoroutinesApi
Expand All @@ -21,7 +16,73 @@ import kotlinx.coroutines.selects.SelectClause1
import mu.KotlinLogging
import org.apache.kafka.clients.consumer.Consumer
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.clients.consumer.ConsumerRecords
import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import tech.figure.kafka.loggingConsumerRebalanceListener
import tech.figure.kafka.records.CommitConsumerRecord
import tech.figure.kafka.records.UnAckedConsumerRecordImpl
import tech.figure.kafka.records.UnAckedConsumerRecords

internal fun <K, V> List<ConsumerRecord<K, V>>.toConsumerRecords() =
groupBy { TopicPartition(it.topic(), it.partition()) }.let(::ConsumerRecords)

/**
* Default is to create a committable consumer channel for unacknowledged record processing.
*
* @see [kafkaAckConsumerChannel]
*/
fun <K, V> kafkaConsumerChannel(
consumerProperties: Map<String, Any>,
topics: Set<String>,
name: String = "kafka-channel",
pollInterval: Duration = DEFAULT_POLL_INTERVAL,
consumer: Consumer<K, V> = KafkaConsumer(consumerProperties),
rebalanceListener: ConsumerRebalanceListener = loggingConsumerRebalanceListener(),
init: Consumer<K, V>.() -> Unit = { subscribe(topics, rebalanceListener) },
): ReceiveChannel<UnAckedConsumerRecords<K, V>> = kafkaAckConsumerChannel(consumerProperties, topics, name, pollInterval, consumer, rebalanceListener, init)

/**
* Create a [ReceiveChannel] for [ConsumerRecords] from kafka.
*
* @param consumerProperties Kafka consumer settings for this channel.
* @param topics Topics to subscribe to. Can be overridden via custom `init` parameter.
* @param name The thread pool's base name for this consumer.
* @param pollInterval Interval for kafka consumer [Consumer.poll] method calls.
* @param consumer The instantiated [Consumer] to use to receive from kafka.
* @param init Callback for initializing the [Consumer].
* @return A non-running [KafkaConsumerChannel] instance that must be started via
* [KafkaConsumerChannel.start].
*/
fun <K, V> kafkaNoAckConsumerChannel(
consumerProperties: Map<String, Any>,
topics: Set<String>,
name: String = "kafka-channel",
pollInterval: Duration = DEFAULT_POLL_INTERVAL,
consumer: Consumer<K, V> = KafkaConsumer(consumerProperties),
rebalanceListener: ConsumerRebalanceListener = loggingConsumerRebalanceListener(),
init: Consumer<K, V>.() -> Unit = { subscribe(topics, rebalanceListener) },
): ReceiveChannel<ConsumerRecords<K, V>> {
return object :
KafkaConsumerChannel<K, V, ConsumerRecords<K, V>>(
consumerProperties,
topics,
name,
pollInterval,
consumer,
init
) {
override suspend fun preProcessPollSet(
records: ConsumerRecords<K, V>,
context: MutableMap<String, Any>
): List<ConsumerRecords<K, V>> {
return listOf(records)
}
}
}

/**
* Create a [ReceiveChannel] for unacknowledged consumer records from kafka.
Expand All @@ -32,9 +93,10 @@ import org.apache.kafka.clients.consumer.KafkaConsumer
* @param pollInterval Interval for kafka consumer [Consumer.poll] method calls.
* @param consumer The instantiated [Consumer] to use to receive from kafka.
* @param init Callback for initializing the [Consumer].
* @return A non-running [KafkaConsumerChannel] instance that must be started via [KafkaConsumerChannel.start].
* @return A non-running [KafkaConsumerChannel] instance that must be started via
* [KafkaConsumerChannel.start].
*/
fun <K, V> kafkaConsumerChannel(
fun <K, V> kafkaAckConsumerChannel(
consumerProperties: Map<String, Any>,
topics: Set<String>,
name: String = "kafka-channel",
Expand All @@ -43,20 +105,95 @@ fun <K, V> kafkaConsumerChannel(
rebalanceListener: ConsumerRebalanceListener = loggingConsumerRebalanceListener(),
init: Consumer<K, V>.() -> Unit = { subscribe(topics, rebalanceListener) },
): ReceiveChannel<UnAckedConsumerRecords<K, V>> {
return KafkaConsumerChannel(consumerProperties, topics, name, pollInterval, consumer, init).also {
Runtime.getRuntime().addShutdownHook(
Thread {
it.cancel()
return KafkaAckConsumerChannel(
consumerProperties,
topics,
name,
pollInterval,
consumer,
init
).also { Runtime.getRuntime().addShutdownHook(Thread { it.cancel() }) }
}

/**
* Acking kafka [Consumer] object implementing the [ReceiveChannel] methods.
*
* Note: Must operate in a bound thread context regardless of coroutine assignment due to internal
* kafka threading limitations for poll fetches, acknowledgements, and sends.
*
* @param consumerProperties Kafka consumer settings for this channel.
* @param topics Topics to subscribe to. Can be overridden via custom `init` parameter.
* @param name The thread pool's base name for this consumer.
* @param pollInterval Interval for kafka consumer [Consumer.poll] method calls.
* @param consumer The instantiated [Consumer] to use to receive from kafka.
* @param init Callback for initializing the [Consumer].
*/
internal class KafkaAckConsumerChannel<K, V>(
consumerProperties: Map<String, Any>,
topics: Set<String>,
name: String,
pollInterval: Duration,
consumer: Consumer<K, V>,
init: Consumer<K, V>.() -> Unit
) :
KafkaConsumerChannel<K, V, UnAckedConsumerRecords<K, V>>(
consumerProperties,
topics,
name,
pollInterval,
consumer,
init
) {
override suspend fun preProcessPollSet(
records: ConsumerRecords<K, V>,
context: MutableMap<String, Any>,
): List<UnAckedConsumerRecords<K, V>> {
log.trace { "preProcessPollSet(${records.count()})" }
val ackChannel =
Channel<CommitConsumerRecord>(capacity = records.count()).also {
context["ack-channel"] = it
}
)
val unackedRecords =
records
.groupBy { "${it.topic()}-${it.partition()}" }
.map {
val timestamp = System.currentTimeMillis()
val records =
it.value.map { UnAckedConsumerRecordImpl(it, ackChannel, timestamp) }
UnAckedConsumerRecords(records)
}
return unackedRecords
}

@Suppress("unchecked_cast")
override suspend fun postProcessPollSet(
records: List<UnAckedConsumerRecords<K, V>>,
context: Map<String, Any>
) {
log.trace { "postProcessPollSet(records:${records.sumOf { it.count() } })" }
val ackChannel = context["ack-channel"]!! as Channel<CommitConsumerRecord>
for (rs in records) {
if (rs.records.isNotEmpty()) {
val count = AtomicInteger(rs.records.size)
while (count.getAndDecrement() > 0) {
log.trace { "waiting for ${count.get()} commits" }
val it = ackChannel.receive()
log.trace { "sending to broker ack(${it.duration.toMillis()}ms):${it.asCommitable()}" }
commit(it)
log.trace { "acking the commit back to flow" }
it.commitAck.send(Unit)
}
}
}
ackChannel.close()
}
}

/**
* Kafka [Consumer] object implementing the [ReceiveChannel] methods.
* Base kafka [Consumer] object implementing the [ReceiveChannel] methods.
*
* Note: Must operate in a bound thread context regardless of coroutine assignment due to internal kafka threading
* limitations for poll fetches, acknowledgements, and sends.
* Note: Must operate in a bound thread context regardless of coroutine assignment due to internal
* kafka threading limitations for poll fetches, acknowledgements, and sends.
*
* @param consumerProperties Kafka consumer settings for this channel.
* @param topics Topics to subscribe to. Can be overridden via custom `init` parameter.
Expand All @@ -65,40 +202,46 @@ fun <K, V> kafkaConsumerChannel(
* @param consumer The instantiated [Consumer] to use to receive from kafka.
* @param init Callback for initializing the [Consumer].
*/
open class KafkaConsumerChannel<K, V>(
abstract class KafkaConsumerChannel<K, V, R>(
consumerProperties: Map<String, Any>,
topics: Set<String> = emptySet(),
name: String = "kafka-channel",
private val pollInterval: Duration = DEFAULT_POLL_INTERVAL,
private val consumer: Consumer<K, V> = KafkaConsumer(consumerProperties),
private val init: Consumer<K, V>.() -> Unit = { subscribe(topics) },
) : ReceiveChannel<UnAckedConsumerRecords<K, V>> {
) : ReceiveChannel<R> {
companion object {
private val threadCounter = AtomicInteger(0)
}

private val log = KotlinLogging.logger {}
protected val log = KotlinLogging.logger {}
private val thread =
thread(name = "$name-${threadCounter.getAndIncrement()}", block = { run() }, isDaemon = true, start = false)
private val sendChannel = Channel<UnAckedConsumerRecords<K, V>>(Channel.UNLIMITED)

private inline fun <T> Channel<T>.use(block: (Channel<T>) -> Unit) {
try {
block(this)
close()
} catch (e: Throwable) {
close(e)
}
}
thread(
name = "$name-${threadCounter.getAndIncrement()}",
block = { run() },
isDaemon = true,
start = false
)
val sendChannel = Channel<R>(Channel.UNLIMITED)

@OptIn(ExperimentalTime::class)
private fun <K, V> Consumer<K, V>.poll(duration: Duration) =
poll(duration.toJavaDuration())
private fun <K, V> Consumer<K, V>.poll(duration: Duration) = poll(duration.toJavaDuration())

private fun <T, L : Iterable<T>> L.ifEmpty(block: () -> L): L =
if (count() == 0) block() else this

@OptIn(ExperimentalCoroutinesApi::class, ExperimentalTime::class)
protected abstract suspend fun preProcessPollSet(
records: ConsumerRecords<K, V>,
context: MutableMap<String, Any>
): List<R>

protected open suspend fun postProcessPollSet(records: List<R>, context: Map<String, Any>) {}

protected fun commit(record: CommitConsumerRecord): OffsetAndMetadata {
consumer.commitSync(record.asCommitable())
return record.offsetAndMetadata
}

@OptIn(ExperimentalCoroutinesApi::class)
fun run() {
consumer.init()

Expand All @@ -108,32 +251,31 @@ open class KafkaConsumerChannel<K, V>(
try {
while (!sendChannel.isClosedForSend) {
log.trace("poll(topics:${consumer.subscription()}) ...")
val polled = consumer.poll(Duration.ZERO).ifEmpty { consumer.poll(pollInterval) }
val polled =
consumer.poll(Duration.ZERO).ifEmpty { consumer.poll(pollInterval) }
val polledCount = polled.count()
if (polledCount == 0) {
continue
}

log.trace("poll(topics:${consumer.subscription()}) got $polledCount records.")
Channel<CommitConsumerRecord>(capacity = polled.count()).use { ackChannel ->
for (it in polled.groupBy { "${it.topic()}-${it.partition()}" }) {
val timestamp = System.currentTimeMillis()
val records = it.value.map {
UnAckedConsumerRecordImpl(it, ackChannel, timestamp)
}
sendChannel.send(UnAckedConsumerRecords(records))
}

if (polledCount > 0) {
val count = AtomicInteger(polledCount)
while (count.getAndDecrement() > 0) {
val it = ackChannel.receive()
log.debug { "ack(${it.duration.toMillis()}ms):${it.asCommitable()}" }
consumer.commitSync(it.asCommitable())
it.commitAck.send(Unit)
}
}
}

// Group by topic-partition to guarantee ordering.
val records =
polled
.groupBy { "${it.topic()}-${it.partition()}" }
.values
.map { it.toConsumerRecords() }

// Convert to internal types.
val context = mutableMapOf<String, Any>()
val processSet = records.map { preProcessPollSet(it, context) }

// Send down the pipeline for processing
processSet
.onEach { it.map { sendChannel.send(it) } }
// Clean up any processing.
.map { postProcessPollSet(it, context) }
}
} finally {
log.info("${coroutineContext.job} shutting down consumer thread")
Expand All @@ -142,7 +284,9 @@ open class KafkaConsumerChannel<K, V>(
consumer.unsubscribe()
consumer.close()
} catch (ex: Exception) {
log.debug { "Consumer failed to be closed. It may have been closed from somewhere else." }
log.debug {
"Consumer failed to be closed. It may have been closed from somewhere else."
}
}
}
}
Expand All @@ -162,19 +306,23 @@ open class KafkaConsumerChannel<K, V>(
@ExperimentalCoroutinesApi
override val isClosedForReceive: Boolean = sendChannel.isClosedForReceive

@ExperimentalCoroutinesApi
override val isEmpty: Boolean = sendChannel.isEmpty
override val onReceive: SelectClause1<UnAckedConsumerRecords<K, V>> get() {
start()
return sendChannel.onReceive
}
@ExperimentalCoroutinesApi override val isEmpty: Boolean = sendChannel.isEmpty
override val onReceive: SelectClause1<R>
get() {
start()
return sendChannel.onReceive
}

override val onReceiveCatching: SelectClause1<ChannelResult<UnAckedConsumerRecords<K, V>>> get() {
start()
return sendChannel.onReceiveCatching
}
override val onReceiveCatching: SelectClause1<ChannelResult<R>>
get() {
start()
return sendChannel.onReceiveCatching
}

@Deprecated("Since 1.2.0, binary compatibility with versions <= 1.1.x", level = DeprecationLevel.HIDDEN)
@Deprecated(
"Since 1.2.0, binary compatibility with versions <= 1.1.x",
level = DeprecationLevel.HIDDEN
)
override fun cancel(cause: Throwable?): Boolean {
cancel(CancellationException("cancel", cause))
return true
Expand All @@ -185,22 +333,22 @@ open class KafkaConsumerChannel<K, V>(
sendChannel.cancel(cause)
}

override fun iterator(): ChannelIterator<UnAckedConsumerRecords<K, V>> {
override fun iterator(): ChannelIterator<R> {
start()
return sendChannel.iterator()
}

override suspend fun receive(): UnAckedConsumerRecords<K, V> {
override suspend fun receive(): R {
start()
return sendChannel.receive()
}

override suspend fun receiveCatching(): ChannelResult<UnAckedConsumerRecords<K, V>> {
override suspend fun receiveCatching(): ChannelResult<R> {
start()
return sendChannel.receiveCatching()
}

override fun tryReceive(): ChannelResult<UnAckedConsumerRecords<K, V>> {
override fun tryReceive(): ChannelResult<R> {
start()
return sendChannel.tryReceive()
}
Expand Down

0 comments on commit 51d0520

Please sign in to comment.