Skip to content

Commit a2ab55b

Browse files
authored
[FEATURE][ML] User config appropriate permission checks on creating/running analytics (#38928)
* [Feature][ML] Add authz check for dataframe source index * fixing origin for client calls and adding headers * addressing PR comments * Having bulk request be done with headers in origin * addressing pr comments and failing test * making analyses immutable * adjusting indexnames and privs for security tests
1 parent 911b805 commit a2ab55b

File tree

14 files changed

+284
-74
lines changed

14 files changed

+284
-74
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1515

1616
import java.io.IOException;
17+
import java.util.Collections;
18+
import java.util.HashMap;
1719
import java.util.Map;
1820
import java.util.Objects;
1921

@@ -26,14 +28,14 @@ public static ContextParser<Void, DataFrameAnalysisConfig> parser() {
2628
private final Map<String, Object> config;
2729

2830
public DataFrameAnalysisConfig(Map<String, Object> config) {
29-
this.config = Objects.requireNonNull(config);
31+
this.config = Collections.unmodifiableMap(new HashMap<>(Objects.requireNonNull(config)));
3032
if (config.size() != 1) {
3133
throw ExceptionsHelper.badRequestException("A data frame analysis must specify exactly one analysis type");
3234
}
3335
}
3436

3537
public DataFrameAnalysisConfig(StreamInput in) throws IOException {
36-
config = in.readMap();
38+
config = Collections.unmodifiableMap(in.readMap());
3739
}
3840

3941
public Map<String, Object> asMap() {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import java.io.IOException;
3030
import java.util.ArrayList;
3131
import java.util.Collections;
32+
import java.util.HashMap;
33+
import java.util.LinkedHashMap;
3234
import java.util.List;
3335
import java.util.Map;
3436
import java.util.Objects;
@@ -62,6 +64,7 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {
6264
public static final ParseField ANALYSES = new ParseField("analyses");
6365
public static final ParseField CONFIG_TYPE = new ParseField("config_type");
6466
public static final ParseField QUERY = new ParseField("query");
67+
public static final ParseField HEADERS = new ParseField("headers");
6568

6669
public static final ObjectParser<Builder, Void> STRICT_PARSER = createParser(false);
6770
public static final ObjectParser<Builder, Void> LENIENT_PARSER = createParser(true);
@@ -75,6 +78,11 @@ public static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFiel
7578
parser.declareString(Builder::setDest, DEST);
7679
parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES);
7780
parser.declareObject((builder, query) -> builder.setQuery(query, ignoreUnknownFields), (p, c) -> p.mapOrdered(), QUERY);
81+
if (ignoreUnknownFields) {
82+
// Headers are not parsed by the strict (config) parser, so headers supplied in the _body_ of a REST request will be rejected.
83+
// (For config, headers are explicitly transferred from the auth headers by code in the put data frame actions.)
84+
parser.declareObject(Builder::setHeaders, (p, c) -> p.mapStrings(), HEADERS);
85+
}
7886
return parser;
7987
}
8088

@@ -84,9 +92,10 @@ public static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFiel
8492
private final List<DataFrameAnalysisConfig> analyses;
8593
private final Map<String, Object> query;
8694
private final CachedSupplier<QueryBuilder> querySupplier;
95+
private final Map<String, String> headers;
8796

8897
public DataFrameAnalyticsConfig(String id, String source, String dest, List<DataFrameAnalysisConfig> analyses,
89-
Map<String, Object> query) {
98+
Map<String, Object> query, Map<String, String> headers) {
9099
this.id = ExceptionsHelper.requireNonNull(id, ID);
91100
this.source = ExceptionsHelper.requireNonNull(source, SOURCE);
92101
this.dest = ExceptionsHelper.requireNonNull(dest, DEST);
@@ -100,6 +109,7 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List<Data
100109
}
101110
this.query = Collections.unmodifiableMap(query);
102111
this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>()));
112+
this.headers = Collections.unmodifiableMap(headers);
103113
}
104114

105115
public DataFrameAnalyticsConfig(StreamInput in) throws IOException {
@@ -109,6 +119,7 @@ public DataFrameAnalyticsConfig(StreamInput in) throws IOException {
109119
analyses = in.readList(DataFrameAnalysisConfig::new);
110120
this.query = in.readMap();
111121
this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>()));
122+
this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
112123
}
113124

