diff --git a/spring-cloud-stream-binder-kafka/src/main/java/org/springframework/cloud/stream/binder/kafka/KafkaMessageChannelBinder.java b/spring-cloud-stream-binder-kafka/src/main/java/org/springframework/cloud/stream/binder/kafka/KafkaMessageChannelBinder.java index dc32c0a6..ad32cb47 100644 --- a/spring-cloud-stream-binder-kafka/src/main/java/org/springframework/cloud/stream/binder/kafka/KafkaMessageChannelBinder.java +++ b/spring-cloud-stream-binder-kafka/src/main/java/org/springframework/cloud/stream/binder/kafka/KafkaMessageChannelBinder.java @@ -114,7 +114,9 @@ import org.springframework.kafka.support.ProducerListener; import org.springframework.kafka.support.SendResult; import org.springframework.kafka.support.TopicPartitionOffset; import org.springframework.kafka.support.TopicPartitionOffset.SeekPosition; +import org.springframework.kafka.support.converter.MessageConverter; import org.springframework.kafka.support.converter.MessagingMessageConverter; +import org.springframework.kafka.support.converter.RecordMessageConverter; import org.springframework.kafka.transaction.KafkaAwareTransactionManager; import org.springframework.kafka.transaction.KafkaTransactionManager; import org.springframework.lang.Nullable; @@ -725,7 +727,7 @@ public class KafkaMessageChannelBinder extends final KafkaMessageDrivenChannelAdapter kafkaMessageDrivenChannelAdapter = new KafkaMessageDrivenChannelAdapter<>(messageListenerContainer, extendedConsumerProperties.isBatchMode() ? ListenerMode.batch : ListenerMode.record); - MessagingMessageConverter messageConverter = getMessageConverter(extendedConsumerProperties); + MessageConverter messageConverter = getMessageConverter(extendedConsumerProperties); kafkaMessageDrivenChannelAdapter.setMessageConverter(messageConverter); kafkaMessageDrivenChannelAdapter.setBeanFactory(getBeanFactory()); kafkaMessageDrivenChannelAdapter.setApplicationContext(applicationContext); @@ -744,7 +746,8 @@ public class KafkaMessageChannelBinder extends messageListenerContainer.setAfterRollbackProcessor(new DefaultAfterRollbackProcessor<>( (record, exception) -> { MessagingException payload = - new MessagingException(messageConverter.toMessage(record, null, null, null), + new MessagingException(((RecordMessageConverter) messageConverter) + .toMessage(record, null, null, null), "Transaction rollback limit exceeded", exception); try { errorInfrastructure.getErrorChannel() @@ -1003,7 +1006,10 @@ public class KafkaMessageChannelBinder extends KafkaMessageSource source = new KafkaMessageSource<>(consumerFactory, consumerProperties); - source.setMessageConverter(getMessageConverter(extendedConsumerProperties)); + MessageConverter messageConverter = getMessageConverter(extendedConsumerProperties); + Assert.isInstanceOf(RecordMessageConverter.class, messageConverter, + "'messageConverter' must be a 'RecordMessageConverter' for polled consumers"); + source.setMessageConverter((RecordMessageConverter) messageConverter); source.setRawMessageHeader(extension.isEnableDlq()); if (!extendedConsumerProperties.isMultiplex()) { @@ -1040,32 +1046,35 @@ public class KafkaMessageChannelBinder extends }); } - private MessagingMessageConverter getMessageConverter( + private MessageConverter getMessageConverter( final ExtendedConsumerProperties extendedConsumerProperties) { - MessagingMessageConverter messageConverter; + + MessageConverter messageConverter; if (extendedConsumerProperties.getExtension().getConverterBeanName() == null) { - messageConverter = new MessagingMessageConverter(); + MessagingMessageConverter mmc = new MessagingMessageConverter(); StandardHeaders standardHeaders = extendedConsumerProperties.getExtension() .getStandardHeaders(); - messageConverter - .setGenerateMessageId(StandardHeaders.id.equals(standardHeaders) + mmc.setGenerateMessageId(StandardHeaders.id.equals(standardHeaders) || StandardHeaders.both.equals(standardHeaders)); - messageConverter.setGenerateTimestamp( + mmc.setGenerateTimestamp( StandardHeaders.timestamp.equals(standardHeaders) || StandardHeaders.both.equals(standardHeaders)); + messageConverter = mmc; } else { try { messageConverter = getApplicationContext().getBean( extendedConsumerProperties.getExtension().getConverterBeanName(), - MessagingMessageConverter.class); + MessageConverter.class); } catch (NoSuchBeanDefinitionException ex) { throw new IllegalStateException( "Converter bean not present in application context", ex); } } - messageConverter.setHeaderMapper(getHeaderMapper(extendedConsumerProperties)); + if (messageConverter instanceof MessagingMessageConverter) { + ((MessagingMessageConverter) messageConverter).setHeaderMapper(getHeaderMapper(extendedConsumerProperties)); + } return messageConverter; } diff --git a/spring-cloud-stream-binder-kafka/src/test/java/org/springframework/cloud/stream/binder/kafka/KafkaBinderTests.java b/spring-cloud-stream-binder-kafka/src/test/java/org/springframework/cloud/stream/binder/kafka/KafkaBinderTests.java index 33688243..ad3ba178 100644 --- a/spring-cloud-stream-binder-kafka/src/test/java/org/springframework/cloud/stream/binder/kafka/KafkaBinderTests.java +++ b/spring-cloud-stream-binder-kafka/src/test/java/org/springframework/cloud/stream/binder/kafka/KafkaBinderTests.java @@ -127,6 +127,7 @@ import org.springframework.kafka.support.KafkaHeaderMapper; import org.springframework.kafka.support.KafkaHeaders; import org.springframework.kafka.support.SendResult; import org.springframework.kafka.support.TopicPartitionOffset; +import org.springframework.kafka.support.converter.BatchMessagingMessageConverter; import org.springframework.kafka.support.converter.MessagingMessageConverter; import org.springframework.kafka.test.core.BrokerAddress; import org.springframework.kafka.test.rule.EmbeddedKafkaRule; @@ -612,6 +613,11 @@ public class KafkaBinderTests extends DirectChannel moduleInputChannel = createBindableChannel("input", createConsumerBindingProperties(consumerProperties)); + MessagingMessageConverter mmc = new MessagingMessageConverter(); + ((GenericApplicationContext) ((KafkaTestBinder) binder).getApplicationContext()) + .registerBean("tSARmmc", MessagingMessageConverter.class, () -> mmc); + consumerProperties.getExtension().setConverterBeanName("tSARmmc"); + Binding producerBinding = binder.bindProducer("foo.bar", moduleOutputChannel, outputBindingProperties.getProducer()); Binding consumerBinding = binder.bindConsumer("foo.bar", @@ -653,6 +659,8 @@ public class KafkaBinderTests extends assertThat(topic.isConsumerTopic()).isTrue(); assertThat(topic.getConsumerGroup()).isEqualTo("testSendAndReceive"); + assertThat(KafkaTestUtils.getPropertyValue(consumerBinding, "lifecycle.recordListener.messageConverter")) + .isSameAs(mmc); producerBinding.unbind(); consumerBinding.unbind(); } @@ -670,6 +678,10 @@ public class KafkaBinderTests extends consumerProperties.getExtension().getConfiguration().put("fetch.min.bytes", "1000"); consumerProperties.getExtension().getConfiguration().put("fetch.max.wait.ms", "5000"); consumerProperties.getExtension().getConfiguration().put("max.poll.records", "2"); + BatchMessagingMessageConverter bmmc = new BatchMessagingMessageConverter(); + ((GenericApplicationContext) ((KafkaTestBinder) binder).getApplicationContext()) + .registerBean("tSARBbmmc", BatchMessagingMessageConverter.class, () -> bmmc); + consumerProperties.getExtension().setConverterBeanName("tSARBbmmc"); DirectChannel moduleInputChannel = createBindableChannel("input", createConsumerBindingProperties(consumerProperties)); @@ -709,6 +721,8 @@ public class KafkaBinderTests extends assertThat(payload.get(1)).isEqualTo("bar".getBytes()); } + assertThat(KafkaTestUtils.getPropertyValue(consumerBinding, "lifecycle.batchListener.batchMessageConverter")) + .isSameAs(bmmc); producerBinding.unbind(); consumerBinding.unbind(); }