54
54
import org .junit .BeforeClass ;
55
55
56
56
import java .util .ArrayList ;
57
+ import java .util .Arrays ;
57
58
import java .util .Collections ;
58
59
import java .util .HashMap ;
59
60
import java .util .HashSet ;
@@ -177,6 +178,8 @@ public void testThreadContext() throws InterruptedException {
177
178
178
179
try (ThreadContext .StoredContext ignored = threadPool .getThreadContext ().stashContext ()) {
179
180
final Map <String , String > expectedHeaders = Collections .singletonMap ("test" , "test" );
181
+ final Map <String , List <String >> expectedResponseHeaders = Collections .singletonMap ("testResponse" ,
182
+ Arrays .asList ("testResponse" ));
180
183
threadPool .getThreadContext ().putHeader (expectedHeaders );
181
184
182
185
final TimeValue ackTimeout = randomBoolean () ? TimeValue .ZERO : TimeValue .timeValueMillis (randomInt (10000 ));
@@ -187,6 +190,8 @@ public void testThreadContext() throws InterruptedException {
187
190
public ClusterState execute (ClusterState currentState ) {
188
191
assertTrue (threadPool .getThreadContext ().isSystemContext ());
189
192
assertEquals (Collections .emptyMap (), threadPool .getThreadContext ().getHeaders ());
193
+ threadPool .getThreadContext ().addResponseHeader ("testResponse" , "testResponse" );
194
+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
190
195
191
196
if (randomBoolean ()) {
192
197
return ClusterState .builder (currentState ).build ();
@@ -201,13 +206,15 @@ public ClusterState execute(ClusterState currentState) {
201
206
public void onFailure (String source , Exception e ) {
202
207
assertFalse (threadPool .getThreadContext ().isSystemContext ());
203
208
assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
209
+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
204
210
latch .countDown ();
205
211
}
206
212
207
213
@ Override
208
214
public void clusterStateProcessed (String source , ClusterState oldState , ClusterState newState ) {
209
215
assertFalse (threadPool .getThreadContext ().isSystemContext ());
210
216
assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
217
+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
211
218
latch .countDown ();
212
219
}
213
220
@@ -229,20 +236,23 @@ public TimeValue timeout() {
229
236
public void onAllNodesAcked (@ Nullable Exception e ) {
230
237
assertFalse (threadPool .getThreadContext ().isSystemContext ());
231
238
assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
239
+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
232
240
latch .countDown ();
233
241
}
234
242
235
243
@ Override
236
244
public void onAckTimeout () {
237
245
assertFalse (threadPool .getThreadContext ().isSystemContext ());
238
246
assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
247
+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
239
248
latch .countDown ();
240
249
}
241
250
242
251
});
243
252
244
253
assertFalse (threadPool .getThreadContext ().isSystemContext ());
245
254
assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
255
+ assertEquals (Collections .emptyMap (), threadPool .getThreadContext ().getResponseHeaders ());
246
256
}
247
257
248
258
latch .await ();
0 commit comments