dwb5013
dwb5013

Reputation: 173

Kafka Transaction Management in Multi-Threaded Environment with Spring Boot

I'm experiencing issues with Kafka transaction management in a multi-threaded environment using Spring Boot. Here's my scenario:

  1. I have a CancelAuthorizationLinkageListener that processes messages in batches:
@Component
@Slf4j
@RequiredArgsConstructor
public class CancelAuthorizationLinkageListener {
    private final CancelAuthorizationLinkageProcessor cancelAuthorizationLinkageProcessor;
    private final ConcurrentMessageProcessor<CancelAuthorizationLinkage> messageProcessor;

    @Bean
    public RecordMessageConverter converter() {
        return new JsonMessageConverter();
    }

    @Bean
    public BatchMessagingMessageConverter batchConverter() {
        return new BatchMessagingMessageConverter(converter());
    }

    @KafkaListener(id = "${spring.kafka.listener.cancel-auth-linkage.id}",
            topics = "${spring.kafka.listener.cancel-auth-linkage.topic.linkage}", autoStartup = "false",
            batch = "true",
            groupId = "cushion",
            concurrency = "1"
    )
    @Transactional("kafkaTransactionManager")
    public void listen(List<Message<CancelAuthorizationLinkage>> messages) {
        try {
            messageProcessor.processMessages(messages, cancelAuthorizationLinkageProcessor::process);
        } catch (Throwable ex) {
            if (ex instanceof MessagingException messagingException) {
                log.error("Failed to process message {}, triggering transaction rollback. Cause: ",
                        messagingException.getFailedMessage(),
                        ex.getCause());
            } else {
                log.error("Failed to process message. Cause: ", ex.getCause());
            }
            throw ex;
        }
    }
}
  1. CancelAuthorizationLinkageProcessor processes one message at a time. If the message processing fails, it throws an exception.
@Slf4j
@Component
@RequiredArgsConstructor
public class CancelAuthorizationLinkageProcessor {
    private final CancelAuthorizationLinkageServiceInterface cancelAuthorizationLinkageService;
    private final ValidatedCancelAuthorizationLinkageFactory validatedCancelAuthorizationLinkageFactory;

    /**
     * Links authorization information via API
     *
     * @param cancelAuthorizationLinkage to be processed, never {@code null}.
     */
    public void process(
            @NonNull CancelAuthorizationLinkage cancelAuthorizationLinkage)
            throws InvalidValueException {
        // Perform validation
        final ValidatedCancelAuthorizationLinkage validatedCancelAuthorizationLinkage =
                validatedCancelAuthorizationLinkageFactory.create(cancelAuthorizationLinkage);
        if (cancelAuthorizationLinkageService.isSingleCancel(validatedCancelAuthorizationLinkage)) {
            // Todo: Required for MVP2?
            log.info("This authorization is not linked. authorization: [{}]", cancelAuthorizationLinkage);
        } else {
            CancelAuthorizationLinkageWriterResource resource = CancelAuthorizationLinkageWriterResource.builder()
                    .apiResource(
                            CancelNotificationRequestResourceFactory.create(validatedCancelAuthorizationLinkage))
                    .authorization(cancelAuthorizationLinkage.toString())
                    .build();
            cancelAuthorizationLinkageService.linkageAuthorization(
                    resource.getApiResource());
        }
    }
}
  1. The ConcurrentMessageProcessor processes messages concurrently:
@Component
@RequiredArgsConstructor
@Slf4j
public class ConcurrentMessageProcessor<T> {
    private final MessageRetryHandler<T> messageRetryHandler;
    private final ListenerPropertiesService listenerPropertiesService;
    private final ThreadPoolTaskExecutor executor;

    /**
     * Processes a list of messages concurrently.
     *
     * @param messages  List of messages to process
     * @param processor Consumer to process the messages
     */
    public void processMessages(List<Message<T>> messages, ThrowingConsumer<T> processor) {
        List<CompletableFuture<SendResult<String, T>>> processingFutures = createProcessingFutures(messages, processor);
        waitForAllProcessingToComplete(processingFutures);
    }

    /**
     * Creates processing futures for each message.
     *
     * @param messages  List of messages to process
     * @param processor Consumer to process the messages
     * @return List of processing futures
     */
    private List<CompletableFuture<SendResult<String, T>>> createProcessingFutures(List<Message<T>> messages,
                                                                                   ThrowingConsumer<T> processor) {
        return messages.stream()
                .map(message -> processMessageWithRetry(message, processor))
                .toList();
    }