114125
public String getId() {
@@ -151,6 +162,10 @@ List<String> getQueryDeprecations(TriFunction<Map<String, Object>, String, List<
151162
return deprecations;
152163
}
153164

165+
public Map<String, String> getHeaders() {
166+
return headers;
167+
}
168+
154169
@Override
155170
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
156171
builder.startObject();
@@ -162,6 +177,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
162177
builder.field(CONFIG_TYPE.getPreferredName(), TYPE);
163178
}
164179
builder.field(QUERY.getPreferredName(), query);
180+
if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
181+
builder.field(HEADERS.getPreferredName(), headers);
182+
}
165183
builder.endObject();
166184
return builder;
167185
}
@@ -173,6 +191,7 @@ public void writeTo(StreamOutput out) throws IOException {
173191
out.writeString(dest);
174192
out.writeList(analyses);
175193
out.writeMap(query);
194+
out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString);
176195
}
177196

178197
@Override
@@ -185,12 +204,13 @@ public boolean equals(Object o) {
185204
&& Objects.equals(source, other.source)
186205
&& Objects.equals(dest, other.dest)
187206
&& Objects.equals(analyses, other.analyses)
207+
&& Objects.equals(headers, other.headers)
188208
&& Objects.equals(query, other.query);
189209
}
190210

191211
@Override
192212
public int hashCode() {
193-
return Objects.hash(id, source, dest, analyses, query);
213+
return Objects.hash(id, source, dest, analyses, query, headers);
194214
}
195215

196216
public static String documentId(String id) {
@@ -204,11 +224,23 @@ public static class Builder {
204224
private String dest;
205225
private List<DataFrameAnalysisConfig> analyses;
206226
private Map<String, Object> query = Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap());
227+
private Map<String, String> headers = Collections.emptyMap();
207228

208229
public String getId() {
209230
return id;
210231
}
211232

233+
public Builder() {}
234+
235+
public Builder(DataFrameAnalyticsConfig config) {
236+
this.id = config.id;
237+
this.source = config.source;
238+
this.dest = config.dest;
239+
this.analyses = new ArrayList<>(config.analyses);
240+
this.query = new LinkedHashMap<>(config.query);
241+
this.headers = new HashMap<>(config.headers);
242+
}
243+
212244
public Builder setId(String id) {
213245
this.id = ExceptionsHelper.requireNonNull(id, ID);
214246
return this;
@@ -248,8 +280,13 @@ public Builder setQuery(Map<String, Object> query, boolean lenient) {
248280
return this;
249281
}
250282

283+
public Builder setHeaders(Map<String, String> headers) {
284+
this.headers = headers;
285+
return this;
286+
}
287+
251288
public DataFrameAnalyticsConfig build() {
252-
return new DataFrameAnalyticsConfig(id, source, dest, analyses, query);
289+
return new DataFrameAnalyticsConfig(id, source, dest, analyses, query, headers);
253290
}
254291
}
255292
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77

88
import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator;
99
import org.elasticsearch.ElasticsearchException;
10+
import org.elasticsearch.common.bytes.BytesReference;
1011
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1112
import org.elasticsearch.common.io.stream.Writeable;
1213
import org.elasticsearch.common.settings.Settings;
1314
import org.elasticsearch.common.xcontent.DeprecationHandler;
15+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
1416
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
17+
import org.elasticsearch.common.xcontent.ToXContent;
1518
import org.elasticsearch.common.xcontent.XContentFactory;
19+
import org.elasticsearch.common.xcontent.XContentHelper;
1620
import org.elasticsearch.common.xcontent.XContentParseException;
1721
import org.elasticsearch.common.xcontent.XContentParser;
1822
import org.elasticsearch.common.xcontent.XContentType;
@@ -21,13 +25,18 @@
2125
import org.elasticsearch.index.query.TermQueryBuilder;
2226
import org.elasticsearch.search.SearchModule;
2327
import org.elasticsearch.test.AbstractSerializingTestCase;
28+
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
2429

2530
import java.io.IOException;
2631
import java.util.Collections;
32+
import java.util.HashMap;
2733
import java.util.List;
34+
import java.util.Map;
2835

2936
import static org.hamcrest.CoreMatchers.equalTo;
37+
import static org.hamcrest.Matchers.hasEntry;
3038
import static org.hamcrest.Matchers.hasItem;
39+
import static org.hamcrest.Matchers.hasSize;
3140
import static org.mockito.Mockito.spy;
3241
import static org.mockito.Mockito.verify;
3342

@@ -61,6 +70,10 @@ protected Writeable.Reader<DataFrameAnalyticsConfig> instanceReader() {
6170
}
6271

6372
public static DataFrameAnalyticsConfig createRandom(String id) {
73+
return createRandomBuilder(id).build();
74+
}
75+
76+
public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) {
6477
String source = randomAlphaOfLength(10);
6578
String dest = randomAlphaOfLength(10);
6679
List<DataFrameAnalysisConfig> analyses = Collections.singletonList(DataFrameAnalysisConfigTests.randomConfig());
@@ -74,7 +87,7 @@ public static DataFrameAnalyticsConfig createRandom(String id) {
7487
Collections.singletonMap(TermQueryBuilder.NAME,
7588
Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))), true);
7689
}
77-
return builder.build();
90+
return builder;
7891
}
7992

