8
8
9
9
import org .apache .lucene .util .SetOnce ;
10
10
import org .elasticsearch .action .ActionListener ;
11
+ import org .elasticsearch .client .Client ;
11
12
import org .elasticsearch .common .ParseField ;
12
13
import org .elasticsearch .common .Strings ;
13
14
import org .elasticsearch .common .io .stream .StreamInput ;
14
15
import org .elasticsearch .common .io .stream .StreamOutput ;
16
+ import org .elasticsearch .common .settings .Settings ;
15
17
import org .elasticsearch .common .xcontent .ConstructingObjectParser ;
16
18
import org .elasticsearch .common .xcontent .XContentBuilder ;
17
19
import org .elasticsearch .common .xcontent .XContentParser ;
21
23
import org .elasticsearch .search .aggregations .pipeline .AbstractPipelineAggregationBuilder ;
22
24
import org .elasticsearch .search .aggregations .pipeline .PipelineAggregator ;
23
25
import org .elasticsearch .xpack .core .XPackField ;
26
+ import org .elasticsearch .xpack .core .ml .action .GetTrainedModelsAction ;
24
27
import org .elasticsearch .xpack .core .ml .inference .trainedmodel .ClassificationConfig ;
25
28
import org .elasticsearch .xpack .core .ml .inference .trainedmodel .ClassificationConfigUpdate ;
26
29
import org .elasticsearch .xpack .core .ml .inference .trainedmodel .InferenceConfigUpdate ;
27
30
import org .elasticsearch .xpack .core .ml .inference .trainedmodel .ResultsFieldUpdate ;
31
+ import org .elasticsearch .xpack .core .security .SecurityContext ;
32
+ import org .elasticsearch .xpack .core .security .action .user .HasPrivilegesAction ;
33
+ import org .elasticsearch .xpack .core .security .action .user .HasPrivilegesRequest ;
34
+ import org .elasticsearch .xpack .core .security .action .user .HasPrivilegesResponse ;
35
+ import org .elasticsearch .xpack .core .security .authz .RoleDescriptor ;
36
+ import org .elasticsearch .xpack .core .security .support .Exceptions ;
28
37
import org .elasticsearch .xpack .ml .inference .loadingservice .LocalModel ;
29
38
import org .elasticsearch .xpack .ml .inference .loadingservice .ModelLoadingService ;
30
39
31
40
import java .io .IOException ;
32
41
import java .util .Map ;
33
42
import java .util .Objects ;
34
43
import java .util .TreeMap ;
44
+ import java .util .function .BiConsumer ;
35
45
import java .util .function .Supplier ;
36
46
37
47
import static org .elasticsearch .common .xcontent .ConstructingObjectParser .constructorArg ;
48
+ import static org .elasticsearch .xpack .ml .utils .SecondaryAuthorizationUtils .useSecondaryAuthIfAvailable ;
38
49
39
50
public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder <InferencePipelineAggregationBuilder > {
40
51
@@ -186,8 +197,9 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
186
197
if (model != null ) {
187
198
return this ;
188
199
}
200
+
189
201
SetOnce <LocalModel > loadedModel = new SetOnce <>();
190
- context . registerAsyncAction (( client , listener ) -> {
202
+ BiConsumer < Client , ActionListener <?>> modelLoadAction = ( client , listener ) ->
191
203
modelLoadingService .get ().getModelForSearch (modelId , ActionListener .delegateFailure (listener , (delegate , model ) -> {
192
204
loadedModel .set (model );
193
205
@@ -199,6 +211,36 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
199
211
delegate .onFailure (LicenseUtils .newComplianceException (XPackField .MACHINE_LEARNING ));
200
212
}
201
213
}));
214
+
215
+
216
+ context .registerAsyncAction ((client , listener ) -> {
217
+ if (licenseState .isSecurityEnabled ()) {
218
+ // check the user has ml privileges
219
+ SecurityContext securityContext = new SecurityContext (Settings .EMPTY , client .threadPool ().getThreadContext ());
220
+ useSecondaryAuthIfAvailable (securityContext , () -> {
221
+ final String username = securityContext .getUser ().principal ();
222
+ final HasPrivilegesRequest privRequest = new HasPrivilegesRequest ();
223
+ privRequest .username (username );
224
+ privRequest .clusterPrivileges (GetTrainedModelsAction .NAME );
225
+ privRequest .indexPrivileges (new RoleDescriptor .IndicesPrivileges []{});
226
+ privRequest .applicationPrivileges (new RoleDescriptor .ApplicationResourcePrivileges []{});
227
+
228
+ ActionListener <HasPrivilegesResponse > privResponseListener = ActionListener .wrap (
229
+ r -> {
230
+ if (r .isCompleteMatch ()) {
231
+ modelLoadAction .accept (client , listener );
232
+ } else {
233
+ listener .onFailure (Exceptions .authorizationError ("user [" + username
234
+ + "] does not have the privilege to get trained models so cannot use ml inference" ));
235
+ }
236
+ },
237
+ listener ::onFailure );
238
+
239
+ client .execute (HasPrivilegesAction .INSTANCE , privRequest , privResponseListener );
240
+ });
241
+ } else {
242
+ modelLoadAction .accept (client , listener );
243
+ }
202
244
});
203
245
return new InferencePipelineAggregationBuilder (name , bucketPathMap , loadedModel ::get , modelId , inferenceConfig , licenseState );
204
246
}
0 commit comments