    /**
     * Processes a message and retries if necessary.
     *
     * @param message   Message to process
     * @param processor Consumer to process the message
     * @return CompletableFuture containing the result
     */
    private CompletableFuture<SendResult<String, T>> processMessageWithRetry(Message<T> message,
                                                                             ThrowingConsumer<T> processor) {
        return CompletableFuture.supplyAsync(() -> tryProcessMessage(message, processor), executor)
                .thenCompose(result -> {
                    // If retry is not needed, return null
                    if (result.isSuccess()) {
                        return CompletableFuture.completedFuture(null);
                    } else {
                        return retryMessage(message, listenerPropertiesService, result.getException());
                    }
                });
    }

    /**
     * Attempts to process a message.
     *
     * @param message   Message to process
     * @param processor Consumer to process the message
     * @return Process result
     */
    private ProcessResult tryProcessMessage(Message<T> message, ThrowingConsumer<T> processor) {
        try {
            processor.accept(message.getPayload());
            log.debug("Successfully processed message: {}", message);
            return new ProcessResult(true, null);
        } catch (Exception e) {
            log.debug("Failed to process message: {}", message, e);
            return new ProcessResult(false, e);
        }
    }

    /**
     * Retries a message.
     *
     * @param message                   Message to retry
     * @param listenerPropertiesService Listener properties service
     * @param originalException         Original exception
     * @return CompletableFuture containing the retry result
     */
    private CompletableFuture<SendResult<String, T>> retryMessage(Message<T> message,
                                                                  ListenerPropertiesService listenerPropertiesService,
                                                                  Exception originalException) {
        try {
            return messageRetryHandler.handle(message, listenerPropertiesService, originalException)
                    .thenApply(result -> {
                        log.debug("Successfully retried message: {}", message);
                        return result;
                    });
        } catch (Exception e) {
            throw new MessagingException(message, e);
        }
    }

    /**
     * Waits for all message processing to complete.
     * The timeout is dynamically set based on the number of messages.
     *
     * @param futures List of processing futures
     */
    private void waitForAllProcessingToComplete(List<CompletableFuture<SendResult<String, T>>> futures) {
        try {
            // Wait for all CompletableFutures to complete
            // Timeout is set to message count * 2 seconds
            CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]))
                    .orTimeout(futures.size() * 2L, TimeUnit.SECONDS).join();
        } catch (CompletionException ex) {
            handleProcessingException(ex);
        }
    }

    private void handleProcessingException(CompletionException ex) {
        Throwable cause = ex.getCause();
        if (cause instanceof MessagingException messagingException) {
            throw messagingException;
        } else {
            throw new MessagingException("Error processing messages", cause);
        }
    }

    @Data
    private static class ProcessResult {
        private final boolean success;
        private final Exception exception;
    }
}

Test:

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.DEFINED_PORT,
        properties = {
                "spring.batch.job.name=cancelAuthorizationLinkageJob",
                "bootstrap-servers: ${spring.embedded.kafka.brokers}"
        })
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@DirtiesContext
@EmbeddedKafka(
        partitions = 1,
        topics = {
                "${spring.kafka.listener.cancel-auth-linkage.topic.linkage}",
                "${spring.kafka.listener.cancel-auth-linkage.topic.retry}"
        },
        count = 3)
class CancelAuthorizationLinkageListenerTest {
    @SpyBean
    private CancelAuthorizationLinkageServiceInterface cancelAuthorizationLinkageService;

    @Autowired
    private KafkaTemplate<String, CancelAuthorizationLinkage> kafkaTemplate;

    @Autowired
    private CancelAuthorizationLinkageTestService cancelAuthorizationLinkageTestService;

    @SpyBean
    private CancelAuthorizationLinkageListener listener;

    @SpyBean
    private CancelAuthorizationLinkageListenerProperties properties;

    @Autowired
    private EmbeddedKafkaBroker embeddedKafka;

    private ListAppender<ILoggingEvent> logWatcher;

    @BeforeEach
    void setUp() {
        logWatcher = new ListAppender<>();
        logWatcher.start();
        Logger logger = (Logger) LoggerFactory.getLogger(CancelAuthorizationLinkageListener.class);
        logger.addAppender(logWatcher);
    }

