Skip to content

[FEATURE][ML] User config appropriate permission checks on creating/running analytics #38928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

Expand All @@ -26,14 +28,14 @@ public static ContextParser<Void, DataFrameAnalysisConfig> parser() {
private final Map<String, Object> config;

public DataFrameAnalysisConfig(Map<String, Object> config) {
this.config = Objects.requireNonNull(config);
this.config = Collections.unmodifiableMap(new HashMap<>(Objects.requireNonNull(config)));
if (config.size() != 1) {
throw ExceptionsHelper.badRequestException("A data frame analysis must specify exactly one analysis type");
}
}

public DataFrameAnalysisConfig(StreamInput in) throws IOException {
config = in.readMap();
config = Collections.unmodifiableMap(in.readMap());
}

public Map<String, Object> asMap() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -62,6 +64,7 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {
public static final ParseField ANALYSES = new ParseField("analyses");
public static final ParseField CONFIG_TYPE = new ParseField("config_type");
public static final ParseField QUERY = new ParseField("query");
public static final ParseField HEADERS = new ParseField("headers");

public static final ObjectParser<Builder, Void> STRICT_PARSER = createParser(false);
public static final ObjectParser<Builder, Void> LENIENT_PARSER = createParser(true);
Expand All @@ -75,6 +78,11 @@ public static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFiel
parser.declareString(Builder::setDest, DEST);
parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES);
parser.declareObject((builder, query) -> builder.setQuery(query, ignoreUnknownFields), (p, c) -> p.mapOrdered(), QUERY);
if (ignoreUnknownFields) {
// Headers are not parsed by the strict (config) parser, so headers supplied in the _body_ of a REST request will be rejected.
// (For config, headers are explicitly transferred from the auth headers by code in the put data frame actions.)
parser.declareObject(Builder::setHeaders, (p, c) -> p.mapStrings(), HEADERS);
}
return parser;
}

Expand All @@ -84,9 +92,10 @@ public static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFiel
private final List<DataFrameAnalysisConfig> analyses;
private final Map<String, Object> query;
private final CachedSupplier<QueryBuilder> querySupplier;
private final Map<String, String> headers;

public DataFrameAnalyticsConfig(String id, String source, String dest, List<DataFrameAnalysisConfig> analyses,
Map<String, Object> query) {
Map<String, Object> query, Map<String, String> headers) {
this.id = ExceptionsHelper.requireNonNull(id, ID);
this.source = ExceptionsHelper.requireNonNull(source, SOURCE);
this.dest = ExceptionsHelper.requireNonNull(dest, DEST);
Expand All @@ -100,6 +109,7 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List<Data
}
this.query = Collections.unmodifiableMap(query);
this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>()));
this.headers = Collections.unmodifiableMap(headers);
}

public DataFrameAnalyticsConfig(StreamInput in) throws IOException {
Expand All @@ -109,6 +119,7 @@ public DataFrameAnalyticsConfig(StreamInput in) throws IOException {
analyses = in.readList(DataFrameAnalysisConfig::new);
this.query = in.readMap();
this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>()));
this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
}

public String getId() {
Expand Down Expand Up @@ -151,6 +162,10 @@ List<String> getQueryDeprecations(TriFunction<Map<String, Object>, String, List<
return deprecations;
}

public Map<String, String> getHeaders() {
return headers;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -162,6 +177,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(CONFIG_TYPE.getPreferredName(), TYPE);
}
builder.field(QUERY.getPreferredName(), query);
if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(HEADERS.getPreferredName(), headers);
}
builder.endObject();
return builder;
}
Expand All @@ -173,6 +191,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(dest);
out.writeList(analyses);
out.writeMap(query);
out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString);
}

@Override
Expand All @@ -185,12 +204,13 @@ public boolean equals(Object o) {
&& Objects.equals(source, other.source)
&& Objects.equals(dest, other.dest)
&& Objects.equals(analyses, other.analyses)
&& Objects.equals(headers, other.headers)
&& Objects.equals(query, other.query);
}

