Skip to content

Tools 3637 #746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
46 changes: 45 additions & 1 deletion common/bsonutil/bsonutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,58 @@ func MtoD(m bson.M) bson.D {
// but would return an error if it cannot be reversed by bson.UnmarshalExtJSON.
//
// It is preferred to be used in mongodump to avoid generating un-reversible ext JSON.
func MarshalExtJSONReversible(val interface{}, canonical bool, escapeHTML bool) ([]byte, error) {
func MarshalExtJSONReversible(
val interface{},
canonical bool,
escapeHTML bool,
) ([]byte, error) {
jsonBytes, err := bson.MarshalExtJSON(val, canonical, escapeHTML)
if err != nil {
return nil, err
}

reversedVal := reflect.New(reflect.TypeOf(val)).Elem().Interface()
if unmarshalErr := bson.UnmarshalExtJSON(jsonBytes, canonical, &reversedVal); unmarshalErr != nil {
return nil, errors2.Wrap(unmarshalErr, "marshal is not reversible")
}

return jsonBytes, nil
}

// MarshalExtJSONWithBSONRoundtripConsistency is a wrapper around bson.MarshalExtJSON
// which also validates that BSON objects that are marshaled to ExtJSON objects
// return a consistent BSON object when unmarshaled.
func MarshalExtJSONWithBSONRoundtripConsistency(
val interface{},
canonical bool,
escapeHTML bool,
) ([]byte, error) {
jsonBytes, err := MarshalExtJSONReversible(val, canonical, escapeHTML)
if err != nil {
return nil, err
}

originalBSON, err := bson.Marshal(val)
if err != nil {
return nil, fmt.Errorf("could not marshal into BSON")
}

reversedVal := reflect.New(reflect.TypeOf(val)).Elem().Interface()
err = bson.UnmarshalExtJSON(jsonBytes, canonical, &reversedVal)
if err != nil {
return nil, err
}

reversedBSON, err := bson.Marshal(reversedVal)
if err != nil {
return nil, fmt.Errorf("could not marshal into BSON")
}

if !bytes.Equal(originalBSON, reversedBSON) {
return nil, fmt.Errorf(
"marshaling BSON to ExtJSON and back resulted in discrepancies",
)
}

return jsonBytes, nil
}
55 changes: 50 additions & 5 deletions common/bsonutil/bsonutil_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bsonutil

import (
"math"
"testing"
"time"

Expand Down Expand Up @@ -104,32 +105,32 @@ func TestMarshalExtJSONReversible(t *testing.T) {

tests := []struct {
val any
canonical bool
reversible bool
expectedJSON string
}{
{
bson.M{"field1": bson.M{"$date": 1257894000000}},
true,
true,
`{"field1":{"$date":{"$numberLong":"1257894000000"}}}`,
},
{
bson.M{"field1": time.Unix(1257894000, 0)},
true,
true,
`{"field1":{"$date":{"$numberLong":"1257894000000"}}}`,
},
{
bson.M{"field1": bson.M{"$date": "invalid"}},
true,
false,
``,
},
}

for _, test := range tests {
json, err := MarshalExtJSONReversible(test.val, test.canonical, false)
json, err := MarshalExtJSONReversible(
test.val,
true, /* canonical */
false, /* escapeHTML */
)
if !test.reversible {
assert.ErrorContains(t, err, "marshal is not reversible")
} else {
Expand All @@ -138,3 +139,47 @@ func TestMarshalExtJSONReversible(t *testing.T) {
assert.Equal(t, test.expectedJSON, string(json))
}
}

func TestMarshalExtJSONWithBSONRoundtripConsistency(t *testing.T) {
testtype.SkipUnlessTestType(t, testtype.UnitTestType)

tests := []struct {
val any
consistentAfterRoundtripping bool
expectedJSON string
}{
{
bson.M{"field1": bson.M{"grapes": int64(123)}},
true,
`{"field1":{"grapes":{"$numberLong":"123"}}}`,
},
{
bson.M{"field1": bson.M{"$date": 1257894000000}},
false,
``,
},
{
bson.M{"field1": bson.M{"nanField": math.NaN()}},
true,
`{"field1":{"nanField":{"$numberDouble":"NaN"}}}`,
},
}

for _, test := range tests {
json, err := MarshalExtJSONWithBSONRoundtripConsistency(
test.val,
true, /* canonical */
false, /* escapeHTML */
)
if !test.consistentAfterRoundtripping {
assert.ErrorContains(
t,
err,
"marshaling BSON to ExtJSON and back resulted in discrepancies",
)
} else {
assert.NoError(t, err)
}
assert.Equal(t, test.expectedJSON, string(json))
}
}
134 changes: 130 additions & 4 deletions common/db/buffered_bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package db
import (
"context"
"fmt"
"strings"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
Expand All @@ -26,6 +27,7 @@ const MAX_MESSAGE_SIZE_BYTES = 48000000
type BufferedBulkInserter struct {
collection *mongo.Collection
writeModels []mongo.WriteModel
docs []bson.D
docLimit int
docCount int
byteCount int
Expand Down Expand Up @@ -88,16 +90,18 @@ func (bb *BufferedBulkInserter) ResetBulk() {
bb.writeModels = bb.writeModels[:0]
bb.docCount = 0
bb.byteCount = 0
bb.docs = bb.docs[:0]
}

// Insert adds a document to the buffer for bulk insertion. If the buffer becomes full, the bulk write is performed, returning
// any error that occurs.
func (bb *BufferedBulkInserter) Insert(doc interface{}) (*mongo.BulkWriteResult, error) {
func (bb *BufferedBulkInserter) Insert(doc bson.D) (*mongo.BulkWriteResult, error) {
rawBytes, err := bson.Marshal(doc)
if err != nil {
return nil, fmt.Errorf("bson encoding error: %v", err)
}

bb.docs = append(bb.docs, doc)
return bb.InsertRaw(rawBytes)
}

Expand Down Expand Up @@ -175,9 +179,131 @@ func (bb *BufferedBulkInserter) TryFlush() (*mongo.BulkWriteResult, error) {
}

func (bb *BufferedBulkInserter) flush() (*mongo.BulkWriteResult, error) {
if bb.docCount == 0 {
return nil, nil

ctx := context.Background()

if bb.docCount == 0 {
return nil, nil
}
res, bulkWriteErr := bb.collection.BulkWrite(ctx, bb.writeModels, bb.bulkWriteOpts)
if bulkWriteErr == nil {
return res, nil
}

bulkWriteException, ok := bulkWriteErr.(mongo.BulkWriteException)
if !ok {
return res, bulkWriteErr
}

var retryDocFilters []bson.D

for _, we := range bulkWriteException.WriteErrors {
if we.Code == ErrDuplicateKeyCode {
var errDetails map[string]bson.Raw
bson.Unmarshal(we.WriteError.Raw, &errDetails)
var filter bson.D
bson.Unmarshal(errDetails["keyValue"], &filter)

exists, err := checkDocumentExistence(ctx, bb.collection, filter)
if err != nil {
return nil, err
}
if !exists {
retryDocFilters = append(retryDocFilters, filter)
} else {
}
}
}

for _, filter := range retryDocFilters {
for _, doc := range bb.docs {
var exists bool
var err error
if compareDocumentWithKeys(filter, doc) {
for range(3) {
_, err = bb.collection.InsertOne(ctx, doc)
if err == nil {
break
}
}
exists, err = checkDocumentExistence(ctx, bb.collection, filter)
if err != nil {
return nil, err
}
if exists {
break
}
}
if !exists {
return nil, fmt.Errorf("could not insert document %+v", doc)
}
}
}

res.InsertedCount += int64(len(retryDocFilters))
return res, bulkWriteErr
}

return bb.collection.BulkWrite(context.Background(), bb.writeModels, bb.bulkWriteOpts)

// extractValueByPath digs into a bson.D using a dotted path to retrieve the value
func extractValueByPath(doc bson.D, path string) (interface{}, bool) {
parts := strings.Split(path, ".")
var current interface{} = doc
for _, part := range parts {
switch curr := current.(type) {
case bson.D:
found := false
for _, elem := range curr {
if elem.Key == part {
current = elem.Value
found = true
break
}
}
if !found {
return nil, false
}
default:
return nil, false
}
}
return current, true
}

// compareDocumentWithKeys checks if the key-value pairs in doc1 exist in doc2
func compareDocumentWithKeys(doc1 bson.D, doc2 bson.D) bool {
for _, elem := range doc1 {
value, exists := extractValueByPath(doc2, elem.Key)
if !exists || value != elem.Value {
return false
}
}
return true
}

func checkDocumentExistence(ctx context.Context, collection *mongo.Collection, document bson.D) (bool, error) {
findCmd := bson.D{
{Key: "find", Value: collection.Name()},
{Key: "filter", Value: document},
{Key: "readConcern", Value: bson.D{{Key: "level", Value: "majority"}}},
}

db := collection.Database()

var result bson.M
err := db.RunCommand(ctx, findCmd).Decode(&result)
if err != nil {
return false, err
}

if cursor, ok := result["cursor"].(bson.M); ok {
if firstBatch, ok := cursor["firstBatch"].(bson.A); ok && len(firstBatch) > 0 {
return true, nil
} else {
return false, nil
}
} else {
return false, err
}

}
2 changes: 1 addition & 1 deletion common/db/buffered_bulk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestBufferedBulkInserterInserts(t *testing.T) {

errCnt := 0
for i := 0; i < 1000000; i++ {
result, err := bufBulk.Insert(bson.M{"_id": i})
result, err := bufBulk.Insert(bson.D{{"_id", i}})
if err != nil {
errCnt++
}
Expand Down
9 changes: 9 additions & 0 deletions common/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,15 @@ func CanIgnoreError(err error) bool {
return ok
case mongo.BulkWriteException:
for _, writeErr := range mongoErr.WriteErrors {

var decoded bson.M
err := bson.Unmarshal(writeErr.Raw, &decoded)
if err != nil {
return false
}
keyValue, _ := decoded["keyValue"].(bson.D)
fmt.Printf("TESTING THIS %+v\n", keyValue)

if _, ok := ignorableWriteErrorCodes[writeErr.Code]; !ok {
return false
}
Expand Down
2 changes: 1 addition & 1 deletion mongodump/metadata_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (dump *MongoDump) dumpMetadata(
}

// Finally, we send the results to the writer as JSON bytes
jsonBytes, err := bsonutil.MarshalExtJSONReversible(meta, true, false)
jsonBytes, err := bsonutil.MarshalExtJSONWithBSONRoundtripConsistency(meta, true, false)
if err != nil {
return fmt.Errorf(
"error marshaling metadata json for collection `%v`: %v",
Expand Down
Loading