    @Test
    @DisplayName("Rollback transaction and throw an exception if an error occurs during retry")
    void testHandleRetry_RetryableException_3() throws Exception {
        // First attempt: No retry, commit successful, offset updated
        List<Message<CancelAuthorizationLinkage>> firstMessages =
                cancelAuthorizationLinkageTestService.createCancelAuthorizationLinkageMessage(3, 0,
                        properties.getLinkageTopic(), 0);
        kafkaTemplate.executeInTransaction(k ->
                firstMessages.stream().map(k::send).toList()
        );
        await()
                .atMost(1, TimeUnit.SECONDS)
                .untilAsserted(() -> verify(listener, atLeastOnce()).listen(any()));
        Map<String, Object> config =
                KafkaTestUtils.consumerProps(this.embeddedKafka.getBrokersAsString(), "cushion", "false");
        config.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed");
        Long firstCommittedOffset;
        Long firstEndOffset;
        AtomicLong secondCommittedOffset = new AtomicLong(0);
        Long secondEndOffset;
        Long thirdCommittedOffset;
        Long thirdEndOffset;
        try (KafkaConsumer<?, ?> consumer = new KafkaConsumer<>(config)) {
            firstCommittedOffset = awaitAndGetLastCommittedOffset(consumer, properties.getLinkageTopic(), 0);
            firstEndOffset = awaitAndGetEndOffset(consumer, properties.getLinkageTopic(), 0);

            // Second attempt: Retry, commit failed, offset not updated, retain the offset from the first attempt
            doReturn(false).when(cancelAuthorizationLinkageService).isSingleCancel(any());
            doThrow(new InternalApiServerErrorException("")).when(cancelAuthorizationLinkageService)
                    .linkageAuthorization(any());
            doReturn(null).when(properties).getRetryTopic();
            List<Message<CancelAuthorizationLinkage>> failedMessages =
                    cancelAuthorizationLinkageTestService.createCancelAuthorizationLinkageMessage(3, 0,
                            properties.getLinkageTopic(), 0);
            kafkaTemplate.executeInTransaction(k ->
                    failedMessages.stream().map(k::send).toList()
            );
            await()
                    .atMost(2, TimeUnit.SECONDS)
                    .pollInterval(500, TimeUnit.MILLISECONDS)
                    .untilAsserted(() -> verify(listener, atLeastOnce()).listen(any()));
            await()
                    .atMost(5, TimeUnit.SECONDS)
                    .pollInterval(500, TimeUnit.MILLISECONDS)
                    .until(() -> logWatcher.list.stream()
                            .filter(lg -> lg.getLevel().equals(Level.ERROR))
                            .anyMatch(
                                    lg -> lg.getFormattedMessage().contains("triggering transaction rollback. Cause: "))
                    );
            await()
                    .atMost(2, TimeUnit.SECONDS)
                    .pollInterval(500, TimeUnit.MILLISECONDS)
                     .untilAsserted(() -> {
                        secondCommittedOffset.set(
                                awaitAndGetLastCommittedOffset(consumer, properties.getLinkageTopic(), 0));
                        assertEquals(firstCommittedOffset, secondCommittedOffset.get());
                    });
            secondEndOffset = awaitAndGetEndOffset(consumer, properties.getLinkageTopic(), 0);
        }
    }

    private static Long awaitAndGetLastCommittedOffset(Consumer<?, ?> consumer, String topic, int partition) {
        return await()
                .atMost(10, TimeUnit.SECONDS)
                .pollInterval(100, TimeUnit.MILLISECONDS)
                .until(() -> {
                    Map<TopicPartition, OffsetAndMetadata> tps =
                            consumer.committed(Set.of(new TopicPartition(topic, partition)));
                    return tps.entrySet().stream()
                            .filter(entry -> entry.getKey().topic().equals(topic))
                            .map(Map.Entry::getValue)
                            .filter(Objects::nonNull)
                            .map(OffsetAndMetadata::offset)
                            .findFirst();

                }, Optional::isPresent)
                .orElseThrow();
    }

    private static Long awaitAndGetEndOffset(Consumer<?, ?> consumer, String topic, int partition) {
        return await()
                .atMost(10, TimeUnit.SECONDS)
                .pollInterval(100, TimeUnit.MILLISECONDS)
                .until(() -> {
                    Map<TopicPartition, Long> tps = consumer.endOffsets(List.of(new TopicPartition(topic, partition)));
                    return tps.entrySet().stream()
                            .filter(entry -> entry.getKey().topic().equals(topic))
                            .map(Map.Entry::getValue)
                            .findFirst();
                }, Optional::isPresent)
                .orElseThrow();
    }
}

Expected Behavior:

When the retry handler throws an exception during the second batch, I expect the listener to rollback the entire transaction. The final committed offset should be the same as the offset after the first successful batch.

Actual Behavior:

The committed offset after the second batch is much larger than the offset after the first batch, suggesting that the transaction was not rolled back as expected. Question:

  1. Is my suspicion correct that the transaction management is failing due to the multi-threaded environment?
  2. How should I manage transactions in a multi-threaded environment with Kafka and Spring Boot?
  3. What modifications should I make to my code to ensure proper transaction management?