8093
public static String randomValidId() {
@@ -142,6 +155,33 @@ public void testPastQueryConfigParse() throws IOException {
142155
}
143156
}
144157

158+
public void testToXContentForInternalStorage() throws IOException {
159+
DataFrameAnalyticsConfig.Builder builder = createRandomBuilder("foo");
160+
161+
// headers are only persisted to cluster state
162+
Map<String, String> headers = new HashMap<>();
163+
headers.put("header-name", "header-value");
164+
builder.setHeaders(headers);
165+
DataFrameAnalyticsConfig config = builder.build();
166+
167+
ToXContent.MapParams params = new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
168+
169+
BytesReference forClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, params, false);
170+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
171+
.createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, forClusterstateXContent.streamInput());
172+
173+
DataFrameAnalyticsConfig parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build();
174+
assertThat(parsedConfig.getHeaders(), hasEntry("header-name", "header-value"));
175+
176+
// headers are not written without the FOR_INTERNAL_STORAGE param
177+
BytesReference nonClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
178+
parser = XContentFactory.xContent(XContentType.JSON)
179+
.createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, nonClusterstateXContent.streamInput());
180+
181+
parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build();
182+
assertThat(parsedConfig.getHeaders().entrySet(), hasSize(0));
183+
}
184+
145185
public void testGetQueryDeprecations() {
146186
DataFrameAnalyticsConfig dataFrame = createTestInstance();
147187
String deprecationWarning = "Warning";

x-pack/plugin/ml/qa/ml-with-security/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ integTestRunner {
3535
'ml/datafeeds_crud/Test put datafeed with invalid query',
3636
'ml/datafeeds_crud/Test put datafeed with security headers in the body',
3737
'ml/datafeeds_crud/Test update datafeed with missing id',
38+
'ml/data_frame_analytics_crud/Test put config with security headers in the body',
3839
'ml/data_frame_analytics_crud/Test put config with inconsistent body/param ids',
3940
'ml/data_frame_analytics_crud/Test put config with invalid id',
4041
'ml/data_frame_analytics_crud/Test put config with unknown top level field',

x-pack/plugin/ml/qa/ml-with-security/roles.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ minimal:
1111
privileges:
1212
- indices:admin/create
1313
- indices:admin/refresh
14-
- indices:data/read/field_caps
15-
- indices:data/read/search
14+
- read
15+
- index
1616
- indices:data/write/bulk
1717
- indices:data/write/index

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,58 @@
88
import org.elasticsearch.action.ActionListener;
99
import org.elasticsearch.action.support.ActionFilters;
1010
import org.elasticsearch.action.support.HandledTransportAction;
11+
import org.elasticsearch.client.Client;
12+
import org.elasticsearch.common.Strings;
1113
import org.elasticsearch.common.inject.Inject;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.common.xcontent.XContentBuilder;
16+
import org.elasticsearch.common.xcontent.json.JsonXContent;
1217
import org.elasticsearch.license.LicenseUtils;
1318
import org.elasticsearch.license.XPackLicenseState;
1419
import org.elasticsearch.tasks.Task;
20+
import org.elasticsearch.threadpool.ThreadPool;
1521
import org.elasticsearch.transport.TransportService;
1622
import org.elasticsearch.xpack.core.XPackField;
23+
import org.elasticsearch.xpack.core.XPackSettings;
1724
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
1825
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
1926
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
2027
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2128
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
29+
import org.elasticsearch.xpack.core.security.SecurityContext;
30+
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
31+
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
32+
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesResponse;
33+
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
34+
import org.elasticsearch.xpack.core.security.authz.permission.ResourcePrivileges;
35+
import org.elasticsearch.xpack.core.security.support.Exceptions;
2236
import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils;
2337
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
2438

39+
import java.io.IOException;
2540
import java.util.function.Supplier;
2641

2742
public class TransportPutDataFrameAnalyticsAction
2843
extends HandledTransportAction<PutDataFrameAnalyticsAction.Request, PutDataFrameAnalyticsAction.Response> {
2944

3045
private final XPackLicenseState licenseState;
3146
private final DataFrameAnalyticsConfigProvider configProvider;
47+
private final ThreadPool threadPool;
48+
private final SecurityContext securityContext;
49+
private final Client client;
3250

3351
@Inject
34-
public TransportPutDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters,
35-
XPackLicenseState licenseState, DataFrameAnalyticsConfigProvider configProvider) {
52+
public TransportPutDataFrameAnalyticsAction(Settings settings, TransportService transportService, ActionFilters actionFilters,
53+
XPackLicenseState licenseState, Client client, ThreadPool threadPool,
54+
DataFrameAnalyticsConfigProvider configProvider) {
3655
super(PutDataFrameAnalyticsAction.NAME, transportService, actionFilters,
3756
(Supplier<PutDataFrameAnalyticsAction.Request>) PutDataFrameAnalyticsAction.Request::new);
3857
this.licenseState = licenseState;
3958
this.configProvider = configProvider;
59+
this.threadPool = threadPool;
60+
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
61+
new SecurityContext(settings, threadPool.getThreadContext()) : null;
62+
this.client = client;
4063
}
4164

4265
@Override
@@ -46,12 +69,58 @@ protected void doExecute(Task task, PutDataFrameAnalyticsAction.Request request,
4669
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
4770
return;
4871
}
49-
5072
validateConfig(request.getConfig());
51-
configProvider.put(request.getConfig(), ActionListener.wrap(
52-
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())),
53-
listener::onFailure
54-
));
73+
if (licenseState.isAuthAllowed()) {
74+
final String username = securityContext.getUser().principal();
75+
RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder()
76+
.indices(request.getConfig().getSource())
77+
.privileges("read")
78+
.build();
79+
RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder()
80+
.indices(request.getConfig().getDest())
81+
.privileges("read", "index", "create_index")
82+
.build();
83+
84+
HasPrivilegesRequest privRequest = new HasPrivilegesRequest();
85+
privRequest.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[0]);
86+
privRequest.username(username);
87+
privRequest.clusterPrivileges(Strings.EMPTY_ARRAY);
88+
privRequest.indexPrivileges(sourceIndexPrivileges, destIndexPrivileges);
89+
90+
ActionListener<HasPrivilegesResponse> privResponseListener = ActionListener.wrap(
91+
r -> handlePrivsResponse(username, request, r, listener),
92+
listener::onFailure);
93+
94+
client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener);
95+
} else {
96+
configProvider.put(request.getConfig(), threadPool.getThreadContext().getHeaders(), ActionListener.wrap(
97+
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())),
98+
listener::onFailure
99+
));
100+
}
101+
}
102+
103+
private void handlePrivsResponse(String username, PutDataFrameAnalyticsAction.Request request,
104+
HasPrivilegesResponse response,
105+
ActionListener<PutDataFrameAnalyticsAction.Response> listener) throws IOException {
106+
if (response.isCompleteMatch()) {
107+
configProvider.put(request.getConfig(), threadPool.getThreadContext().getHeaders(), ActionListener.wrap(
108+
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())),
109+
listener::onFailure
110+
));
111+
} else {
112+
XContentBuilder builder = JsonXContent.contentBuilder();
113+
builder.startObject();
114+
for (ResourcePrivileges index : response.getIndexPrivileges()) {
115+
builder.field(index.getResource());
116+
builder.map(index.getPrivileges());
117+
}
118+
builder.endObject();
119+
120+
listener.onFailure(Exceptions.authorizationError("Cannot create data frame analytics [{}]" +
121+
" because user {} lacks permissions on the indices: {}",
122+
request.getConfig().getId(), username, Strings.toString(builder)));
123+
}
55124
}
56125

57126
private void validateConfig(DataFrameAnalyticsConfig config) {

0 commit comments

Comments
 (0)