Nipel-Crumple
Nipel-Crumple

Reputation: 99

Wait until @KafkaListener finishes processing message in tests with @EmbeddedKafka

I have a @KafkaListener consumer and want to write integration test. The fact is that it seems to be difficult to find the exact moment when method Consumer#consume finished its execution to perform some asserts after message was processed and some state in database has changed.

@Component
public class Consumer {

    private final Service service;

    @KafkaListener(id = "id", groupId = "group", topics = "topic", containerFactory = "factory")
    public void consume(@Payload Message message, Acknowledgment acknowledgment) {
        service.process(message);
        acknowledgment.acknowledge();
    }

}

Test

@SpringBootTest
@EmbeddedKafka
void class Testing {
    // some useful beans 

    @SpyBean
    private Consumer consumer;

    @Test
    void shoudConsume() throws Exception {
        Message message = new Message();
        String topic = "topic";
        Map<String, Object> senderProps = KafkaTestUtils.producerProps(embeddedKafka);
        new KafkaProducer<String, String>(senderProps).send(new ProducerRecord<>(topic, message))
                .get(1L, TimeUnit.SECONDS);

        Mockito.verify(consumer, Mockito.timeout(1_000L)).consume(any(Message.class), any(Acknowledgment.class));
        // perform some asserts
    }

The fact is that if I put Thread.sleep(1000L) the consumer processes message and all works fine but with Mockito it doesn't work, because all asserts executes before consumer finished its execution of method Consumer#consume.

Is there is a opportunity (using Listeners, or etc) to catch the moment when @KafkaListener-consumer acknowledged/finished message processing to perform asserts with appropriate database state? The integration test is needed to be sure that end-to-end functionality works fine.

Also I tried to make #verify checks on @SpyBean private Service service, method Service#process, but it doesn't work too.

Upvotes: 5

Views: 10880

Answers (3)

rios0rios0
rios0rios0

Reputation: 925

If you want to do something easier, you can check these options:

  1. Using Awaitility (when you only need to check the effect of the queue processing);
  @Test
  public void testMessageSendReceive_Awaitility() {
    producer.getMysource()
      .output()
      .send(MessageBuilder.withPayload("payload")
        .setHeader("type", "string")
        .build());

    waitAtMost(5, TimeUnit.SECONDS)
      .untilAsserted(() -> {
        then("payload").isEqualTo(
          EmbeddedKafkaAwaitilityTest.this.consumer.getReceivedMessage());
      });
  }
  1. Using CountDownLach (when you don't have access to the listener injected, as a @SpringBootTest without @Autowired your class, for example, this could be a bad idea);
  @Test
  public void testMessageSendReceive() throws InterruptedException {
    producer.getMysource()
      .output()
      .send(MessageBuilder.withPayload("payload")
        .setHeader("type", "string")
        .build());

    latch.await();
    assertThat(consumer.getReceivedMessage()).isEqualTo("payload");
  }
  1. You can also create a BlockingQueue (but I don't think that is a good option).
BlockingQueue<ConsumerRecord<String, String>> consumerRecords;
consumerRecords = new LinkedBlockingQueue<>();
consumerRecords.poll(10, TimeUnit.SECONDS);

References:

Another approach (I didn't test):

Upvotes: 4

Sergio Cervera
Sergio Cervera

Reputation: 11

I had the same problem and solved it by creating the following class

@Primary
@Service
class MyCustomConsumerForTest(
    // Services
) : MyConsumer(// Services) {

    val latch = CountDownLatch(1)

    override fun listen(message: String) {
        super.listen(message)
        latch.countDown()
    }
}

and my consumer

@Service
class MyConsumer(
    private val service
) {

    @KafkaListener(
        topics = ["topic"])
    fun listen(message: String) {
        //my process
        service.somefunction(foo)
        .......
    }
}

and my test

@EnableKafka
@SpringBootTest(classes = [MyCustomConsumerForTest::class,
    KafkaConfig::class])
@EmbeddedKafka(
    partitions = 1,
    controlledShutdown = false,
    brokerProperties = [
        "listeners=PLAINTEXT://localhost:9092",
        "port=9092"
    ])
@ActiveProfiles("test")
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class CampaignDataValidatorIntegrationTest {

    private val TOPIC_KAFKA = "topic"
 
    @Autowired
    private lateinit var embeddedKafkaBroker: EmbeddedKafkaBroker

    @Autowired
    private lateinit var listener: MyCustomConsumerForTest

    private lateinit var container: KafkaMessageListenerContainer<String, String>
    private lateinit var records: BlockingQueue<ConsumerRecord<String, String>>

    @SpyBean
    private lateinit var service: Service

    @BeforeAll
    fun setUp() {
        val configs = HashMap(KafkaTestUtils.consumerProps("consumer", "false", embeddedKafkaBroker))
        val consumerFactory = DefaultKafkaConsumerFactory(configs, StringDeserializer(), StringDeserializer())
        val containerProperties = ContainerProperties(TOPIC_KAFKA)
        container = KafkaMessageListenerContainer(consumerFactory, containerProperties)
        records = LinkedBlockingQueue()
        container.setupMessageListener(MessageListener<String, String> { records.add(it) })
        container.start()
        embeddedKafkaBroker.partitionsPerTopic.let { ContainerTestUtils.waitForAssignment(container, it) }
    }

    @AfterAll
    fun tearDown() {
        logger.info("Stop Listener")
        container.stop()
    }

    @Test
    fun kafkaSetup_withTopic_ensureSendMessageIsReceived() {
        // Arrange
        val configs = HashMap(KafkaTestUtils.producerProps(embeddedKafkaBroker))
        val producer = DefaultKafkaProducerFactory(configs, StringSerializer(), StringSerializer()).createProducer()

        // Act
        producer.send(ProducerRecord<String, String>(TOPIC_KAFKA, file))
        producer.flush()

        // Assert
        val singleRecord = records.poll(1, TimeUnit.MILLISECONDS)

        listener.latch.await(1000, TimeUnit.MILLISECONDS)

        assert(singleRecord != null)

        verify(service, times(1)).validate(anyOrNull())

        argumentCaptor<Foo>().apply {
            verify(service, times(1)).somefunction(capture())

            Assertions.assertEquals(1, allValues.size)
            Assertions.assertEquals("text", firstValue.text)
        }
    }
}

and my kafkaconfig

@Configuration
@EnableKafka
class KafkaConfig {

    @Value("\${kafka.bootstrap-servers}")
    private lateinit var bootstrapAddress: String

    @Value("\${kafka.consumer.group-id}")
    private lateinit var groupId: String

    @Bean
    fun consumerFactory(): ConsumerFactory<String, String> {
        val props = HashMap<String, Any>()
        props[ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG] = bootstrapAddress
        props[ConsumerConfig.GROUP_ID_CONFIG] = groupId
        props[ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG] = StringDeserializer::class.java
        props[ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG] = StringDeserializer::class.java
        props[ConsumerConfig.AUTO_OFFSET_RESET_CONFIG] = "latest"
        return DefaultKafkaConsumerFactory(props)
    }

    @Bean
    fun kafkaListenerContainerFactory(): ConcurrentKafkaListenerContainerFactory<String, String> {

        val factory = ConcurrentKafkaListenerContainerFactory<String, String>()
        factory.consumerFactory = consumerFactory()
        factory.containerProperties.isMissingTopicsFatal = false
        return factory
    }
}

Upvotes: 1

Javier Gonzalez Benito
Javier Gonzalez Benito

Reputation: 362

With the next method, you can poll events from 2 topics, in an interval of N seconds. You must call fetchEventFromOutputTopic with enough time. I use it with kafka streams and it works fine.

 private Map<String, List<Foo>> fetchEventFromOutputTopic(int seconds) throws Exception {
    Map<String, List<Foo>> result = new HashMap<>();
    result.put("topic-out-0", new ArrayList<>());
    result.put("topic-out-1", new ArrayList<>());

    int i = 0;
    while (i < seconds) {
        ConsumerRecords<String, Event> records = consumer.poll(Duration.ofSeconds(1));
        records.records("topic-out-0").forEach(record -> result.get("topic-out-0").add(record.value()));
        records.records("topic-out-1").forEach(record -> result.get("topic-out-1").add(record.value()));
        Thread.sleep(1000);
        i++;
    }
    return result;
}

Upvotes: 0

Related Questions