@Override
public int hashCode() {
return Objects.hash(id, source, dest, analyses, query);
return Objects.hash(id, source, dest, analyses, query, headers);
}

public static String documentId(String id) {
Expand All @@ -204,11 +224,23 @@ public static class Builder {
private String dest;
private List<DataFrameAnalysisConfig> analyses;
private Map<String, Object> query = Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap());
private Map<String, String> headers = Collections.emptyMap();

public String getId() {
return id;
}

public Builder() {}

public Builder(DataFrameAnalyticsConfig config) {
this.id = config.id;
this.source = config.source;
this.dest = config.dest;
this.analyses = new ArrayList<>(config.analyses);
this.query = new LinkedHashMap<>(config.query);
this.headers = new HashMap<>(config.headers);
}

public Builder setId(String id) {
this.id = ExceptionsHelper.requireNonNull(id, ID);
return this;
Expand Down Expand Up @@ -248,8 +280,13 @@ public Builder setQuery(Map<String, Object> query, boolean lenient) {
return this;
}

public Builder setHeaders(Map<String, String> headers) {
this.headers = headers;
return this;
}

public DataFrameAnalyticsConfig build() {
return new DataFrameAnalyticsConfig(id, source, dest, analyses, query);
return new DataFrameAnalyticsConfig(id, source, dest, analyses, query, headers);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
Expand All @@ -21,13 +25,18 @@
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;

Expand Down Expand Up @@ -61,6 +70,10 @@ protected Writeable.Reader<DataFrameAnalyticsConfig> instanceReader() {
}

public static DataFrameAnalyticsConfig createRandom(String id) {
return createRandomBuilder(id).build();
}

public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) {
String source = randomAlphaOfLength(10);
String dest = randomAlphaOfLength(10);
List<DataFrameAnalysisConfig> analyses = Collections.singletonList(DataFrameAnalysisConfigTests.randomConfig());
Expand All @@ -74,7 +87,7 @@ public static DataFrameAnalyticsConfig createRandom(String id) {
Collections.singletonMap(TermQueryBuilder.NAME,
Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))), true);
}
return builder.build();
return builder;
}

public static String randomValidId() {
Expand Down Expand Up @@ -142,6 +155,33 @@ public void testPastQueryConfigParse() throws IOException {
}
}

public void testToXContentForInternalStorage() throws IOException {
DataFrameAnalyticsConfig.Builder builder = createRandomBuilder("foo");

// headers are only persisted to cluster state
Map<String, String> headers = new HashMap<>();
headers.put("header-name", "header-value");
builder.setHeaders(headers);
DataFrameAnalyticsConfig config = builder.build();

ToXContent.MapParams params = new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));

BytesReference forClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, params, false);
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, forClusterstateXContent.streamInput());

DataFrameAnalyticsConfig parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build();
assertThat(parsedConfig.getHeaders(), hasEntry("header-name", "header-value"));

// headers are not written without the FOR_INTERNAL_STORAGE param
BytesReference nonClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, nonClusterstateXContent.streamInput());

parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build();
assertThat(parsedConfig.getHeaders().entrySet(), hasSize(0));
}

