diff --git a/reflect_extension.go b/reflect_extension.go index 74a97bfe..3969aad2 100644 --- a/reflect_extension.go +++ b/reflect_extension.go @@ -2,12 +2,13 @@ package jsoniter import ( "fmt" - "github.com/modern-go/reflect2" "reflect" "sort" "strings" "unicode" "unsafe" + + "github.com/modern-go/reflect2" ) var typeDecoders = map[string]ValDecoder{} @@ -332,6 +333,10 @@ func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder { } func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor { + return _describeStruct(ctx, typ, nil) +} + +func _describeStruct(ctx *ctx, typ reflect2.Type, parents []reflect2.Type) *StructDescriptor { structType := typ.(*reflect2.UnsafeStructType) embeddedBindings := []*Binding{} bindings := []*Binding{} @@ -347,7 +352,16 @@ func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor { tagParts := strings.Split(tag, ",") if field.Anonymous() && (tag == "" || tagParts[0] == "") { if field.Type().Kind() == reflect.Struct { - structDescriptor := describeStruct(ctx, field.Type()) + if isRecursive(parents, field.Type()) { + return nil + } + parents = append(parents, field.Type()) + + structDescriptor := _describeStruct(ctx, field.Type(), parents) + if structDescriptor == nil { + continue + } + for _, binding := range structDescriptor.Fields { binding.levels = append([]int{i}, binding.levels...) omitempty := binding.Encoder.(*structFieldEncoder).omitempty @@ -359,7 +373,16 @@ func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor { } else if field.Type().Kind() == reflect.Ptr { ptrType := field.Type().(*reflect2.UnsafePtrType) if ptrType.Elem().Kind() == reflect.Struct { - structDescriptor := describeStruct(ctx, ptrType.Elem()) + if isRecursive(parents, field.Type()) { + return nil + } + parents = append(parents, field.Type()) + + structDescriptor := _describeStruct(ctx, ptrType.Elem(), parents) + if structDescriptor == nil { + continue + } + for _, binding := range structDescriptor.Fields { binding.levels = append([]int{i}, binding.levels...) omitempty := binding.Encoder.(*structFieldEncoder).omitempty @@ -395,6 +418,16 @@ func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor { } return createStructDescriptor(ctx, typ, bindings, embeddedBindings) } + +func isRecursive(parents []reflect2.Type, parent reflect2.Type) bool { + for _, p := range parents { + if p == parent { + return true + } + } + return false +} + func createStructDescriptor(ctx *ctx, typ reflect2.Type, bindings []*Binding, embeddedBindings []*Binding) *StructDescriptor { structDescriptor := &StructDescriptor{ Type: typ, diff --git a/reflect_struct_decoder.go b/reflect_struct_decoder.go index 92ae912d..bd23a6b3 100644 --- a/reflect_struct_decoder.go +++ b/reflect_struct_decoder.go @@ -12,22 +12,15 @@ import ( func decoderOfStruct(ctx *ctx, typ reflect2.Type) ValDecoder { bindings := map[string]*Binding{} structDescriptor := describeStruct(ctx, typ) - for _, binding := range structDescriptor.Fields { - for _, fromName := range binding.FromNames { - old := bindings[fromName] - if old == nil { - bindings[fromName] = binding - continue - } - ignoreOld, ignoreNew := resolveConflictBinding(ctx.frozenConfig, old, binding) - if ignoreOld { - delete(bindings, fromName) - } - if !ignoreNew { - bindings[fromName] = binding - } - } + + flattenedBindings := flattenFrom(structDescriptor.Fields, ctx.frozenConfig) + + orderedBindings := resolveBindings(flattenedBindings) + + for _, b := range orderedBindings { + bindings[b.name] = b.binding } + fields := map[string]*structFieldDecoder{} for k, binding := range bindings { fields[k] = binding.Decoder.(*structFieldDecoder) @@ -44,6 +37,22 @@ func decoderOfStruct(ctx *ctx, typ reflect2.Type) ValDecoder { return createStructDecoder(ctx, typ, fields) } +func flattenFrom(bindings []*Binding, cfg *frozenConfig) []*binding { + flattened := make([]*binding, 0, len(bindings)) + + for _, b := range bindings { + for _, fromName := range b.FromNames { + flattened = append(flattened, &binding{ + binding: b, + name: fromName, + hasTag: hasTag(b, cfg), + }) + } + } + + return flattened +} + func createStructDecoder(ctx *ctx, typ reflect2.Type, fields map[string]*structFieldDecoder) ValDecoder { if ctx.disallowUnknownFields { return &generalStructDecoder{typ: typ, fields: fields, disallowUnknownFields: true} diff --git a/reflect_struct_encoder.go b/reflect_struct_encoder.go index 152e3ef5..819a7c95 100644 --- a/reflect_struct_encoder.go +++ b/reflect_struct_encoder.go @@ -2,48 +2,125 @@ package jsoniter import ( "fmt" - "github.com/modern-go/reflect2" "io" "reflect" + "sort" + "strings" "unsafe" + + "github.com/modern-go/reflect2" ) +type binding struct { + binding *Binding + name string + hasTag bool +} + func encoderOfStruct(ctx *ctx, typ reflect2.Type) ValEncoder { - type bindingTo struct { - binding *Binding - toName string - ignored bool - } - orderedBindings := []*bindingTo{} + + orderedBindings := []*binding{} structDescriptor := describeStruct(ctx, typ) - for _, binding := range structDescriptor.Fields { - for _, toName := range binding.ToNames { - new := &bindingTo{ - binding: binding, - toName: toName, - } - for _, old := range orderedBindings { - if old.toName != toName { - continue - } - old.ignored, new.ignored = resolveConflictBinding(ctx.frozenConfig, old.binding, new.binding) - } - orderedBindings = append(orderedBindings, new) - } - } + + fields := flattenTo(structDescriptor.Fields, ctx.frozenConfig) + + orderedBindings = resolveBindings(fields) + if len(orderedBindings) == 0 { return &emptyStructEncoder{} } + finalOrderedFields := []structFieldTo{} for _, bindingTo := range orderedBindings { - if !bindingTo.ignored { - finalOrderedFields = append(finalOrderedFields, structFieldTo{ - encoder: bindingTo.binding.Encoder.(*structFieldEncoder), - toName: bindingTo.toName, + finalOrderedFields = append(finalOrderedFields, structFieldTo{ + encoder: bindingTo.binding.Encoder.(*structFieldEncoder), + toName: bindingTo.name, + }) + } + + return &structEncoder{typ, finalOrderedFields} +} + +func flattenTo(bindings []*Binding, cfg *frozenConfig) []*binding { + flattened := make([]*binding, 0, len(bindings)) + + for _, b := range bindings { + for _, toName := range b.ToNames { + flattened = append(flattened, &binding{ + binding: b, + name: toName, + hasTag: hasTag(b, cfg), }) } } - return &structEncoder{typ, finalOrderedFields} + + return flattened +} + +func hasTag(b *Binding, cfg *frozenConfig) bool { + before, _, _ := strings.Cut(b.Field.Tag().Get(cfg.getTagKey()), ",") + return before != "" +} + +func resolveBindings(fields []*binding) []*binding { + sort.SliceStable(fields, func(i, j int) bool { + // As per std's encoding/json, + // it sorts fields by names, index depth(here we call it levels) and tags. + // We've already sorted fields by index order in describeStruct. + // By using stable sorting, we avoid sorting them again. + if fields[i].name != fields[j].name { + return fields[i].name < fields[j].name + } + if len(fields[i].binding.levels) != len(fields[j].binding.levels) { + return len(fields[i].binding.levels) < len(fields[j].binding.levels) + } + if fields[i].hasTag != fields[j].hasTag { + return fields[i].hasTag + } + return true // equal. + }) + + orderedBindings := trimOverlappingBindings(fields) + + sort.Slice(orderedBindings, func(i, j int) bool { + left := orderedBindings[i].binding.levels + right := orderedBindings[j].binding.levels + k := 0 + for { + if left[k] < right[k] { + return true + } else if left[k] > right[k] { + return false + } + k++ + } + }) + + return orderedBindings +} + +func trimOverlappingBindings(bindings []*binding) []*binding { + out := bindings[:0] + for nameRange, i := 0, 0; i < len(bindings); i += nameRange { + for nameRange = 1; i+nameRange < len(bindings); nameRange++ { + endOfRange := bindings[i+nameRange] + if endOfRange.name != bindings[i].name { + break + } + } + if nameRange == 1 { // only one field for that name + out = append(out, bindings[i]) + } else { + fields := bindings[i : i+nameRange] + if len(fields[0].binding.levels) == len(fields[1].binding.levels) && + fields[0].hasTag == fields[1].hasTag { + continue + } + out = append(out, fields[0]) + } + } + + return out } func createCheckIsEmpty(ctx *ctx, typ reflect2.Type) checkIsEmpty { diff --git a/type_tests/struct_embedded_test.go b/type_tests/struct_embedded_test.go index dcab2a42..aca5bb3e 100644 --- a/type_tests/struct_embedded_test.go +++ b/type_tests/struct_embedded_test.go @@ -61,6 +61,9 @@ func init() { (*SameLevel2Tagged)(nil), (*EmbeddedPtr)(nil), (*UnnamedLiteral)(nil), + (*EmbeddedRecursive)(nil), + (*EmbeddedRecursive2)(nil), + (*EmbeddedRecursive3)(nil), ) } @@ -236,3 +239,42 @@ type EmbeddedPtr struct { type UnnamedLiteral struct { _ struct{} } + +type EmbeddedRecursive struct { + Foo string + Recursive1 +} + +type Recursive1 struct { + R string + Recursive2 +} + +type Recursive2 struct { + Foo string + R string + RR string + *EmbeddedRecursive +} + +type EmbeddedRecursive2 struct { + Foo string + Recursive1 + Recursive3 +} + +type Recursive3 struct { + Foo string + RR string + *EmbeddedRecursive2 +} + +type Recursive4 struct { + Bar string + Recursive3 +} + +type EmbeddedRecursive3 struct { + Foo string + *EmbeddedRecursive3 +} diff --git a/type_tests/struct_tags_test.go b/type_tests/struct_tags_test.go index ae4b80f2..cd075c8e 100644 --- a/type_tests/struct_tags_test.go +++ b/type_tests/struct_tags_test.go @@ -149,6 +149,32 @@ func init() { (*struct { Field bool `json:"中文"` })(nil), + (*struct { + Foo string `json:"Bar"` + Bar string + })(nil), + (*struct { + Foo string `json:"Bar"` + Bar string `json:"Foo"` + })(nil), + (*struct { + Foo string + Bar string `json:"Foo"` + })(nil), + (*struct { + Foo string `json:"Bar"` + Bar string `json:"Bar"` + })(nil), + (*struct { + Foo string `json:"Bar"` + Bar string + Baz string `json:"Bar"` + })(nil), + (*struct { + Foo string + F string + EmbeddedOmitEmptyE + })(nil), ) }