Skip to content

Commit 0d2ea1b

Browse files
authored
Check for ml privilege when using the Inference Aggregation (#59530) (#59562)
The inference pipeline aggregation requires the user has permission to access the ml get trained models endpoint (_ml/inference/)
1 parent 408a07f commit 0d2ea1b

File tree

5 files changed

+90
-6
lines changed

5 files changed

+90
-6
lines changed

x-pack/plugin/ml/qa/ml-with-security/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ integTest.runner {
184184
'ml/job_groups/Test put job with id that matches an existing group',
185185
'ml/job_groups/Test put job with invalid group',
186186
'ml/ml_info/Test ml info',
187+
'ml/pipeline_inference/Test setting results field is invalid',
187188
'ml/post_data/Test Flush data with invalid parameters',
188189
'ml/post_data/Test flushing and posting a closed job',
189190
'ml/post_data/Test open and close with non-existent job id',

x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityInsufficientRoleIT.java

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
import org.elasticsearch.test.rest.yaml.section.ExecutableSection;
1313

1414
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Map;
17+
import java.util.Set;
1518

1619
import static org.hamcrest.Matchers.containsString;
1720
import static org.hamcrest.Matchers.either;
@@ -26,6 +29,7 @@ public MlWithSecurityInsufficientRoleIT(@Name("yaml") ClientYamlTestCandidate te
2629
}
2730

2831
@Override
32+
@SuppressWarnings("unchecked")
2933
public void test() throws IOException {
3034
try {
3135
// Cannot use expectThrows here because blacklisted tests will throw an
@@ -38,7 +42,19 @@ public void test() throws IOException {
3842
String apiName = ((DoSection) section).getApiCallSection().getApi();
3943

4044
if (apiName.startsWith("ml.")) {
41-
fail("call to ml endpoint should have failed because of missing role");
45+
fail("call to ml endpoint [" + apiName + "] should have failed because of missing role");
46+
} else if (apiName.startsWith("search")) {
47+
DoSection doSection = (DoSection) section;
48+
List<Map<String, Object>> bodies = doSection.getApiCallSection().getBodies();
49+
boolean containsInferenceAgg = false;
50+
for (Map<String, Object> body : bodies) {
51+
Map<String, Object> aggs = (Map<String, Object>)body.get("aggs");
52+
containsInferenceAgg = containsInferenceAgg || containsKey("inference", aggs);
53+
}
54+
55+
if (containsInferenceAgg) {
56+
fail("call to [search] with the ml inference agg should have failed because of missing role");
57+
}
4258
}
4359
}
4460
}
@@ -49,9 +65,13 @@ public void test() throws IOException {
4965
assertThat(ae.getMessage(), containsString("but was Integer [0]"));
5066
} else {
5167
assertThat(ae.getMessage(),
52-
either(containsString("action [cluster:monitor/xpack/ml")).or(containsString("action [cluster:admin/xpack/ml")));
68+
either(containsString("action [cluster:monitor/xpack/ml"))
69+
.or(containsString("action [cluster:admin/xpack/ml"))
70+
.or(containsString("security_exception")));
5371
assertThat(ae.getMessage(), containsString("returned [403 Forbidden]"));
54-
assertThat(ae.getMessage(), containsString("is unauthorized for user [no_ml]"));
72+
assertThat(ae.getMessage(),
73+
either(containsString("is unauthorized for user [no_ml]"))
74+
.or(containsString("user [no_ml] does not have the privilege to get trained models")));
5575
}
5676
}
5777
}
@@ -60,5 +80,24 @@ public void test() throws IOException {
6080
protected String[] getCredentials() {
6181
return new String[]{"no_ml", "x-pack-test-password"};
6282
}
83+
84+
@SuppressWarnings("unchecked")
85+
static boolean containsKey(String key, Map<String, Object> mapOfMaps) {
86+
if (mapOfMaps.containsKey(key)) {
87+
return true;
88+
}
89+
90+
Set<Map.Entry<String, Object>> entries = mapOfMaps.entrySet();
91+
for (Map.Entry<String, Object> entry : entries) {
92+
if (entry.getValue() instanceof Map<?,?>) {
93+
boolean isInNestedMap = containsKey(key, (Map<String, Object>)entry.getValue());
94+
if (isInNestedMap) {
95+
return true;
96+
}
97+
}
98+
}
99+
100+
return false;
101+
}
63102
}
64103