public void testGetQueryDeprecations() {
DataFrameAnalyticsConfig dataFrame = createTestInstance();
String deprecationWarning = "Warning";
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ integTestRunner {
'ml/datafeeds_crud/Test put datafeed with invalid query',
'ml/datafeeds_crud/Test put datafeed with security headers in the body',
'ml/datafeeds_crud/Test update datafeed with missing id',
'ml/data_frame_analytics_crud/Test put config with security headers in the body',
'ml/data_frame_analytics_crud/Test put config with inconsistent body/param ids',
'ml/data_frame_analytics_crud/Test put config with invalid id',
'ml/data_frame_analytics_crud/Test put config with unknown top level field',
Expand Down
4 changes: 2 additions & 2 deletions x-pack/plugin/ml/qa/ml-with-security/roles.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ minimal:
privileges:
- indices:admin/create
- indices:admin/refresh
- indices:data/read/field_caps
- indices:data/read/search
- read
- index
- indices:data/write/bulk
- indices:data/write/index
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,58 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesResponse;
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.authz.permission.ResourcePrivileges;
import org.elasticsearch.xpack.core.security.support.Exceptions;
import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;

import java.io.IOException;
import java.util.function.Supplier;

public class TransportPutDataFrameAnalyticsAction
extends HandledTransportAction<PutDataFrameAnalyticsAction.Request, PutDataFrameAnalyticsAction.Response> {

private final XPackLicenseState licenseState;
private final DataFrameAnalyticsConfigProvider configProvider;
private final ThreadPool threadPool;
private final SecurityContext securityContext;
private final Client client;

@Inject
public TransportPutDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters,
XPackLicenseState licenseState, DataFrameAnalyticsConfigProvider configProvider) {
public TransportPutDataFrameAnalyticsAction(Settings settings, TransportService transportService, ActionFilters actionFilters,
XPackLicenseState licenseState, Client client, ThreadPool threadPool,
DataFrameAnalyticsConfigProvider configProvider) {
super(PutDataFrameAnalyticsAction.NAME, transportService, actionFilters,
(Supplier<PutDataFrameAnalyticsAction.Request>) PutDataFrameAnalyticsAction.Request::new);
this.licenseState = licenseState;
this.configProvider = configProvider;
this.threadPool = threadPool;
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
new SecurityContext(settings, threadPool.getThreadContext()) : null;
this.client = client;
}

@Override
Expand All @@ -46,12 +69,58 @@ protected void doExecute(Task task, PutDataFrameAnalyticsAction.Request request,
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
return;
}

validateConfig(request.getConfig());
configProvider.put(request.getConfig(), ActionListener.wrap(
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())),
listener::onFailure
));
if (licenseState.isAuthAllowed()) {
final String username = securityContext.getUser().principal();
RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder()
.indices(request.getConfig().getSource())
.privileges("read")
.build();
RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder()
.indices(request.getConfig().getDest())
.privileges("read", "index", "create_index")
.build();

HasPrivilegesRequest privRequest = new HasPrivilegesRequest();
privRequest.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[0]);
privRequest.username(username);
privRequest.clusterPrivileges(Strings.EMPTY_ARRAY);
privRequest.indexPrivileges(sourceIndexPrivileges, destIndexPrivileges);

ActionListener<HasPrivilegesResponse> privResponseListener = ActionListener.wrap(
r -> handlePrivsResponse(username, request, r, listener),
listener::onFailure);

client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener);
} else {
configProvider.put(request.getConfig(), threadPool.getThreadContext().getHeaders(), ActionListener.wrap(
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())),
listener::onFailure
));
}
}

private void handlePrivsResponse(String username, PutDataFrameAnalyticsAction.Request request,
HasPrivilegesResponse response,
ActionListener<PutDataFrameAnalyticsAction.Response> listener) throws IOException {
if (response.isCompleteMatch()) {
configProvider.put(request.getConfig(), threadPool.getThreadContext().getHeaders(), ActionListener.wrap(
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())),
listener::onFailure
));
} else {
XContentBuilder builder = JsonXContent.contentBuilder();
builder.startObject();
for (ResourcePrivileges index : response.getIndexPrivileges()) {
builder.field(index.getResource());
builder.map(index.getPrivileges());
}
builder.endObject();

listener.onFailure(Exceptions.authorizationError("Cannot create data frame analytics [{}]" +
" because user {} lacks permissions on the indices: {}",
request.getConfig().getId(), username, Strings.toString(builder)));
}
}

private void validateConfig(DataFrameAnalyticsConfig config) {
Expand Down
Loading