Skip to content

Commit d1635e1

Browse files
committed
Apply patch
1 parent 5b37c91 commit d1635e1

File tree

4 files changed

+149
-35
lines changed

4 files changed

+149
-35
lines changed

src/google/protobuf/extension_set_inl.h

+18-9
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
206206
const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata,
207207
internal::ParseContext* ctx) {
208208
std::string payload;
209-
uint32_t type_id = 0;
210-
bool payload_read = false;
209+
uint32_t type_id;
210+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
211+
State state = State::kNoTag;
212+
211213
while (!ctx->Done(&ptr)) {
212214
uint32_t tag = static_cast<uint8_t>(*ptr++);
213215
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
214216
uint64_t tmp;
215217
ptr = ParseBigVarint(ptr, &tmp);
216218
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
217-
type_id = tmp;
218-
if (payload_read) {
219+
if (state == State::kNoTag) {
220+
type_id = tmp;
221+
state = State::kHasType;
222+
} else if (state == State::kHasPayload) {
223+
type_id = tmp;
219224
ExtensionInfo extension;
220225
bool was_packed_on_wire;
221226
if (!FindExtension(2, type_id, extendee, ctx, &extension,
@@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
241246
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
242247
tmp_ctx.EndedAtLimit());
243248
}
244-
type_id = 0;
249+
state = State::kDone;
245250
}
246251
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
247-
if (type_id != 0) {
252+
if (state == State::kHasType) {
248253
ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr,
249254
extendee, metadata, ctx);
250255
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
251-
type_id = 0;
256+
state = State::kDone;
252257
} else {
258+
std::string tmp;
253259
int32_t size = ReadSize(&ptr);
254260
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
255-
ptr = ctx->ReadString(ptr, size, &payload);
261+
ptr = ctx->ReadString(ptr, size, &tmp);
256262
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
257-
payload_read = true;
263+
if (state == State::kNoTag) {
264+
payload = std::move(tmp);
265+
state = State::kHasPayload;
266+
}
258267
}
259268
} else {
260269
ptr = ReadTag(ptr - 1, &tag);

src/google/protobuf/wire_format.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
657657
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
658658
// Parse a MessageSetItem
659659
auto metadata = reflection->MutableInternalMetadata(msg);
660+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
661+
State state = State::kNoTag;
662+
660663
std::string payload;
661664
uint32_t type_id = 0;
662-
bool payload_read = false;
663665
while (!ctx->Done(&ptr)) {
664666
// We use 64 bit tags in order to allow typeid's that span the whole
665667
// range of 32 bit numbers.
@@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
668670
uint64_t tmp;
669671
ptr = ParseBigVarint(ptr, &tmp);
670672
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
671-
type_id = tmp;
672-
if (payload_read) {
673+
if (state == State::kNoTag) {
674+
type_id = tmp;
675+
state = State::kHasType;
676+
} else if (state == State::kHasPayload) {
677+
type_id = tmp;
673678
const FieldDescriptor* field;
674679
if (ctx->data().pool == nullptr) {
675680
field = reflection->FindKnownExtensionByNumber(type_id);
@@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
696701
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
697702
tmp_ctx.EndedAtLimit());
698703
}
699-
type_id = 0;
704+
state = State::kDone;
700705
}
701706
continue;
702707
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
703-
if (type_id == 0) {
708+
if (state == State::kNoTag) {
704709
int32_t size = ReadSize(&ptr);
705710
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
706711
ptr = ctx->ReadString(ptr, size, &payload);
707712
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
708-
payload_read = true;
709-
} else {
713+
state = State::kHasPayload;
714+
} else if (state == State::kHasType) {
710715
// We're now parsing the payload
711716
const FieldDescriptor* field = nullptr;
712717
if (descriptor->IsExtensionNumber(type_id)) {
@@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
720725
ptr = WireFormat::_InternalParseAndMergeField(
721726
msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection,
722727
field);
723-
type_id = 0;
728+
state = State::kDone;
729+
} else {
730+
int32_t size = ReadSize(&ptr);
731+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
732+
ptr = ctx->Skip(ptr, size);
733+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
724734
}
725735
} else {
726736
// An unknown field in MessageSetItem.

src/google/protobuf/wire_format_lite.h

+18-9
Original file line numberDiff line numberDiff line change
@@ -1845,6 +1845,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18451845
// we can parse it later.
18461846
std::string message_data;
18471847

1848+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
1849+
State state = State::kNoTag;
1850+
18481851
while (true) {
18491852
const uint32_t tag = input->ReadTagNoLastTag();
18501853
if (tag == 0) return false;
@@ -1853,26 +1856,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18531856
case WireFormatLite::kMessageSetTypeIdTag: {
18541857
uint32_t type_id;
18551858
if (!input->ReadVarint32(&type_id)) return false;
1856-
last_type_id = type_id;
1857-
1858-
if (!message_data.empty()) {
1859+
if (state == State::kNoTag) {
1860+
last_type_id = type_id;
1861+
state = State::kHasType;
1862+
} else if (state == State::kHasPayload) {
18591863
// We saw some message data before the type_id. Have to parse it
18601864
// now.
18611865
io::CodedInputStream sub_input(
18621866
reinterpret_cast<const uint8_t*>(message_data.data()),
18631867
static_cast<int>(message_data.size()));
18641868
sub_input.SetRecursionLimit(input->RecursionBudget());
1865-
if (!ms.ParseField(last_type_id, &sub_input)) {
1869+
if (!ms.ParseField(type_id, &sub_input)) {
18661870
return false;
18671871
}
18681872
message_data.clear();
1873+
state = State::kDone;
18691874
}
18701875

18711876
break;
18721877
}
18731878

18741879
case WireFormatLite::kMessageSetMessageTag: {
1875-
if (last_type_id == 0) {
1880+
if (state == State::kHasType) {
1881+
// Already saw type_id, so we can parse this directly.
1882+
if (!ms.ParseField(last_type_id, input)) {
1883+
return false;
1884+
}
1885+
state = State::kDone;
1886+
} else if (state == State::kNoTag) {
18761887
// We haven't seen a type_id yet. Append this data to message_data.
18771888
uint32_t length;
18781889
if (!input->ReadVarint32(&length)) return false;
@@ -1883,11 +1894,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18831894
auto ptr = reinterpret_cast<uint8_t*>(&message_data[0]);
18841895
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
18851896
if (!input->ReadRaw(ptr, length)) return false;
1897+
state = State::kHasPayload;
18861898
} else {
1887-
// Already saw type_id, so we can parse this directly.
1888-
if (!ms.ParseField(last_type_id, input)) {
1889-
return false;
1890-
}
1899+
if (!ms.SkipField(tag, input)) return false;
18911900
}
18921901

18931902
break;

src/google/protobuf/wire_format_unittest.inc

+95-9
Original file line numberDiff line numberDiff line change
@@ -580,28 +580,54 @@ TEST(WireFormatTest, ParseMessageSet) {
580580
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
581581
}
582582

583-
TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
583+
namespace {
584+
std::string BuildMessageSetItemStart() {
584585
std::string data;
585586
{
586-
UNITTEST::TestMessageSetExtension1 message;
587-
message.set_i(123);
588-
// Build a MessageSet manually with its message content put before its
589-
// type_id.
590587
io::StringOutputStream output_stream(&data);
591588
io::CodedOutputStream coded_output(&output_stream);
592589
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
590+
}
591+
return data;
592+
}
593+
std::string BuildMessageSetItemEnd() {
594+
std::string data;
595+
{
596+
io::StringOutputStream output_stream(&data);
597+
io::CodedOutputStream coded_output(&output_stream);
598+
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
599+
}
600+
return data;
601+
}
602+
std::string BuildMessageSetTestExtension1(int value = 123) {
603+
std::string data;
604+
{
605+
UNITTEST::TestMessageSetExtension1 message;
606+
message.set_i(value);
607+
io::StringOutputStream output_stream(&data);
608+
io::CodedOutputStream coded_output(&output_stream);
593609
// Write the message content first.
594610
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
595611
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
596612
&coded_output);
597613
coded_output.WriteVarint32(message.ByteSizeLong());
598614
message.SerializeWithCachedSizes(&coded_output);
599-
// Write the type id.
600-
uint32 type_id = message.GetDescriptor()->extension(0)->number();
615+
}
616+
return data;
617+
}
618+
std::string BuildMessageSetItemTypeId(int extension_number) {
619+
std::string data;
620+
{
621+
io::StringOutputStream output_stream(&data);
622+
io::CodedOutputStream coded_output(&output_stream);
601623
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
602-
type_id, &coded_output);
603-
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
624+
extension_number, &coded_output);
604625
}
626+
return data;
627+
}
628+
void ValidateTestMessageSet(const std::string& test_case,
629+
const std::string& data) {
630+
SCOPED_TRACE(test_case);
605631
{
606632
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set;
607633
ASSERT_TRUE(message_set.ParseFromString(data));
@@ -611,6 +637,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
611637
.GetExtension(
612638
UNITTEST::TestMessageSetExtension1::message_set_extension)
613639
.i());
640+
641+
// Make sure it does not contain anything else.
642+
message_set.ClearExtension(
643+
UNITTEST::TestMessageSetExtension1::message_set_extension);
644+
EXPECT_EQ(message_set.SerializeAsString(), "");
614645
}
615646
{
616647
// Test parse the message via Reflection.
@@ -626,6 +657,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
626657
UNITTEST::TestMessageSetExtension1::message_set_extension)
627658
.i());
628659
}
660+
{
661+
// Test parse the message via DynamicMessage.
662+
DynamicMessageFactory factory;
663+
std::unique_ptr<Message> msg(
664+
factory
665+
.GetPrototype(
666+
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet::descriptor())
667+
->New());
668+
msg->ParseFromString(data);
669+
auto* reflection = msg->GetReflection();
670+
std::vector<const FieldDescriptor*> fields;
671+
reflection->ListFields(*msg, &fields);
672+
ASSERT_EQ(fields.size(), 1);
673+
const auto& sub = reflection->GetMessage(*msg, fields[0]);
674+
reflection = sub.GetReflection();
675+
EXPECT_EQ(123, reflection->GetInt32(
676+
sub, sub.GetDescriptor()->FindFieldByName("i")));
677+
}
678+
}
679+
} // namespace
680+
681+
TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
682+
std::string start = BuildMessageSetItemStart();
683+
std::string end = BuildMessageSetItemEnd();
684+
std::string id = BuildMessageSetItemTypeId(
685+
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
686+
std::string message = BuildMessageSetTestExtension1();
687+
688+
ValidateTestMessageSet("id + message", start + id + message + end);
689+
ValidateTestMessageSet("message + id", start + message + id + end);
690+
}
691+
692+
TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
693+
std::string start = BuildMessageSetItemStart();
694+
std::string end = BuildMessageSetItemEnd();
695+
std::string id = BuildMessageSetItemTypeId(
696+
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
697+
std::string other_id = BuildMessageSetItemTypeId(123456);
698+
std::string message = BuildMessageSetTestExtension1();
699+
std::string other_message = BuildMessageSetTestExtension1(321);
700+
701+
// Double id
702+
ValidateTestMessageSet("id + other_id + message",
703+
start + id + other_id + message + end);
704+
ValidateTestMessageSet("id + message + other_id",
705+
start + id + message + other_id + end);
706+
ValidateTestMessageSet("message + id + other_id",
707+
start + message + id + other_id + end);
708+
// Double message
709+
ValidateTestMessageSet("id + message + other_message",
710+
start + id + message + other_message + end);
711+
ValidateTestMessageSet("message + id + other_message",
712+
start + message + id + other_message + end);
713+
ValidateTestMessageSet("message + other_message + id",
714+
start + message + other_message + id + end);
629715
}
630716

631717
void SerializeReverseOrder(

0 commit comments

Comments
 (0)