Skip to content

Commit 2b5f030

Browse files
thinaihmosiac1
authored andcommitted
Handle streaming exceptions as bad requests instead of internal errors
1 parent 261608d commit 2b5f030

File tree

5 files changed

+36
-34
lines changed

5 files changed

+36
-34
lines changed

trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java

+11-9
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515

1616
import com.google.common.base.Splitter;
1717
import io.trino.aws.proxy.spi.signing.ChunkSigningSession;
18+
import jakarta.ws.rs.WebApplicationException;
1819

19-
import java.io.EOFException;
2020
import java.io.IOException;
2121
import java.io.InputStream;
2222
import java.util.List;
2323
import java.util.Map;
2424
import java.util.Optional;
2525

26+
import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST;
2627
import static java.util.Objects.requireNonNull;
2728

2829
class AwsChunkedInputStream
@@ -61,7 +62,7 @@ public int read()
6162

6263
int i = delegate.read();
6364
if (i < 0) {
64-
throw new EOFException("Unexpected end of stream");
65+
throw new WebApplicationException("Unexpected end of stream", BAD_REQUEST);
6566
}
6667

6768
chunkSigningSession.write((byte) (i & 0xff));
@@ -82,7 +83,7 @@ public int read(byte[] b, int off, int len)
8283

8384
int count = delegate.read(b, off, len);
8485
if (count < 0) {
85-
throw new EOFException("Unexpected end of stream");
86+
throw new WebApplicationException("Unexpected end of stream", BAD_REQUEST);
8687
}
8788

8889
chunkSigningSession.write(b, off, count);
@@ -116,7 +117,7 @@ private void updateBytesRemaining(int count)
116117
nextChunk();
117118
}
118119
else if (bytesRemainingInChunk < 0) {
119-
throw new IllegalStateException("bytesRemainingInChunk has gone negative: " + bytesRemainingInChunk);
120+
throw new WebApplicationException("bytesRemainingInChunk has gone negative: " + bytesRemainingInChunk, BAD_REQUEST);
120121
}
121122
}
122123

@@ -195,10 +196,11 @@ private void nextChunk()
195196
while (false);
196197

197198
if (!success) {
198-
throw new IOException("Invalid chunk header: " + header);
199+
throw new WebApplicationException("Invalid chunk header: " + header, BAD_REQUEST);
199200
}
200201
if (bytesAccountedFor > decodedContentLength) {
201-
throw new IllegalStateException("chunked data headers report a larger size than originally declared in the request: declared %s sent %s".formatted(decodedContentLength, bytesAccountedFor));
202+
throw new WebApplicationException("chunked data headers report a larger size than originally declared in the request: declared %s sent %s".formatted(decodedContentLength, bytesAccountedFor),
203+
BAD_REQUEST);
202204
}
203205
}
204206

@@ -207,7 +209,7 @@ private void readEmptyLine()
207209
{
208210
String crLf = readLine();
209211
if (!crLf.isEmpty()) {
210-
throw new IOException("Expected CR/LF. Instead read: " + crLf);
212+
throw new WebApplicationException("Expected CR/LF. Instead read: " + crLf, BAD_REQUEST);
211213
}
212214
}
213215

