Skip to content

Commit 375e0e6

Browse files
committed
Handle Content-Length in ShallowEtagHeaderFilter more robustly
This commit ensures that setting the Content-Length through setHeader("Content-Length", x") has the same effect as calling setContentLength in the ShallowEtagHeaderFilter. It also filters out Content-Type headers similarly to Content-Length. Closes gh-32039
1 parent b8b31ff commit 375e0e6

File tree

2 files changed

+142
-6
lines changed

2 files changed

+142
-6
lines changed

Diff for: spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java

+134-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,6 +21,10 @@
2121
import java.io.OutputStreamWriter;
2222
import java.io.PrintWriter;
2323
import java.io.UnsupportedEncodingException;
24+
import java.util.ArrayList;
25+
import java.util.Collection;
26+
import java.util.Collections;
27+
import java.util.List;
2428

2529
import jakarta.servlet.ServletOutputStream;
2630
import jakarta.servlet.WriteListener;
@@ -55,6 +59,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
5559
@Nullable
5660
private Integer contentLength;
5761

62+
@Nullable
63+
private String contentType;
64+
5865

5966
/**
6067
* Create a new ContentCachingResponseWrapper for the given servlet response.
@@ -139,6 +146,122 @@ public void setContentLengthLong(long len) {
139146
this.contentLength = lenInt;
140147
}
141148

149+
@Override
150+
public void setContentType(String type) {
151+
this.contentType = type;
152+
}
153+
154+
@Override
155+
@Nullable
156+
public String getContentType() {
157+
return this.contentType;
158+
}
159+
160+
@Override
161+
public boolean containsHeader(String name) {
162+
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
163+
return this.contentLength != null;
164+
}
165+
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
166+
return this.contentType != null;
167+
}
168+
else {
169+
return super.containsHeader(name);
170+
}
171+
}
172+
173+
@Override
174+
public void setHeader(String name, String value) {
175+
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
176+
this.contentLength = Integer.valueOf(value);
177+
}
178+
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
179+
this.contentType = value;
180+
}
181+
else {
182+
super.setHeader(name, value);
183+
}
184+
}
185+
186+
@Override
187+
public void addHeader(String name, String value) {
188+
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
189+
this.contentLength = Integer.valueOf(value);
190+
}
191+
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
192+
this.contentType = value;
193+
}
194+
else {
195+
super.addHeader(name, value);
196+
}
197+
}
198+
199+
@Override
200+
public void setIntHeader(String name, int value) {
201+
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
202+
this.contentLength = Integer.valueOf(value);
203+
}
204+
else {
205+
super.setIntHeader(name, value);
206+
}
207+
}
208+
209+
@Override
210+
public void addIntHeader(String name, int value) {
211+
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
212+
this.contentLength = Integer.valueOf(value);
213+
}
214+
else {
215+
super.addIntHeader(name, value);
216+
}
217+
}
218+
219+
@Override
220+
@Nullable
221+
public String getHeader(String name) {
222+
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
223+
return (this.contentLength != null) ? this.contentLength.toString() : null;
224+
}
225+
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
226+
return this.contentType;
227+
}
228+
else {
229+
return super.getHeader(name);
230+
}
231+
}
232+
233+
@Override
234+
public Collection<String> getHeaders(String name) {
235+
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
236+
return this.contentLength != null ? Collections.singleton(this.contentLength.toString()) :
237+
Collections.emptySet();
238+
}
239+
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
240+
return this.contentType != null ? Collections.singleton(this.contentType) : Collections.emptySet();
241+
}
242+
else {
243+
return super.getHeaders(name);
244+
}
245+
}
246+
247+
@Override
248+
public Collection<String> getHeaderNames() {
249+
Collection<String> headerNames = super.getHeaderNames();
250+
if (this.contentLength != null || this.contentType != null) {
251+
List<String> result = new ArrayList<>(headerNames);
252+
if (this.contentLength != null) {
253+
result.add(HttpHeaders.CONTENT_LENGTH);
254+
}
255+
if (this.contentType != null) {
256+
result.add(HttpHeaders.CONTENT_TYPE);
257+
}
258+
return result;
259+
}
260+
else {
261+
return headerNames;
262+
}
263+
}
264+
142265
@Override
143266
public void setBufferSize(int size) {
144267
if (size > this.content.size()) {
@@ -197,11 +320,17 @@ public void copyBodyToResponse() throws IOException {
197320
protected void copyBodyToResponse(boolean complete) throws IOException {
198321
if (this.content.size() > 0) {
199322
HttpServletResponse rawResponse = (HttpServletResponse) getResponse();
200-
if ((complete || this.contentLength != null) && !rawResponse.isCommitted()) {
201-
if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) {
202-
rawResponse.setContentLength(complete ? this.content.size() : this.contentLength);
323+
if (!rawResponse.isCommitted()) {
324+
if (complete || this.contentLength != null) {
325+
if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) {
326+
rawResponse.setContentLength(complete ? this.content.size() : this.contentLength);
327+
}
328+
this.contentLength = null;
329+
}
330+
if (complete || this.contentType != null) {
331+
rawResponse.setContentType(this.contentType);
332+
this.contentType = null;
203333
}
204-
this.contentLength = null;
205334
}
206335
this.content.writeTo(rawResponse.getOutputStream());
207336
this.content.reset();

Diff for: spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import jakarta.servlet.http.HttpServletResponse;
2424
import org.junit.jupiter.api.Test;
2525

26+
import org.springframework.http.MediaType;
2627
import org.springframework.util.FileCopyUtils;
2728
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
2829
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
@@ -68,13 +69,15 @@ void filterNoMatch() throws Exception {
6869
FilterChain filterChain = (filterRequest, filterResponse) -> {
6970
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
7071
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
72+
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
7173
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
7274
};
7375
filter.doFilter(request, response, filterChain);
7476

7577
assertThat(response.getStatus()).as("Invalid status").isEqualTo(200);
7678
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
7779
assertThat(response.getContentLength()).as("Invalid Content-Length header").isGreaterThan(0);
80+
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(MediaType.TEXT_PLAIN_VALUE);
7881
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(responseBody);
7982
}
8083

@@ -88,13 +91,15 @@ void filterNoMatchWeakETag() throws Exception {
8891
FilterChain filterChain = (filterRequest, filterResponse) -> {
8992
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
9093
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
94+
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
9195
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
9296
};
9397
filter.doFilter(request, response, filterChain);
9498

9599
assertThat(response.getStatus()).as("Invalid status").isEqualTo(200);
96100
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("W/\"0b10a8db164e0754105b7a99be72e3fe5\"");
97101
assertThat(response.getContentLength()).as("Invalid Content-Length header").isGreaterThan(0);
102+
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(MediaType.TEXT_PLAIN_VALUE);
98103
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(responseBody);
99104
}
100105

@@ -108,14 +113,16 @@ void filterMatch() throws Exception {
108113
FilterChain filterChain = (filterRequest, filterResponse) -> {
109114
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
110115
byte[] responseBody = "Hello World".getBytes(StandardCharsets.UTF_8);
111-
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
112116
filterResponse.setContentLength(responseBody.length);
117+
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
118+
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
113119
};
114120
filter.doFilter(request, response, filterChain);
115121

116122
assertThat(response.getStatus()).as("Invalid status").isEqualTo(304);
117123
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
118124
assertThat(response.containsHeader("Content-Length")).as("Response has Content-Length header").isFalse();
125+
assertThat(response.containsHeader("Content-Type")).as("Response has Content-Type header").isFalse();
119126
byte[] expecteds = new byte[0];
120127
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(expecteds);
121128
}

0 commit comments

Comments
 (0)