38
38
import org .springframework .web .context .request .async .StandardServletAsyncWebRequest ;
39
39
import org .springframework .web .context .request .async .WebAsyncManager ;
40
40
import org .springframework .web .context .request .async .WebAsyncUtils ;
41
+ import org .springframework .web .util .NestedServletException ;
41
42
42
43
import static org .hamcrest .Matchers .containsString ;
43
44
import static org .hamcrest .Matchers .equalTo ;
@@ -62,7 +63,8 @@ public class ErrorPageFilterTests {
62
63
63
64
private ErrorPageFilter filter = new ErrorPageFilter ();
64
65
65
- private MockHttpServletRequest request = new MockHttpServletRequest ();
66
+ private MockHttpServletRequest request = new MockHttpServletRequest ("GET" ,
67
+ "/test/path" );
66
68
67
69
private MockHttpServletResponse response = new MockHttpServletResponse ();
68
70
@@ -199,6 +201,9 @@ public void doFilter(ServletRequest request, ServletResponse response)
199
201
equalTo ((Object ) 400 ));
200
202
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_MESSAGE ),
201
203
equalTo ((Object ) "BAD" ));
204
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_REQUEST_URI ),
205
+ equalTo ((Object ) "/test/path" ));
206
+
202
207
assertTrue (this .response .isCommitted ());
203
208
assertThat (this .response .getForwardedUrl (), equalTo ("/error" ));
204
209
}
@@ -221,6 +226,8 @@ public void doFilter(ServletRequest request, ServletResponse response)
221
226
equalTo ((Object ) 400 ));
222
227
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_MESSAGE ),
223
228
equalTo ((Object ) "BAD" ));
229
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_REQUEST_URI ),
230
+ equalTo ((Object ) "/test/path" ));
224
231
assertTrue (this .response .isCommitted ());
225
232
assertThat (this .response .getForwardedUrl (), equalTo ("/400" ));
226
233
}
@@ -264,6 +271,8 @@ public void doFilter(ServletRequest request, ServletResponse response)
264
271
equalTo ((Object ) "BAD" ));
265
272
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ),
266
273
equalTo ((Object ) RuntimeException .class .getName ()));
274
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_REQUEST_URI ),
275
+ equalTo ((Object ) "/test/path" ));
267
276
assertTrue (this .response .isCommitted ());
268
277
assertThat (this .response .getForwardedUrl (), equalTo ("/500" ));
269
278
}
@@ -319,6 +328,8 @@ public void doFilter(ServletRequest request, ServletResponse response)
319
328
equalTo ((Object ) "BAD" ));
320
329
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ),
321
330
equalTo ((Object ) IllegalStateException .class .getName ()));
331
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_REQUEST_URI ),
332
+ equalTo ((Object ) "/test/path" ));
322
333
assertTrue (this .response .isCommitted ());
323
334
}
324
335
@@ -465,6 +476,32 @@ public void doFilter(ServletRequest request, ServletResponse response)
465
476
assertThat (this .output .toString (), containsString ("request [/test/alpha]" ));
466
477
}
467
478
479
+ @ Test
480
+ public void nestedServletExceptionIsUnwrapped () throws Exception {
481
+ this .filter .addErrorPages (new ErrorPage (RuntimeException .class , "/500" ));
482
+ this .chain = new MockFilterChain () {
483
+ @ Override
484
+ public void doFilter (ServletRequest request , ServletResponse response )
485
+ throws IOException , ServletException {
486
+ super .doFilter (request , response );
487
+ throw new NestedServletException ("Wrapper" , new RuntimeException ("BAD" ));
488
+ }
489
+ };
490
+ this .filter .doFilter (this .request , this .response , this .chain );
491
+ assertThat (((HttpServletResponseWrapper ) this .chain .getResponse ()).getStatus (),
492
+ equalTo (500 ));
493
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_STATUS_CODE ),
494
+ equalTo ((Object ) 500 ));
495
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_MESSAGE ),
496
+ equalTo ((Object ) "BAD" ));
497
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ),
498
+ equalTo ((Object ) RuntimeException .class .getName ()));
499
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_REQUEST_URI ),
500
+ equalTo ((Object ) "/test/path" ));
501
+ assertTrue (this .response .isCommitted ());
502
+ assertThat (this .response .getForwardedUrl (), equalTo ("/500" ));
503
+ }
504
+
468
505
private void setUpAsyncDispatch () throws Exception {
469
506
this .request .setAsyncSupported (true );
470
507
this .request .setAsyncStarted (true );
0 commit comments