x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityUserRoleIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ public void test() throws IOException {
4848
String apiName = ((DoSection) section).getApiCallSection().getApi();
4949

5050
if (apiName.startsWith("ml.") && isAllowed(apiName) == false) {
51-
fail("should have failed because of missing role");
51+
fail("call to ml endpoint [" + apiName + "] should have failed because of missing role");
5252
}
5353
}
5454
}
5555
} catch (AssertionError ae) {
5656
assertThat(ae.getMessage(),
57-
either(containsString("action [cluster:monitor/xpack/ml")).or(containsString("action [cluster:admin/xpack/ml")));
57+
either(containsString("action [cluster:monitor/xpack/ml"))
58+
.or(containsString("action [cluster:admin/xpack/ml")));
5859
assertThat(ae.getMessage(), containsString("returned [403 Forbidden]"));
5960
assertThat(ae.getMessage(), containsString("is unauthorized for user [ml_user]"));
6061
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
import org.apache.lucene.util.SetOnce;
1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.client.Client;
1112
import org.elasticsearch.common.ParseField;
1213
import org.elasticsearch.common.Strings;
1314
import org.elasticsearch.common.io.stream.StreamInput;
1415
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.common.settings.Settings;
1517
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1618
import org.elasticsearch.common.xcontent.XContentBuilder;
1719
import org.elasticsearch.common.xcontent.XContentParser;
@@ -21,20 +23,29 @@
2123
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
2224
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
2325
import org.elasticsearch.xpack.core.XPackField;
26+
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
2427
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
2528
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
2629
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
2730
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
31+
import org.elasticsearch.xpack.core.security.SecurityContext;
32+
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
33+
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
34+
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesResponse;
35+
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
36+
import org.elasticsearch.xpack.core.security.support.Exceptions;
2837
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
2938
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
3039

3140
import java.io.IOException;
3241
import java.util.Map;
3342
import java.util.Objects;
3443
import java.util.TreeMap;
44+
import java.util.function.BiConsumer;
3545
import java.util.function.Supplier;
3646

3747
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
48+
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
3849

3950
public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
4051

@@ -186,8 +197,9 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
186197
if (model != null) {
187198
return this;
188199
}
200+
189201
SetOnce<LocalModel> loadedModel = new SetOnce<>();
190-
context.registerAsyncAction((client, listener) -> {
202+
BiConsumer<Client, ActionListener<?>> modelLoadAction = (client, listener) ->
191203
modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
192204
loadedModel.set(model);
193205

@@ -199,6 +211,36 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
199211
delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
200212
}
201213
}));
214+
215+
216+
context.registerAsyncAction((client, listener) -> {
217+
if (licenseState.isSecurityEnabled()) {
218+
// check the user has ml privileges
219+
SecurityContext securityContext = new SecurityContext(Settings.EMPTY, client.threadPool().getThreadContext());
220+
useSecondaryAuthIfAvailable(securityContext, () -> {
221+
final String username = securityContext.getUser().principal();
222+
final HasPrivilegesRequest privRequest = new HasPrivilegesRequest();
223+
privRequest.username(username);
224+
privRequest.clusterPrivileges(GetTrainedModelsAction.NAME);
225+
privRequest.indexPrivileges(new RoleDescriptor.IndicesPrivileges[]{});
226+
privRequest.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[]{});
227+
228+
ActionListener<HasPrivilegesResponse> privResponseListener = ActionListener.wrap(
229+
r -> {
230+
if (r.isCompleteMatch()) {
231+
modelLoadAction.accept(client, listener);
232+
} else {
233+
listener.onFailure(Exceptions.authorizationError("user [" + username
234+
+ "] does not have the privilege to get trained models so cannot use ml inference"));
235+
}
236+
},
237+
listener::onFailure);
238+
239+
client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener);
240+
});
241+
} else {
242+
modelLoadAction.accept(client, listener);
243+
}
202244
});
203245
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
204246
}

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/pipeline_inference.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ setup:
139139
}
140140
}
141141
- match: { aggregations.good.buckets.0.regression_agg.value: 2.0 }
142+
142143
---
143144
"Test pipeline agg referencing a single bucket":
144145

0 commit comments

Comments
 (0)