Skip to content

Commit 7b80312

Browse files
committed
Add Netty fast-path
1 parent 695e51e commit 7b80312

File tree

2 files changed

+164
-29
lines changed

2 files changed

+164
-29
lines changed

Diff for: driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java

+137-17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.mongodb.internal.connection;
1818

19+
import com.mongodb.internal.connection.netty.NettyByteBuf;
1920
import org.bson.BsonSerializationException;
2021
import org.bson.ByteBuf;
2122
import org.bson.io.OutputBuffer;
@@ -279,34 +280,153 @@ public void close() {
279280
protected int writeCharacters(final String str, final boolean checkForNullCharacters) {
280281
ensureOpen();
281282
ByteBuf buf = getCurrentByteBuffer();
283+
if ((buf.remaining() >= str.length() + 1)) {
284+
if (buf.hasArray()) {
285+
return writeCharactersOnArray(str, checkForNullCharacters, buf);
286+
} else if (buf instanceof NettyByteBuf) {
287+
return writeCharactersOnNettyByteBuf(str, checkForNullCharacters, buf);
288+
}
289+
}
290+
return super.writeCharacters(str, 0, checkForNullCharacters);
291+
}
292+
293+
private static void validateNoNullSingleByteChars(String str, long chars, int i) {
294+
long tmp = (chars & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL;
295+
tmp = ~(tmp | chars | 0x7F7F7F7F7F7F7F7FL);
296+
if (tmp != 0) {
297+
int firstZero = Long.numberOfLeadingZeros(tmp) >>> 3;
298+
throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character "
299+
+ "at index %d", str, i + firstZero));
300+
}
301+
}
302+
303+
private static void validateNoNullAsciiCharacters(String str, long asciiChars, int i) {
304+
// simplified Hacker's delight search for zero with ASCII chars i.e. which doesn't use the MSB
305+
long tmp = asciiChars + 0x7F7F7F7F7F7F7F7FL;
306+
// MSB is 0 iff the byte is 0x00, 1 otherwise
307+
tmp = ~tmp & 0x8080808080808080L;
308+
// MSB is 1 iff the byte is 0x00, 0 otherwise
309+
if (tmp != 0) {
310+
// there's some 0x00 in the word
311+
int firstZero = Long.numberOfTrailingZeros(tmp) >> 3;
312+
throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character "
313+
+ "at index %d", str, i + firstZero));
314+
}
315+
}
316+
317+
private int writeCharactersOnNettyByteBuf(String str, boolean checkForNullCharacters, ByteBuf buf) {
282318
int i = 0;
283-
if (buf.hasArray() && (buf.remaining() >= str.length() + 1)) {
284-
byte[] array = buf.array();
285-
int pos = buf.position();
286-
int len = str.length();
287-
for (;i < len; i++) {
319+
io.netty.buffer.ByteBuf bettyBuf = ((NettyByteBuf) buf).asByteBuf();
320+
// readonly buffers, netty buffers and off-heap NIO ByteBuffer
321+
boolean slowPath = false;
322+
int batches = str.length() / 8;
323+
final int writerIndex = bettyBuf.writerIndex();
324+
// this would avoid resizing the buffer while appending: ASCII length + delimiter required space
325+
bettyBuf.ensureWritable(str.length() + 1);
326+
for (int b = 0; b < batches; b++) {
327+
i = b * 8;
328+
// read 4 chars at time to preserve the 0x0100 cases
329+
long evenChars = str.charAt(i) |
330+
str.charAt(i + 2) << 16 |
331+
(long) str.charAt(i + 4) << 32 |
332+
(long) str.charAt(i + 6) << 48;
333+
long oddChars = str.charAt(i + 1) |
334+
str.charAt(i + 3) << 16 |
335+
(long) str.charAt(i + 5) << 32 |
336+
(long) str.charAt(i + 7) << 48;
337+
// check that both the second byte and the MSB of the first byte of each pair is 0
338+
// needed for cases like \u0100 and \u0080
339+
long mergedChars = evenChars | oddChars;
340+
if ((mergedChars & 0xFF80FF80FF80FF80L) != 0) {
341+
if (allSingleByteChars(mergedChars)) {
342+
i = tryWriteAsciiChars(str, checkForNullCharacters, oddChars, evenChars, bettyBuf, writerIndex, i);
343+
}
344+
slowPath = true;
345+
break;
346+
}
347+
// all ASCII - compose them into a single long
348+
long asciiChars = oddChars << 8 | evenChars;
349+
if (checkForNullCharacters) {
350+
validateNoNullAsciiCharacters(str, asciiChars, i);
351+
}
352+
bettyBuf.setLongLE(writerIndex + i, asciiChars);
353+
}
354+
if (!slowPath) {
355+
i = batches * 8;
356+
// do the rest, if any
357+
for (; i < str.length(); i++) {
288358
char c = str.charAt(i);
289359
if (checkForNullCharacters && c == 0x0) {
290360
throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character "
291361
+ "at index %d", str, i));
292362
}
293-
if (c > 0x80) {
363+
if (c >= 0x80) {
364+
slowPath = true;
294365
break;
295366
}
296-
array[pos + i] = (byte) c;
367+
bettyBuf.setByte(writerIndex + i, c);
297368
}
298-
if (i == len) {
299-
int total = len + 1;
300-
array[pos + len] = 0;
301-
position += total;
302-
buf.position(pos + total);
303-
return len + 1;
369+
}
370+
if (slowPath) {
371+
// ith char is not ASCII:
372+
position += i;
373+
buf.position(writerIndex + i);
374+
return i + super.writeCharacters(str, i, checkForNullCharacters);
375+
} else {
376+
bettyBuf.setByte(writerIndex + str.length(), 0);
377+
int totalWritten = str.length() + 1;
378+
position += totalWritten;
379+
buf.position(writerIndex + totalWritten);
380+
return totalWritten;
381+
}
382+
}
383+
384+
private static boolean allSingleByteChars(long fourChars) {
385+
return (fourChars & 0xFF00FF00FF00FF00L) == 0;
386+
}
387+
388+
private static int tryWriteAsciiChars(String str, boolean checkForNullCharacters,
389+
long oddChars, long evenChars, io.netty.buffer.ByteBuf nettyByteBuf, int writerIndex, int i) {
390+
// all single byte chars
391+
long latinChars = oddChars << 8 | evenChars;
392+
if (checkForNullCharacters) {
393+
validateNoNullSingleByteChars(str, latinChars, i);
394+
}
395+
long msbSetForNonAscii = latinChars & 0x8080808080808080L;
396+
int firstNonAsciiOffset = Long.numberOfTrailingZeros(msbSetForNonAscii) >> 3;
397+
// that's a bit cheating :P but later phases will patch the wrongly encoded ones
398+
nettyByteBuf.setLongLE(writerIndex + i, latinChars);
399+
i += firstNonAsciiOffset;
400+
return i;
401+
}
402+
403+
private int writeCharactersOnArray(String str, boolean checkForNullCharacters, ByteBuf buf) {
404+
int i = 0;
405+
byte[] array = buf.array();
406+
int pos = buf.position();
407+
int len = str.length();
408+
for (; i < len; i++) {
409+
char c = str.charAt(i);
410+
if (checkForNullCharacters && c == 0x0) {
411+
throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character "
412+
+ "at index %d", str, i));
304413
}
305-
// ith character is not ASCII
306-
if (i > 0) {
307-
position += i;
308-
buf.position(pos + i);
414+
if (c >= 0x80) {
415+
break;
309416
}
417+
array[pos + i] = (byte) c;
418+
}
419+
if (i == len) {
420+
int total = len + 1;
421+
array[pos + len] = 0;
422+
position += total;
423+
buf.position(pos + total);
424+
return len + 1;
425+
}
426+
// ith character is not ASCII
427+
if (i > 0) {
428+
position += i;
429+
buf.position(pos + i);
310430
}
311431
return i + super.writeCharacters(str, i, checkForNullCharacters);
312432
}

Diff for: driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java

+27-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.bson.BsonSerializationException;
2121
import org.bson.ByteBuf;
2222
import org.bson.types.ObjectId;
23+
import org.jetbrains.annotations.NotNull;
2324
import org.junit.jupiter.api.DisplayName;
2425
import org.junit.jupiter.api.Test;
2526
import org.junit.jupiter.params.ParameterizedTest;
@@ -29,8 +30,10 @@
2930
import java.io.ByteArrayOutputStream;
3031
import java.io.IOException;
3132
import java.nio.ByteBuffer;
33+
import java.nio.ByteOrder;
3234
import java.nio.charset.CharacterCodingException;
3335
import java.nio.charset.StandardCharsets;
36+
import java.util.Arrays;
3437
import java.util.concurrent.ThreadLocalRandom;
3538
import java.util.function.BiConsumer;
3639
import java.util.function.Consumer;
@@ -247,37 +250,49 @@ void shouldWriteEmptyString(final boolean useBranch) {
247250
@ParameterizedTest
248251
@ValueSource(booleans = {false, true})
249252
void shouldWriteAsciiString(final boolean useBranch) {
250-
try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
251-
String v = "Java";
253+
try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new ByteBufSpecification.NettyBufferProvider())) {
254+
String v = "JavaIsACoolLanguage";
255+
byte[] expected = expectedArrayOf(v);
256+
252257
if (useBranch) {
253258
try (ByteBufferBsonOutput.Branch branch = out.branch()) {
254259
branch.writeString(v);
255260
}
256261
} else {
257262
out.writeString(v);
258263
}
259-
assertArrayEquals(new byte[] {5, 0, 0, 0, 0x4a, 0x61, 0x76, 0x61, 0}, out.toByteArray());
260-
assertEquals(9, out.getPosition());
261-
assertEquals(9, out.size());
264+
assertArrayEquals(expected, out.toByteArray());
265+
assertEquals(expected.length, out.getPosition());
266+
assertEquals(expected.length, out.size());
262267
}
263268
}
264269

