Skip to content

Commit d8a03d7

Browse files
committed
GH-1018 Ensure AWS adapter can pass raw InputStream
Resolves #1018
1 parent e76cca0 commit d8a03d7

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed

spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.cloud.function.adapter.aws;
1818

19+
import java.io.IOException;
20+
import java.io.InputStream;
1921
import java.lang.reflect.Type;
2022
import java.nio.charset.StandardCharsets;
2123
import java.util.HashMap;
@@ -31,7 +33,9 @@
3133
import org.springframework.http.HttpStatus;
3234
import org.springframework.messaging.Message;
3335
import org.springframework.messaging.MessageHeaders;
36+
import org.springframework.messaging.support.GenericMessage;
3437
import org.springframework.messaging.support.MessageBuilder;
38+
import org.springframework.util.StreamUtils;
3539

3640
/**
3741
*
@@ -77,6 +81,23 @@ static boolean isSupportedAWSType(Type inputType) {
7781
|| typeName.equals("com.amazonaws.services.lambda.runtime.events.KinesisEvent");
7882
}
7983

84+
@SuppressWarnings("rawtypes")
85+
public static Message generateMessage(InputStream payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper, Context context) throws IOException {
86+
if (inputType != null && FunctionTypeUtils.isMessage(inputType)) {
87+
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
88+
}
89+
if (inputType != null && InputStream.class.isAssignableFrom(FunctionTypeUtils.getRawType(inputType))) {
90+
MessageBuilder msgBuilder = MessageBuilder.withPayload(payload);
91+
if (context != null) {
92+
msgBuilder.setHeader(AWSLambdaUtils.AWS_CONTEXT, context);
93+
}
94+
return msgBuilder.build();
95+
}
96+
else {
97+
return generateMessage(StreamUtils.copyToByteArray(payload), inputType, isSupplier, jsonMapper, context);
98+
}
99+
}
100+
80101
public static Message<byte[]> generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) {
81102
return generateMessage(payload, inputType, isSupplier, jsonMapper, null);
82103
}
@@ -87,6 +108,7 @@ public static Message<byte[]> generateMessage(byte[] payload, Type inputType, bo
87108
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
88109
}
89110

111+
90112
Object structMessage = jsonMapper.fromJson(payload, Object.class);
91113
boolean isApiGateway = structMessage instanceof Map
92114
&& (((Map<String, Object>) structMessage).containsKey("httpMethod") ||

spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public FunctionInvoker() {
8080
@Override
8181
public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException {
8282
Message requestMessage = AWSLambdaUtils
83-
.generateMessage(StreamUtils.copyToByteArray(input), this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);
83+
.generateMessage(input, this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);
8484

8585
Object response = this.function.apply(requestMessage);
8686
byte[] responseBytes = this.buildResult(requestMessage, response);

spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.springframework.messaging.converter.AbstractMessageConverter;
5858
import org.springframework.messaging.support.MessageBuilder;
5959
import org.springframework.util.MimeType;
60+
import org.springframework.util.StreamUtils;
6061

6162
import static org.assertj.core.api.Assertions.assertThat;
6263
import static org.junit.jupiter.api.Assertions.fail;
@@ -989,6 +990,40 @@ public void testApiGatewayAsSupplier() throws Exception {
989990
assertThat(result.get("body")).isEqualTo("\"boom\"");
990991
}
991992

993+
@SuppressWarnings({ "rawtypes", "unchecked" })
994+
@Test
995+
public void testApiGatewayInAndOutInputStream() throws Exception {
996+
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
997+
System.setProperty("spring.cloud.function.definition", "echoInputStreamToString");
998+
FunctionInvoker invoker = new FunctionInvoker();
999+
1000+
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
1001+
ByteArrayOutputStream output = new ByteArrayOutputStream();
1002+
invoker.handleRequest(targetStream, output, null);
1003+
1004+
Map result = mapper.readValue(output.toByteArray(), Map.class);
1005+
assertThat(result.get("body")).isEqualTo("hello");
1006+
Map headers = (Map) result.get("headers");
1007+
assertThat(headers).isNotEmpty();
1008+
}
1009+
1010+
@SuppressWarnings({ "rawtypes", "unchecked" })
1011+
@Test
1012+
public void testApiGatewayInAndOutInputStreamMsg() throws Exception {
1013+
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
1014+
System.setProperty("spring.cloud.function.definition", "echoInputStreamMsgToString");
1015+
FunctionInvoker invoker = new FunctionInvoker();
1016+
1017+
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
1018+
ByteArrayOutputStream output = new ByteArrayOutputStream();
1019+
invoker.handleRequest(targetStream, output, null);
1020+
1021+
Map result = mapper.readValue(output.toByteArray(), Map.class);
1022+
assertThat(result.get("body")).isEqualTo("hello");
1023+
Map headers = (Map) result.get("headers");
1024+
assertThat(headers).isNotEmpty();
1025+
}
1026+
9921027
@SuppressWarnings("rawtypes")
9931028
@Test
9941029
public void testApiGatewayInAndOut() throws Exception {
@@ -1426,6 +1461,34 @@ public Function<APIGatewayProxyRequestEvent, String> inputApiEvent() {
14261461
};
14271462
}
14281463

1464+
@Bean
1465+
1466+
public Function<InputStream, String> echoInputStreamToString() {
1467+
return is -> {
1468+
try {
1469+
String result = StreamUtils.copyToString(is, StandardCharsets.UTF_8);
1470+
return result;
1471+
}
1472+
catch (Exception e) {
1473+
throw new RuntimeException(e);
1474+
}
1475+
};
1476+
}
1477+
1478+
@Bean
1479+
1480+
public Function<Message<InputStream>, String> echoInputStreamMsgToString() {
1481+
return msg -> {
1482+
try {
1483+
String result = StreamUtils.copyToString(msg.getPayload(), StandardCharsets.UTF_8);
1484+
return result;
1485+
}
1486+
catch (Exception e) {
1487+
throw new RuntimeException(e);
1488+
}
1489+
};
1490+
}
1491+
14291492
@Bean
14301493
public Function<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> inputOutputApiEvent() {
14311494
return v -> {

0 commit comments

Comments
 (0)