@@ -219,7 +221,7 @@ private String readLine()
219221
int i = delegate.read();
220222
if (i < 0) {
221223
delegateIsDone = true;
222-
throw new EOFException("Unexpected end of stream");
224+
throw new WebApplicationException("Unexpected end of stream", BAD_REQUEST);
223225
}
224226
if (i == '\r') {
225227
break;
@@ -229,7 +231,7 @@ private String readLine()
229231

230232
int i = delegate.read();
231233
if (i != '\n') {
232-
throw new IOException("Expected LF. Instead read: " + i);
234+
throw new WebApplicationException("Expected LF. Instead read: " + i, BAD_REQUEST);
233235
}
234236

235237
return line.toString();

trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/HashCheckInputStream.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.InputStream;
2323
import java.util.Optional;
2424

25+
import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST;
2526
import static jakarta.ws.rs.core.Response.Status.UNAUTHORIZED;
2627
import static java.util.Objects.requireNonNull;
2728

@@ -106,8 +107,7 @@ private void updateBytesRead(int count)
106107
bytesRead += count;
107108
expectedLength.ifPresent(expected -> {
108109
if (bytesRead > expected) {
109-
log.debug("More bytes read than expected. Expected: %s, Actual: %s", expected, bytesRead);
110-
throw new WebApplicationException(UNAUTHORIZED);
110+
throw new WebApplicationException("More bytes read than expected. Expected: %s, Actual: %s".formatted(expected, bytesRead), BAD_REQUEST);
111111
}
112112

113113
if (bytesRead == expected) {

trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/StreamingResponseHandler.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,10 @@ private void resume(Object result)
101101
{
102102
switch (result) {
103103
case WebApplicationException exception -> resume(exception.getResponse());
104-
case Throwable exception when Throwables.getRootCause(exception) instanceof WebApplicationException webApplicationException -> resume(webApplicationException.getResponse());
105-
case Throwable exception -> resume(jakarta.ws.rs.core.Response.status(INTERNAL_SERVER_ERROR.getStatusCode(), Optional.ofNullable(exception.getMessage()).orElse("Unknown error")).build());
104+
case Throwable exception when Throwables.getRootCause(exception) instanceof WebApplicationException webApplicationException ->
105+
resume(webApplicationException.getResponse());
106+
case Throwable exception ->
107+
resume(jakarta.ws.rs.core.Response.status(INTERNAL_SERVER_ERROR.getStatusCode(), Optional.ofNullable(exception.getMessage()).orElse("Unknown error")).build());
106108
default -> {
107109
if (hasBeenResumed.compareAndSet(false, true)) {
108110
asyncResponse.resume(result);

trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java

+9-9
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ public void testAwsChunkedUploadInvalidContent()
182182

183183
// Final chunk has an invalid size
184184
Function<String, String> changeSizeOfFinalChunk = chunked -> chunked.replaceFirst("\\r\\n0;chunk-signature=(\\w+)", "\r\n1;chunk-signature=$1");
185-
assertThat(doAwsChunkedUpload(bucket, fileKey, LOREM_IPSUM, 2, validCredential, changeSizeOfFinalChunk).getStatusCode()).isEqualTo(500);
185+
assertThat(doAwsChunkedUpload(bucket, fileKey, LOREM_IPSUM, 2, validCredential, changeSizeOfFinalChunk).getStatusCode()).isEqualTo(400);
186186
assertFileNotInS3(storageClient, bucket, fileKey);
187187

188188
// First chunk has an invalid size
@@ -198,7 +198,7 @@ public void testAwsChunkedUploadInvalidContent()
198198
}
199199
return "%s%s".formatted(newSizeAsString, chunked.substring(firstChunkIdx));
200200
};
201-
assertThat(doAwsChunkedUpload(bucket, fileKey, LOREM_IPSUM, 2, validCredential, changeSizeOfFirstChunk).getStatusCode()).isEqualTo(500);
201+
assertThat(doAwsChunkedUpload(bucket, fileKey, LOREM_IPSUM, 2, validCredential, changeSizeOfFirstChunk).getStatusCode()).isEqualTo(400);
202202
assertFileNotInS3(storageClient, bucket, fileKey);
203203

204204
// Change the signature of each of the chunks
@@ -269,24 +269,24 @@ public void testAwsChunkedCornerCases()
269269
storageClient.createBucket(r -> r.bucket(bucket).build());
270270

271271
// Illegal signature and no final chunk
272-
testAwsChunkedIllegalChunks(bucket, "no-final-chunk", buildFakeChunk(longDummyContent, longDummyContent.length()), longDummyContent.length(), 500);
272+
testAwsChunkedIllegalChunks(bucket, "no-final-chunk", buildFakeChunk(longDummyContent, longDummyContent.length()), longDummyContent.length(), 400);
273273
// Illegal signature with a final chunk
274274
testAwsChunkedIllegalChunks(bucket, "with-final-chunk", "%s%s".formatted(buildFakeChunk(longDummyContent, longDummyContent.length()), buildFakeChunk("", 0)), longDummyContent.length(), 401);
275275
// Illegal signature and no final chunk - more chunked data than we report in the x-amz-decoded-content-length header
276-
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-more-data-than-headers-indicate", buildFakeChunk(longDummyContent, longDummyContent.length()), 4096, 500);
276+
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-more-data-than-headers-indicate", buildFakeChunk(longDummyContent, longDummyContent.length()), 4096, 400);
277277

278278
// Illegal signature with a final chunk - more chunked data than we report in the x-amz-decoded-content-length header
279-
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-more-data-than-headers-indicate", "%s%s".formatted(buildFakeChunk(longDummyContent, longDummyContent.length()), buildFakeChunk("", 0)), 4096, 500);
279+
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-more-data-than-headers-indicate", "%s%s".formatted(buildFakeChunk(longDummyContent, longDummyContent.length()), buildFakeChunk("", 0)), 4096, 400);
280280

281281
// Illegal signature and no final chunk - chunk misreports its size
282-
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-chunk-underreports-size", buildFakeChunk(longDummyContent, 4096), 4096, 500);
282+
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-chunk-underreports-size", buildFakeChunk(longDummyContent, 4096), 4096, 400);
283283
// Illegal signature with a final chunk - chunk misreports its size
284-
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-chunk-underreports-size", "%s%s".formatted(buildFakeChunk(longDummyContent, 4096), buildFakeChunk("", 0)), 4096, 500);
284+
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-chunk-underreports-size", "%s%s".formatted(buildFakeChunk(longDummyContent, 4096), buildFakeChunk("", 0)), 4096, 400);
285285

286286
// Illegal signature and no final chunk - chunk misreports its size
287-
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-chunk-overreports-size", buildFakeChunk(longDummyContent, 9_000_000), 4096, 500);
287+
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-chunk-overreports-size", buildFakeChunk(longDummyContent, 9_000_000), 4096, 400);
288288
// Illegal signature with a final chunk - chunk misreports its size
289-
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-chunk-overreports-size", "%s%s".formatted(buildFakeChunk(longDummyContent, 9_000_000), buildFakeChunk("", 0)), 4096, 500);
289+
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-chunk-overreports-size", "%s%s".formatted(buildFakeChunk(longDummyContent, 9_000_000), buildFakeChunk("", 0)), 4096, 400);
290290
Thread.sleep(1000);
291291
assertThat(listFilesInS3Bucket(storageClient, bucket)).isEmpty();
292292
}

trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java

+10-12
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ private static void tryReadAwsChunkedData(String chunkedData, int decodedContent
282282
private static void testIllegalAwsChunkedData(String chunkedData, int decodedContentLength, TestingChunkSigningSession signingSession, ChunkReader readerMethod)
283283
{
284284
ByteArrayOutputStream testOutput = new ByteArrayOutputStream();
285-
assertThatThrownBy(() -> readerMethod.read(chunkedData, decodedContentLength, signingSession, testOutput)).isInstanceOfAny(IllegalStateException.class, WebApplicationException.class, IOException.class);
285+
assertThatThrownBy(() -> readerMethod.read(chunkedData, decodedContentLength, signingSession, testOutput)).isInstanceOfAny(WebApplicationException.class, IOException.class);
286286
assertThat(testOutput.toByteArray().length).isLessThan(decodedContentLength);
287287
}
288288

@@ -377,7 +377,7 @@ public void testChunkedInputStreamNoClosingChunk()
377377
InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length);
378378
byte[] tmp = new byte[5];
379379
// altered from original test. Our AwsChunkedInputStream is improved and throws when the final chunk is missing or bad
380-
assertThrows(IOException.class, () -> in.read(tmp));
380+
assertThrows(WebApplicationException.class, () -> in.read(tmp));
381381
}
382382

383383
// Truncated stream (missing closing CRLF)
@@ -393,7 +393,7 @@ public void testCorruptChunkedInputStreamTruncatedCRLF()
393393
InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length);
394394
byte[] tmp = new byte[5];
395395
// altered from original test. Our AwsChunkedInputStream is improved and throws when the final chunk is missing or bad
396-
assertThrows(IOException.class, () -> in.read(tmp));
396+
assertThrows(WebApplicationException.class, () -> in.read(tmp));
397397
try {
398398
in.close();
399399
}
@@ -413,7 +413,7 @@ public void testCorruptChunkedInputStreamMissingCRLF()
413413
InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length);
414414
byte[] buffer = new byte[300];
415415
ByteArrayOutputStream out = new ByteArrayOutputStream();
416-
assertThrows(IOException.class, () -> {
416+
assertThrows(WebApplicationException.class, () -> {
417417
int len;
418418
while ((len = in.read(buffer)) > 0) {
419419
out.write(buffer, 0, len);
@@ -431,7 +431,7 @@ public void testCorruptChunkedInputStreamMissingLF()
431431
byte[] rawBytes = s.getBytes(UTF_8);
432432
ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes);
433433
InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length);
434-
assertThrows(IOException.class, in::read);
434+
assertThrows(WebApplicationException.class, in::read);
435435
in.close();
436436
}
437437

@@ -444,7 +444,7 @@ public void testCorruptChunkedInputStreamInvalidSize()
444444
byte[] rawBytes = s.getBytes(UTF_8);
445445
ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes);
446446
InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length);
447-
assertThrows(IOException.class, in::read);
447+
assertThrows(WebApplicationException.class, in::read);
448448
in.close();
449449
}
450450

@@ -457,7 +457,7 @@ public void testCorruptChunkedInputStreamNegativeSize()
457457
byte[] rawBytes = s.getBytes(UTF_8);
458458
ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes);
459459
InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length);
460-
assertThrows(IOException.class, in::read);
460+
assertThrows(WebApplicationException.class, in::read);
461461
in.close();
462462
}
463463

@@ -472,20 +472,18 @@ public void testCorruptChunkedInputStreamTruncatedChunk()
472472
InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length);
473473
byte[] buffer = new byte[300];
474474
assertEquals(2, in.read(buffer));
475-
assertThrows(IOException.class, () -> in.read(buffer));
475+
assertThrows(WebApplicationException.class, () -> in.read(buffer));
476476
in.close();
477477
}
478478

479479
@Test
480480
public void testCorruptChunkedInputStreamClose()
481-
throws IOException
482481
{
483482
String s = "whatever;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n";
484483
byte[] rawBytes = s.getBytes(UTF_8);
485484
ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes);
486-
try (InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length)) {
487-
assertThrows(IOException.class, in::read);
488-
}
485+
InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length);
486+
assertThrows(WebApplicationException.class, in::read);
489487
}
490488

491489
@Test

0 commit comments

Comments
 (0)