270+
private static @NotNull byte[] expectedArrayOf(String v) {
271+
byte[] encoded = v.getBytes(StandardCharsets.UTF_8);
272+
ByteBuffer expected = ByteBuffer.allocate(4 + encoded.length + 1).order(ByteOrder.LITTLE_ENDIAN);
273+
expected.putInt((byte) (encoded.length + 1));
274+
expected.put(encoded);
275+
expected.put((byte) 0);
276+
return expected.array();
277+
}
278+
265279
@DisplayName("should write a UTF-8 string")
266280
@ParameterizedTest
267281
@ValueSource(booleans = {false, true})
268282
void shouldWriteUtf8String(final boolean useBranch) {
269-
try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
270-
String v = "\u0900";
283+
try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new ByteBufSpecification.NettyBufferProvider())) {
284+
String v = "JavaIs\u0080ACool\u0900";
285+
byte[] expected = expectedArrayOf(v);
271286
if (useBranch) {
272287
try (ByteBufferBsonOutput.Branch branch = out.branch()) {
273288
branch.writeString(v);
274289
}
275290
} else {
276291
out.writeString(v);
277292
}
278-
assertArrayEquals(new byte[] {4, 0, 0, 0, (byte) 0xe0, (byte) 0xa4, (byte) 0x80, 0}, out.toByteArray());
279-
assertEquals(8, out.getPosition());
280-
assertEquals(8, out.size());
293+
assertArrayEquals(expected, out.toByteArray());
294+
assertEquals(expected.length, out.getPosition());
295+
assertEquals(expected.length, out.size());
281296
}
282297
}
283298

@@ -304,7 +319,7 @@ void shouldWriteEmptyCString(final boolean useBranch) {
304319
@ParameterizedTest
305320
@ValueSource(booleans = {false, true})
306321
void shouldWriteAsciiCString(final boolean useBranch) {
307-
try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
322+
try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new ByteBufSpecification.NettyBufferProvider())) {
308323
String v = "Java";
309324
if (useBranch) {
310325
try (ByteBufferBsonOutput.Branch branch = out.branch()) {
@@ -323,7 +338,7 @@ void shouldWriteAsciiCString(final boolean useBranch) {
323338
@ParameterizedTest
324339
@ValueSource(booleans = {false, true})
325340
void shouldWriteUtf8CString(final boolean useBranch) {
326-
try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
341+
try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new ByteBufSpecification.NettyBufferProvider())) {
327342
String v = "\u0900";
328343
if (useBranch) {
329344
try (ByteBufferBsonOutput.Branch branch = out.branch()) {

0 commit comments

Comments
 (0)