Any insights or suggestions would be greatly appreciated. Thank you!

EDITED DefaultMessageRetryHandler based on the type of consumption exception, will send the failed consumption messages back to the current topic or to another designated error topic.

@Component
@RequiredArgsConstructor
@Slf4j
public class DefaultMessageRetryHandler<T> implements MessageRetryHandler<T> {
    private final KafkaTemplate<String, T> kafkaTemplate;

    @Override
    public CompletableFuture<SendResult<String, T>> handle(Message<T> message,
                                                           ListenerPropertiesService listenerPropertiesService,
                                                           Exception exception) {
        log.debug("Retry message: {}", message);
        String targetTopic = getTargetTopic(message, listenerPropertiesService, exception);
        Message<T> retryMessage = this.buildMessage(message, listenerPropertiesService, targetTopic);
        return kafkaTemplate.executeInTransaction(k -> k.send(retryMessage));
    }

    protected String getTargetTopic(Message<T> message, ListenerPropertiesService listenerPropertiesService,
                                    Exception exception) {
        if (isRetryableException(exception) && !hasRetryLimitExceeded(message, listenerPropertiesService)) {
            String linkageTopic = listenerPropertiesService.getLinkageTopic();
            if (linkageTopic == null) {
                throw new MessageRetryException("Linkage topic not configured.");
            }
            return linkageTopic;
        } else {
            String retryTopic = listenerPropertiesService.getRetryTopic();
            if (retryTopic == null) {
                throw new MessageRetryException("Retry topic not configured.");
            }
            return retryTopic;
        }
    }

    protected boolean isRetryableException(Exception exception) {
        return exception instanceof InternalApiServerErrorException;
    }

    protected int getRetryCount(Message<T> message, ListenerPropertiesService listenerPropertiesService) {
        String nextRetryAttemptHeader = this.getNextRetryAttemptHeader(listenerPropertiesService);
        Integer nextRetryAttempt = message.getHeaders().get(nextRetryAttemptHeader, Integer.class);
        return nextRetryAttempt == null ? 0 : nextRetryAttempt;
    }

    protected boolean hasRetryLimitExceeded(Message<T> message, ListenerPropertiesService listenerPropertiesService) {
        int retryCount = getRetryCount(message, listenerPropertiesService);
        Integer retryLimit = listenerPropertiesService.getRetryLimit();
        if (retryLimit == null) {
            throw new MessageRetryException("Retry limit not configured.");
        }
        return retryCount >= retryLimit;
    }

    protected String getNextRetryAttemptHeader(ListenerPropertiesService listenerPropertiesService) {
        String nextRetryAttemptHeader = listenerPropertiesService.getNextRetryAttemptHeader();
        if (nextRetryAttemptHeader == null) {
            throw new MessageRetryException("Retry attempt header not configured.");
        }
        return nextRetryAttemptHeader;
    }

    protected Message<T> buildMessage(Message<T> srcMessage, ListenerPropertiesService listenerPropertiesService,
                                      String targetTopic) {
        String nextRetryAttemptHeader = this.getNextRetryAttemptHeader(listenerPropertiesService);
        Object receivedKey = srcMessage.getHeaders().get(KafkaHeaders.RECEIVED_KEY);
        if (receivedKey == null) {
            throw new MessageRetryException("Received key header is null");
        }

        return MessageBuilder.fromMessage(srcMessage)
                .setHeader(nextRetryAttemptHeader, getRetryCount(srcMessage, listenerPropertiesService) + 1)
                .setHeader(KafkaHeaders.KEY, receivedKey)
                .setHeader(KafkaHeaders.TOPIC, targetTopic)
                .build();
    }
}

Edited2

@Transcational
class Listener {
    public void listener (List<Message> messages) {
        concurrentProcessor.processBatch(messages);
    }
}

class ConcurrentProcessor {

    public void processBatch(List<Message> messages) {
            try {
                multiThread(faildMessages);
            } catch (Exception ex) {
                throw new RuntimeErrorException(ex);
            }
    }

    public CompletableFuture multiThread(Message message, CallBackFunction retry) {
        return CompletableFuture.supplyAsync(() -> processSingle.thenCompose(result -> {
            if(ok()) {
                return CompletableFuture.completedFuture(null);
            }else {
                return retry(result);
            }

        }))
    }
    
    public void processSingle(Message message) {
        // do something
        if (!ok())
            throw new Exception();
    }
        
    public CompletableFuture void retry(Message<String> message) {
        // do something
        kafkaTemplate.executeInTransaction(k -> faildMessage(k::send))
    }
}

Upvotes: 0

Views: 54

Answers (0)

Related Questions