Skip to content

Commit a4eefac

Browse files
author
accelerated
committed
concurrency issues in MessageBuilder internal data
1 parent 2381065 commit a4eefac

File tree

5 files changed

+47
-18
lines changed

5 files changed

+47
-18
lines changed

Diff for: include/cppkafka/message_builder.h

+9
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,15 @@ class MessageBuilder : public BasicMessageBuilder<Buffer, MessageBuilder> {
348348
void construct_buffer(Buffer& lhs, const T& rhs) {
349349
lhs = Buffer(rhs);
350350
}
351+
352+
MessageBuilder clone() const {
353+
return std::move(MessageBuilder(topic()).
354+
key(Buffer(key().get_data(), key().get_size())).
355+
payload(Buffer(payload().get_data(), payload().get_size())).
356+
timestamp(timestamp()).
357+
user_data(user_data()).
358+
internal(internal()));
359+
}
351360
};
352361

353362
/**

Diff for: include/cppkafka/message_internal.h

+9-3
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,29 @@ namespace cppkafka {
3636

3737
class Message;
3838

39-
struct Internal {
39+
class Internal {
40+
public:
4041
virtual ~Internal() = default;
4142
};
4243
using InternalPtr = std::shared_ptr<Internal>;
4344

4445
/**
4546
* \brief Private message data structure
4647
*/
47-
struct MessageInternal {
48+
class MessageInternal {
49+
public:
4850
MessageInternal(void* user_data, std::shared_ptr<Internal> internal);
4951
static std::unique_ptr<MessageInternal> load(Message& message);
52+
void* get_user_data() const;
53+
InternalPtr get_internal() const;
54+
private:
5055
void* user_data_;
5156
InternalPtr internal_;
5257
};
5358

5459
template <typename BuilderType>
55-
struct MessageInternalGuard {
60+
class MessageInternalGuard {
61+
public:
5662
MessageInternalGuard(BuilderType& builder)
5763
: builder_(builder),
5864
user_data_(builder.user_data()) {

Diff for: include/cppkafka/utils/buffered_producer.h

+19-13
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,6 @@ class CPPKAFKA_API BufferedProducer {
366366

367367
template <typename BuilderType>
368368
TrackerPtr add_tracker(BuilderType& builder) {
369-
if (!has_internal_data_ && (max_number_retries_ > 0)) {
370-
has_internal_data_ = true; //enable once
371-
}
372369
if (has_internal_data_ && !builder.internal()) {
373370
// Add message tracker only if it hasn't been added before
374371
TrackerPtr tracker = std::make_shared<Tracker>(SenderType::Async, max_number_retries_);
@@ -426,8 +423,7 @@ BufferedProducer<BufferType>::BufferedProducer(Configuration config)
426423

427424
template <typename BufferType>
428425
void BufferedProducer<BufferType>::add_message(const MessageBuilder& builder) {
429-
add_tracker(const_cast<MessageBuilder&>(builder));
430-
do_add_message(builder, MessagePriority::Low, true);
426+
add_message(Builder(builder)); //make ConcreteBuilder
431427
}
432428

433429
template <typename BufferType>
@@ -438,19 +434,26 @@ void BufferedProducer<BufferType>::add_message(Builder builder) {
438434

439435
template <typename BufferType>
440436
void BufferedProducer<BufferType>::produce(const MessageBuilder& builder) {
441-
add_tracker(const_cast<MessageBuilder&>(builder));
442-
async_produce(builder, true);
437+
if (has_internal_data_) {
438+
MessageBuilder builder_copy(builder.clone());
439+
add_tracker(builder_copy);
440+
async_produce(builder_copy, true);
441+
}
442+
else {
443+
async_produce(builder, true);
444+
}
443445
}
444446

445447
template <typename BufferType>
446448
void BufferedProducer<BufferType>::sync_produce(const MessageBuilder& builder) {
447-
TrackerPtr tracker = add_tracker(const_cast<MessageBuilder&>(builder));
448-
if (tracker) {
449+
if (has_internal_data_) {
450+
MessageBuilder builder_copy(builder.clone());
451+
TrackerPtr tracker = add_tracker(builder_copy);
449452
// produce until we succeed or we reach max retry limit
450453
std::future<bool> should_retry;
451454
do {
452455
should_retry = tracker->get_new_future();
453-
produce_message(builder);
456+
produce_message(builder_copy);
454457
wait_for_acks();
455458
}
456459
while (should_retry.get());
@@ -576,6 +579,9 @@ size_t BufferedProducer<BufferType>::get_flushes_in_progress() const {
576579

577580
template <typename BufferType>
578581
void BufferedProducer<BufferType>::set_max_number_retries(size_t max_number_retries) {
582+
if (!has_internal_data_ && (max_number_retries > 0)) {
583+
has_internal_data_ = true; //enable once
584+
}
579585
max_number_retries_ = max_number_retries;
580586
}
581587

@@ -638,12 +644,12 @@ void BufferedProducer<BufferType>::async_produce(BuilderType&& builder, bool thr
638644
if (test_params && test_params->force_produce_error_) {
639645
throw HandleException(Error(RD_KAFKA_RESP_ERR_UNKNOWN));
640646
}
641-
produce_message(std::forward<BuilderType>(builder));
647+
produce_message(builder);
642648
}
643649
catch (const HandleException& ex) {
644650
// If we have a flush failure callback and it returns true, we retry producing this message later
645651
CallbackInvoker<FlushFailureCallback> callback("flush failure", flush_failure_callback_, &producer_);
646-
if (!callback || callback(std::forward<BuilderType>(builder), ex.get_error())) {
652+
if (!callback || callback(builder, ex.get_error())) {
647653
TrackerPtr tracker = std::static_pointer_cast<Tracker>(builder.internal());
648654
if (tracker && tracker->num_retries_ > 0) {
649655
--tracker->num_retries_;
@@ -671,7 +677,7 @@ void BufferedProducer<BufferType>::on_delivery_report(const Message& message) {
671677
//Get tracker data
672678
TestParameters* test_params = get_test_parameters();
673679
TrackerPtr tracker = has_internal_data_ ?
674-
std::static_pointer_cast<Tracker>(MessageInternal::load(const_cast<Message&>(message))->internal_) : nullptr;
680+
std::static_pointer_cast<Tracker>(MessageInternal::load(const_cast<Message&>(message))->get_internal()) : nullptr;
675681
bool should_retry = false;
676682
if (message.get_error() || (test_params && test_params->force_delivery_error_)) {
677683
// We should produce this message again if we don't have a produce failure callback

Diff for: src/message.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ Message::Message(HandlePtr handle)
6868
Message& Message::load_internal() {
6969
if (user_data_) {
7070
MessageInternal* mi = static_cast<MessageInternal*>(user_data_);
71-
user_data_ = mi->user_data_;
72-
internal_ = mi->internal_;
71+
user_data_ = mi->get_user_data();
72+
internal_ = mi->get_internal();
7373
}
7474
return *this;
7575
}

Diff for: src/message_internal.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,12 @@ std::unique_ptr<MessageInternal> MessageInternal::load(Message& message) {
4545
static_cast<MessageInternal*>(message.get_handle()->_private) : nullptr);
4646
}
4747

48+
void* MessageInternal::get_user_data() const {
49+
return user_data_;
50+
}
51+
52+
InternalPtr MessageInternal::get_internal() const {
53+
return internal_;
54+
}
55+
4856
}

0 commit comments

Comments
 (0)