13
13
import org .elasticsearch .action .delete .DeleteRequest ;
14
14
import org .elasticsearch .action .delete .DeleteResponse ;
15
15
import org .elasticsearch .action .get .GetResponse ;
16
+ import org .elasticsearch .action .get .MultiGetItemResponse ;
17
+ import org .elasticsearch .action .get .MultiGetRequest ;
18
+ import org .elasticsearch .action .get .MultiGetResponse ;
16
19
import org .elasticsearch .action .index .IndexRequest ;
17
20
import org .elasticsearch .action .index .IndexResponse ;
18
21
import org .elasticsearch .action .search .MultiSearchResponse ;
50
53
import java .io .IOException ;
51
54
import java .util .ArrayList ;
52
55
import java .util .Arrays ;
53
- import java .util .Collection ;
54
56
import java .util .Collections ;
55
57
import java .util .HashMap ;
56
58
import java .util .HashSet ;
@@ -116,33 +118,54 @@ public void getRoleDescriptors(Set<String> names, final ActionListener<RoleRetri
116
118
if (securityIndex .indexExists () == false ) {
117
119
// TODO remove this short circuiting and fix tests that fail without this!
118
120
listener .onResponse (RoleRetrievalResult .success (Collections .emptySet ()));
119
- } else if (names != null && names .size () == 1 ) {
120
- getRoleDescriptor (Objects .requireNonNull (names .iterator ().next ()), listener );
121
- } else {
121
+ } else if (names == null || names .isEmpty ()) {
122
122
securityIndex .checkIndexVersionThenExecute (listener ::onFailure , () -> {
123
- QueryBuilder query ;
124
- if (names == null || names .isEmpty ()) {
125
- query = QueryBuilders .termQuery (RoleDescriptor .Fields .TYPE .getPreferredName (), ROLE_TYPE );
126
- } else {
127
- final String [] roleNames = names .stream ().map (NativeRolesStore ::getIdForUser ).toArray (String []::new );
128
- query = QueryBuilders .boolQuery ().filter (QueryBuilders .idsQuery (ROLE_DOC_TYPE ).addIds (roleNames ));
129
- }
123
+ QueryBuilder query = QueryBuilders .termQuery (RoleDescriptor .Fields .TYPE .getPreferredName (), ROLE_TYPE );
130
124
final Supplier <ThreadContext .StoredContext > supplier = client .threadPool ().getThreadContext ().newRestorableContext (false );
131
125
try (ThreadContext .StoredContext ignore = stashWithOrigin (client .threadPool ().getThreadContext (), SECURITY_ORIGIN )) {
132
126
SearchRequest request = client .prepareSearch (SecurityIndexManager .SECURITY_INDEX_NAME )
133
- .setScroll (DEFAULT_KEEPALIVE_SETTING .get (settings ))
134
- .setQuery (query )
135
- .setSize (1000 )
136
- .setFetchSource (true )
137
- .request ();
127
+ .setScroll (DEFAULT_KEEPALIVE_SETTING .get (settings ))
128
+ .setQuery (query )
129
+ .setSize (1000 )
130
+ .setFetchSource (true )
131
+ .request ();
138
132
request .indicesOptions ().ignoreUnavailable ();
139
- final ActionListener <Collection <RoleDescriptor >> descriptorsListener = ActionListener .wrap (
140
- roleDescriptors -> listener .onResponse (RoleRetrievalResult .success (new HashSet <>(roleDescriptors ))),
141
- e -> listener .onResponse (RoleRetrievalResult .failure (e )));
142
- ScrollHelper .fetchAllByEntity (client , request , new ContextPreservingActionListener <>(supplier , descriptorsListener ),
143
- (hit ) -> transformRole (hit .getId (), hit .getSourceRef (), logger , licenseState ));
133
+ ScrollHelper .fetchAllByEntity (client , request , new ContextPreservingActionListener <>(supplier ,
134
+ ActionListener .wrap (roles -> listener .onResponse (RoleRetrievalResult .success (new HashSet <>(roles ))),
135
+ e -> listener .onResponse (RoleRetrievalResult .failure (e )))),
136
+ (hit ) -> transformRole (hit .getId (), hit .getSourceRef (), logger , licenseState ));
144
137
}
145
138
});
139
+ } else if (names .size () == 1 ) {
140
+ getRoleDescriptor (Objects .requireNonNull (names .iterator ().next ()), listener );
141
+ } else {
142
+ securityIndex .checkIndexVersionThenExecute (listener ::onFailure , () -> {
143
+ final String [] roleIds = names .stream ().map (NativeRolesStore ::getIdForRole ).toArray (String []::new );
144
+ MultiGetRequest multiGetRequest = client .prepareMultiGet ().add (SECURITY_INDEX_NAME , ROLE_DOC_TYPE , roleIds ).request ();
145
+ executeAsyncWithOrigin (client .threadPool ().getThreadContext (), SECURITY_ORIGIN , multiGetRequest ,
146
+ ActionListener .<MultiGetResponse >wrap (mGetResponse -> {
147
+ final MultiGetItemResponse [] responses = mGetResponse .getResponses ();
148
+ Set <RoleDescriptor > descriptors = new HashSet <>();
149
+ for (int i = 0 ; i < responses .length ; i ++) {
150
+ MultiGetItemResponse item = responses [i ];
151
+ if (item .isFailed ()) {
152
+ final Exception failure = item .getFailure ().getFailure ();
153
+ for (int j = i + 1 ; j < responses .length ; j ++) {
154
+ item = responses [j ];
155
+ if (item .isFailed ()) {
156
+ failure .addSuppressed (failure );
157
+ }
158
+ }
159
+ listener .onResponse (RoleRetrievalResult .failure (failure ));
160
+ return ;
161
+ } else if (item .getResponse ().isExists ()) {
162
+ descriptors .add (transformRole (item .getResponse ()));
163
+ }
164
+ }
165
+ listener .onResponse (RoleRetrievalResult .success (descriptors ));
166
+ },
167
+ e -> listener .onResponse (RoleRetrievalResult .failure (e ))), client ::multiGet );
168
+ });
146
169
}
147
170
}
148
171
@@ -157,7 +180,7 @@ public void deleteRole(final DeleteRoleRequest deleteRoleRequest, final ActionLi
157
180
} else {
158
181
securityIndex .checkIndexVersionThenExecute (listener ::onFailure , () -> {
159
182
DeleteRequest request = client .prepareDelete (SecurityIndexManager .SECURITY_INDEX_NAME ,
160
- ROLE_DOC_TYPE , getIdForUser (deleteRoleRequest .name ())).request ();
183
+ ROLE_DOC_TYPE , getIdForRole (deleteRoleRequest .name ())).request ();
161
184
request .setRefreshPolicy (deleteRoleRequest .getRefreshPolicy ());
162
185
executeAsyncWithOrigin (client .threadPool ().getThreadContext (), SECURITY_ORIGIN , request ,
163
186
new ActionListener <DeleteResponse >() {
@@ -199,7 +222,7 @@ void innerPutRole(final PutRoleRequest request, final RoleDescriptor role, final
199
222
listener .onFailure (e );
200
223
return ;
201
224
}
202
- final IndexRequest indexRequest = client .prepareIndex (SECURITY_INDEX_NAME , ROLE_DOC_TYPE , getIdForUser (role .getName ()))
225
+ final IndexRequest indexRequest = client .prepareIndex (SECURITY_INDEX_NAME , ROLE_DOC_TYPE , getIdForRole (role .getName ()))
203
226
.setSource (xContentBuilder )
204
227
.setRefreshPolicy (request .getRefreshPolicy ())
205
228
.request ();
@@ -315,7 +338,7 @@ private void executeGetRoleRequest(String role, ActionListener<GetResponse> list
315
338
securityIndex .checkIndexVersionThenExecute (listener ::onFailure , () ->
316
339
executeAsyncWithOrigin (client .threadPool ().getThreadContext (), SECURITY_ORIGIN ,
317
340
client .prepareGet (SECURITY_INDEX_NAME ,
318
- ROLE_DOC_TYPE , getIdForUser (role )).request (),
341
+ ROLE_DOC_TYPE , getIdForRole (role )).request (),
319
342
listener ,
320
343
client ::get ));
321
344
}
@@ -395,7 +418,7 @@ public static void addSettings(List<Setting<?>> settings) {
395
418
/**
396
419
* Gets the document's id field for the given role name.
397
420
*/
398
- private static String getIdForUser (final String roleName ) {
421
+ private static String getIdForRole (final String roleName ) {
399
422
return ROLE_TYPE + "-" + roleName ;
400
423
}
401
424
}
0 commit comments