Skip to content

Commit 806d7e4

Browse files
Merge pull request #10544 from deannagarcia/3.20.x
Apply patch
2 parents 6439c5c + ae718b3 commit 806d7e4

File tree

4 files changed

+150
-35
lines changed

4 files changed

+150
-35
lines changed

src/google/protobuf/extension_set_inl.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
206206
const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata,
207207
internal::ParseContext* ctx) {
208208
std::string payload;
209-
uint32_t type_id = 0;
210-
bool payload_read = false;
209+
uint32_t type_id;
210+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
211+
State state = State::kNoTag;
212+
211213
while (!ctx->Done(&ptr)) {
212214
uint32_t tag = static_cast<uint8_t>(*ptr++);
213215
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
214216
uint64_t tmp;
215217
ptr = ParseBigVarint(ptr, &tmp);
216218
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
217-
type_id = tmp;
218-
if (payload_read) {
219+
if (state == State::kNoTag) {
220+
type_id = tmp;
221+
state = State::kHasType;
222+
} else if (state == State::kHasPayload) {
223+
type_id = tmp;
219224
ExtensionInfo extension;
220225
bool was_packed_on_wire;
221226
if (!FindExtension(2, type_id, extendee, ctx, &extension,
@@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
241246
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
242247
tmp_ctx.EndedAtLimit());
243248
}
244-
type_id = 0;
249+
state = State::kDone;
245250
}
246251
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
247-
if (type_id != 0) {
252+
if (state == State::kHasType) {
248253
ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr,
249254
extendee, metadata, ctx);
250255
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
251-
type_id = 0;
256+
state = State::kDone;
252257
} else {
258+
std::string tmp;
253259
int32_t size = ReadSize(&ptr);
254260
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
255-
ptr = ctx->ReadString(ptr, size, &payload);
261+
ptr = ctx->ReadString(ptr, size, &tmp);
256262
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
257-
payload_read = true;
263+
if (state == State::kNoTag) {
264+
payload = std::move(tmp);
265+
state = State::kHasPayload;
266+
}
258267
}
259268
} else {
260269
ptr = ReadTag(ptr - 1, &tag);

src/google/protobuf/wire_format.cc

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
657657
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
658658
// Parse a MessageSetItem
659659
auto metadata = reflection->MutableInternalMetadata(msg);
660+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
661+
State state = State::kNoTag;
662+
660663
std::string payload;
661664
uint32_t type_id = 0;
662-
bool payload_read = false;
663665
while (!ctx->Done(&ptr)) {
664666
// We use 64 bit tags in order to allow typeid's that span the whole
665667
// range of 32 bit numbers.
@@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
668670
uint64_t tmp;
669671
ptr = ParseBigVarint(ptr, &tmp);
670672
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
671-
type_id = tmp;
672-
if (payload_read) {
673+
if (state == State::kNoTag) {
674+
type_id = tmp;
675+
state = State::kHasType;
676+
} else if (state == State::kHasPayload) {
677+
type_id = tmp;
673678
const FieldDescriptor* field;
674679
if (ctx->data().pool == nullptr) {
675680
field = reflection->FindKnownExtensionByNumber(type_id);
@@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
696701
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
697702
tmp_ctx.EndedAtLimit());
698703
}
699-
type_id = 0;
704+
state = State::kDone;
700705
}
701706
continue;
702707
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
703-
if (type_id == 0) {
708+
if (state == State::kNoTag) {
704709
int32_t size = ReadSize(&ptr);
705710
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
706711
ptr = ctx->ReadString(ptr, size, &payload);
707712
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
708-
payload_read = true;
709-
} else {
713+
state = State::kHasPayload;
714+
} else if (state == State::kHasType) {
710715
// We're now parsing the payload
711716
const FieldDescriptor* field = nullptr;
712717
if (descriptor->IsExtensionNumber(type_id)) {
@@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
720725
ptr = WireFormat::_InternalParseAndMergeField(
721726
msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection,
722727
field);
723-
type_id = 0;
728+
state = State::kDone;
729+
} else {
730+
int32_t size = ReadSize(&ptr);
731+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
732+
ptr = ctx->Skip(ptr, size);
733+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
724734
}
725735
} else {
726736
// An unknown field in MessageSetItem.

src/google/protobuf/wire_format_lite.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,6 +1830,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18301830
// we can parse it later.
18311831
std::string message_data;
18321832

1833+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
1834+
State state = State::kNoTag;
1835+
18331836
while (true) {
18341837
const uint32_t tag = input->ReadTagNoLastTag();
18351838
if (tag == 0) return false;
@@ -1838,26 +1841,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18381841
case WireFormatLite::kMessageSetTypeIdTag: {
18391842
uint32_t type_id;
18401843
if (!input->ReadVarint32(&type_id)) return false;
1841-
last_type_id = type_id;
1842-
1843-
if (!message_data.empty()) {
1844+
if (state == State::kNoTag) {
1845+
last_type_id = type_id;
1846+
state = State::kHasType;
1847+
} else if (state == State::kHasPayload) {
18441848
// We saw some message data before the type_id. Have to parse it
18451849
// now.
18461850
io::CodedInputStream sub_input(
18471851
reinterpret_cast<const uint8_t*>(message_data.data()),
18481852
static_cast<int>(message_data.size()));
18491853
sub_input.SetRecursionLimit(input->RecursionBudget());
1850-
if (!ms.ParseField(last_type_id, &sub_input)) {
1854+
if (!ms.ParseField(type_id, &sub_input)) {
18511855
return false;
18521856
}
18531857
message_data.clear();
1858+
state = State::kDone;
18541859
}
18551860

18561861
break;
18571862
}
18581863

18591864
case WireFormatLite::kMessageSetMessageTag: {
1860-
if (last_type_id == 0) {
1865+
if (state == State::kHasType) {
1866+
// Already saw type_id, so we can parse this directly.
1867+
if (!ms.ParseField(last_type_id, input)) {
1868+
return false;
1869+
}
1870+
state = State::kDone;
1871+
} else if (state == State::kNoTag) {
18611872
// We haven't seen a type_id yet. Append this data to message_data.
18621873
uint32_t length;
18631874
if (!input->ReadVarint32(&length)) return false;
@@ -1868,11 +1879,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18681879
auto ptr = reinterpret_cast<uint8_t*>(&message_data[0]);
18691880
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
18701881
if (!input->ReadRaw(ptr, length)) return false;
1882+
state = State::kHasPayload;
18711883
} else {
1872-
// Already saw type_id, so we can parse this directly.
1873-
if (!ms.ParseField(last_type_id, input)) {
1874-
return false;
1875-
}
1884+
if (!ms.SkipField(tag, input)) return false;
18761885
}
18771886

18781887
break;

src/google/protobuf/wire_format_unittest.inc

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include <google/protobuf/stubs/casts.h>
5050
#include <google/protobuf/stubs/strutil.h>
5151
#include <google/protobuf/stubs/stl_util.h>
52+
#include <google/protobuf/dynamic_message.h>
5253

5354
// clang-format off
5455
#include <google/protobuf/port_def.inc>
@@ -581,28 +582,54 @@ TEST(WireFormatTest, ParseMessageSet) {
581582
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
582583
}
583584

584-
TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
585+
namespace {
586+
std::string BuildMessageSetItemStart() {
585587
std::string data;
586588
{
587-
UNITTEST::TestMessageSetExtension1 message;
588-
message.set_i(123);
589-
// Build a MessageSet manually with its message content put before its
590-
// type_id.
591589
io::StringOutputStream output_stream(&data);
592590
io::CodedOutputStream coded_output(&output_stream);
593591
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
592+
}
593+
return data;
594+
}
595+
std::string BuildMessageSetItemEnd() {
596+
std::string data;
597+
{
598+
io::StringOutputStream output_stream(&data);
599+
io::CodedOutputStream coded_output(&output_stream);
600+
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
601+
}
602+
return data;
603+
}
604+
std::string BuildMessageSetTestExtension1(int value = 123) {
605+
std::string data;
606+
{
607+
UNITTEST::TestMessageSetExtension1 message;
608+
message.set_i(value);
609+
io::StringOutputStream output_stream(&data);
610+
io::CodedOutputStream coded_output(&output_stream);
594611
// Write the message content first.
595612
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
596613
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
597614
&coded_output);
598615
coded_output.WriteVarint32(message.ByteSizeLong());
599616
message.SerializeWithCachedSizes(&coded_output);
600-
// Write the type id.
601-
uint32_t type_id = message.GetDescriptor()->extension(0)->number();
617+
}
618+
return data;
619+
}
620+
std::string BuildMessageSetItemTypeId(int extension_number) {
621+
std::string data;
622+
{
623+
io::StringOutputStream output_stream(&data);
624+
io::CodedOutputStream coded_output(&output_stream);
602625
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
603-
type_id, &coded_output);
604-
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
626+
extension_number, &coded_output);
605627
}
628+
return data;
629+
}
630+
void ValidateTestMessageSet(const std::string& test_case,
631+
const std::string& data) {
632+
SCOPED_TRACE(test_case);
606633
{
607634
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set;
608635
ASSERT_TRUE(message_set.ParseFromString(data));
@@ -612,6 +639,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
612639
.GetExtension(
613640
UNITTEST::TestMessageSetExtension1::message_set_extension)
614641
.i());
642+
643+
// Make sure it does not contain anything else.
644+
message_set.ClearExtension(
645+
UNITTEST::TestMessageSetExtension1::message_set_extension);
646+
EXPECT_EQ(message_set.SerializeAsString(), "");
615647
}
616648
{
617649
// Test parse the message via Reflection.
@@ -627,6 +659,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
627659
UNITTEST::TestMessageSetExtension1::message_set_extension)
628660
.i());
629661
}
662+
{
663+
// Test parse the message via DynamicMessage.
664+
DynamicMessageFactory factory;
665+
std::unique_ptr<Message> msg(
666+
factory
667+
.GetPrototype(
668+
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet::descriptor())
669+
->New());
670+
msg->ParseFromString(data);
671+
auto* reflection = msg->GetReflection();
672+
std::vector<const FieldDescriptor*> fields;
673+
reflection->ListFields(*msg, &fields);
674+
ASSERT_EQ(fields.size(), 1);
675+
const auto& sub = reflection->GetMessage(*msg, fields[0]);
676+
reflection = sub.GetReflection();
677+
EXPECT_EQ(123, reflection->GetInt32(
678+
sub, sub.GetDescriptor()->FindFieldByName("i")));
679+
}
680+
}
681+
} // namespace
682+
683+
TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
684+
std::string start = BuildMessageSetItemStart();
685+
std::string end = BuildMessageSetItemEnd();
686+
std::string id = BuildMessageSetItemTypeId(
687+
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
688+
std::string message = BuildMessageSetTestExtension1();
689+
690+
ValidateTestMessageSet("id + message", start + id + message + end);
691+
ValidateTestMessageSet("message + id", start + message + id + end);
692+
}
693+
694+
TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
695+
std::string start = BuildMessageSetItemStart();
696+
std::string end = BuildMessageSetItemEnd();
697+
std::string id = BuildMessageSetItemTypeId(
698+
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
699+
std::string other_id = BuildMessageSetItemTypeId(123456);
700+
std::string message = BuildMessageSetTestExtension1();
701+
std::string other_message = BuildMessageSetTestExtension1(321);
702+
703+
// Double id
704+
ValidateTestMessageSet("id + other_id + message",
705+
start + id + other_id + message + end);
706+
ValidateTestMessageSet("id + message + other_id",
707+
start + id + message + other_id + end);
708+
ValidateTestMessageSet("message + id + other_id",
709+
start + message + id + other_id + end);
710+
// Double message
711+
ValidateTestMessageSet("id + message + other_message",
712+
start + id + message + other_message + end);
713+
ValidateTestMessageSet("message + id + other_message",
714+
start + message + id + other_message + end);
715+
ValidateTestMessageSet("message + other_message + id",
716+
start + message + other_message + id + end);
630717
}
631718

632719
void SerializeReverseOrder(

0 commit comments

Comments
 (0)