Skip to content

Tasks: Only require task permissions #35667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
* Action for retrieving a list of currently running tasks
*/
public class GetTaskAction extends Action<GetTaskResponse> {
public static final String TASKS_ORIGIN = "tasks";
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stuck this here because it felt like a fairly right place at the time. Now that I look at it again maybe it should be on something a little more centrally located. But it is about actions so I stuck it on an action. I'm not sure of a better spot.


public static final GetTaskAction INSTANCE = new GetTaskAction();
public static final String NAME = "cluster:monitor/task/get";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
Expand All @@ -51,6 +52,7 @@

import java.io.IOException;

import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
import static org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction.waitForCompletionTimeout;

/**
Expand All @@ -77,7 +79,7 @@ public TransportGetTaskAction(ThreadPool threadPool, TransportService transportS
this.threadPool = threadPool;
this.clusterService = clusterService;
this.transportService = transportService;
this.client = client;
this.client = new OriginSettingClient(client, TASKS_ORIGIN);
this.xContentRegistry = xContentRegistry;
}

Expand Down Expand Up @@ -210,6 +212,7 @@ void getFinishedTaskFromIndex(Task thisTask, GetTaskRequest request, ActionListe
GetRequest get = new GetRequest(TaskResultsService.TASK_INDEX, TaskResultsService.TASK_TYPE,
request.getTaskId().toString());
get.setParentTask(clusterService.localNode().getId(), thisTask.getId());

client.get(get, new ActionListener<GetResponse>() {
@Override
public void onResponse(GetResponse getResponse) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client;

import org.elasticsearch.action.Action;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.util.concurrent.ThreadContext;

import java.util.function.Supplier;

/**
* A {@linkplain Client} that sends requests with the
* {@link ThreadContext#stashWithOrigin origin} set to a particular
* value and calls its {@linkplain ActionListener} in its original
* {@link ThreadContext}.
*/
public final class OriginSettingClient extends FilterClient {

private final String origin;

public OriginSettingClient(Client in, String origin) {
super(in);
this.origin = origin;
}

@Override
protected <Request extends ActionRequest, Response extends ActionResponse>
void doExecute(Action<Response> action, Request request, ActionListener<Response> listener) {
final Supplier<ThreadContext.StoredContext> supplier = in().threadPool().getThreadContext().newRestorableContext(false);
try (ThreadContext.StoredContext ignore = in().threadPool().getThreadContext().stashWithOrigin(origin)) {
super.doExecute(action, request, new ContextPreservingActionListener<>(supplier, listener));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.logging.log4j.LogManager;
import org.apache.lucene.util.CloseableThreadLocal;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -85,6 +87,12 @@ public final class ThreadContext implements Closeable, Writeable {

public static final String PREFIX = "request.headers";
public static final Setting<Settings> DEFAULT_HEADERS_SETTING = Setting.groupSetting(PREFIX + ".", Property.NodeScope);

/**
* Name for the {@link #stashWithOrigin origin} attribute.
*/
public static final String ACTION_ORIGIN_TRANSIENT_NAME = "action.origin";

private static final Logger logger = LogManager.getLogger(ThreadContext.class);
private static final ThreadContextStruct DEFAULT_CONTEXT = new ThreadContextStruct();
private final Map<String, String> defaultHeader;
Expand Down Expand Up @@ -119,14 +127,39 @@ public void close() throws IOException {

/**
* Removes the current context and resets a default context. The removed context can be
* restored when closing the returned {@link StoredContext}
* restored by closing the returned {@link StoredContext}.
*/
public StoredContext stashContext() {
final ThreadContextStruct context = threadLocal.get();
threadLocal.set(null);
return () -> threadLocal.set(context);
}

/**
* Removes the current context and resets a default context marked with as
* originating from the supplied string. The removed context can be
* restored by closing the returned {@link StoredContext}. Callers should
* be careful to save the current context before calling this method and
* restore it any listeners, likely with
* {@link ContextPreservingActionListener}. Use {@link OriginSettingClient}
* which can be used to do this automatically.
* <p>
* Without security the origin is ignored, but security uses it to authorize
* actions that are made up of many sub-actions. These actions call
* {@link #stashWithOrigin} before performing on behalf of a user that
* should be allowed even if the user doesn't have permission to perform
* those actions on their own.
* <p>
* For example, a user might not have permission to GET from the tasks index
* but the tasks API will perform a get on their behalf using this method
* if it can't find the task in memory.
*/
public StoredContext stashWithOrigin(String origin) {
final ThreadContext.StoredContext storedContext = stashContext();
putTransient(ACTION_ORIGIN_TRANSIENT_NAME, origin);
return storedContext;
}

/**
* Removes the current context and resets a new context that contains a merge of the current headers and the given headers.
* The removed context can be restored when closing the returned {@link StoredContext}. The merge strategy is that headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,19 @@
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateObserver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData.PersistentTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.function.Predicate;
import java.util.function.Supplier;

/**
* This service is used by persistent tasks and allocated persistent tasks to communicate changes
Expand All @@ -50,15 +48,14 @@ public class PersistentTasksService {

private static final Logger logger = LogManager.getLogger(PersistentTasksService.class);

private static final String ACTION_ORIGIN_TRANSIENT_NAME = "action.origin";
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the server has ThreadContext.stashWithOrigin and OriginSettingClient we can use it. I believe this change is a noop, and the tests seem to agree, but I'd love a second set of eyes on it.

private static final String PERSISTENT_TASK_ORIGIN = "persistent_tasks";

private final Client client;
private final ClusterService clusterService;
private final ThreadPool threadPool;

public PersistentTasksService(ClusterService clusterService, ThreadPool threadPool, Client client) {
this.client = client;
this.client = new OriginSettingClient(client, PERSISTENT_TASK_ORIGIN);
this.clusterService = clusterService;
this.threadPool = threadPool;
}
Expand Down Expand Up @@ -98,12 +95,7 @@ void sendCancelRequest(final long taskId, final String reason, final ActionListe
request.setTaskId(new TaskId(clusterService.localNode().getId(), taskId));
request.setReason(reason);
try {
final ThreadContext threadContext = client.threadPool().getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);

try (ThreadContext.StoredContext ignore = stashWithOrigin(threadContext, PERSISTENT_TASK_ORIGIN)) {
client.admin().cluster().cancelTasks(request, new ContextPreservingActionListener<>(supplier, listener));
}
client.admin().cluster().cancelTasks(request, listener);
} catch (Exception e) {
listener.onFailure(e);
}
Expand Down Expand Up @@ -140,14 +132,8 @@ public void sendRemoveRequest(final String taskId, final ActionListener<Persiste
private <Req extends ActionRequest, Resp extends PersistentTaskResponse>
void execute(final Req request, final Action<Resp> action, final ActionListener<PersistentTask<?>> listener) {
try {
final ThreadContext threadContext = client.threadPool().getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);

try (ThreadContext.StoredContext ignore = stashWithOrigin(threadContext, PERSISTENT_TASK_ORIGIN)) {
client.execute(action, request,
new ContextPreservingActionListener<>(supplier,
ActionListener.wrap(r -> listener.onResponse(r.getTask()), listener::onFailure)));
}
client.execute(action, request,
ActionListener.wrap(r -> listener.onResponse(r.getTask()), listener::onFailure));
} catch (Exception e) {
listener.onFailure(e);
}
Expand Down Expand Up @@ -233,10 +219,4 @@ default void onTimeout(TimeValue timeout) {
onFailure(new IllegalStateException("Timed out when waiting for persistent task after " + timeout));
}
}

public static ThreadContext.StoredContext stashWithOrigin(ThreadContext threadContext, String origin) {
final ThreadContext.StoredContext storedContext = threadContext.stashContext();
threadContext.putTransient(ACTION_ORIGIN_TRANSIENT_NAME, origin);
return storedContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.client.Requests;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexMetaData;
Expand All @@ -50,6 +51,8 @@
import java.nio.charset.StandardCharsets;
import java.util.Map;

import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;

/**
* Service that can store task results.
*/
Expand All @@ -73,7 +76,7 @@ public class TaskResultsService {

@Inject
public TaskResultsService(Client client, ClusterService clusterService) {
this.client = client;
this.client = new OriginSettingClient(client, TASKS_ORIGIN);
this.clusterService = clusterService;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksAction;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.upgrade.post.UpgradeAction;
import org.elasticsearch.action.admin.indices.validate.query.ValidateQueryAction;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchAction;
Expand Down Expand Up @@ -85,7 +85,6 @@
import static org.elasticsearch.common.unit.TimeValue.timeValueMillis;
import static org.elasticsearch.common.unit.TimeValue.timeValueSeconds;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertThrows;
Expand Down Expand Up @@ -725,12 +724,6 @@ public void testTasksWaitForAllTask() throws Exception {
}

public void testTaskStoringSuccesfulResult() throws Exception {
// Randomly create an empty index to make sure the type is created automatically
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this because it made writing the assertions about setting the origin very difficult and it isn't really a supported configuration.

if (randomBoolean()) {
logger.info("creating an empty results index with custom settings");
assertAcked(client().admin().indices().prepareCreate(TaskResultsService.TASK_INDEX));
}

registerTaskManageListeners(TestTaskPlugin.TestTaskAction.NAME); // we need this to get task id of the process

// Start non-blocking test task
Expand All @@ -743,23 +736,20 @@ public void testTaskStoringSuccesfulResult() throws Exception {
TaskInfo taskInfo = events.get(0);
TaskId taskId = taskInfo.getTaskId();

GetResponse resultDoc = client()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced these with the tasks API calls because it gives me a good opportunity to assert that they find the documents while setting the origin.

.prepareGet(TaskResultsService.TASK_INDEX, TaskResultsService.TASK_TYPE, taskId.toString()).get();
assertTrue(resultDoc.isExists());

Map<String, Object> source = resultDoc.getSource();
@SuppressWarnings("unchecked")
Map<String, Object> task = (Map<String, Object>) source.get("task");
assertEquals(taskInfo.getTaskId().getNodeId(), task.get("node"));
assertEquals(taskInfo.getAction(), task.get("action"));
assertEquals(Long.toString(taskInfo.getId()), task.get("id").toString());

@SuppressWarnings("unchecked")
Map<String, Object> result = (Map<String, Object>) source.get("response");
TaskResult taskResult = client().admin().cluster()
.getTask(new GetTaskRequest().setTaskId(taskId)).get().getTask();
assertTrue(taskResult.isCompleted());
assertNull(taskResult.getError());

assertEquals(taskInfo.getTaskId(), taskResult.getTask().getTaskId());
assertEquals(taskInfo.getType(), taskResult.getTask().getType());
assertEquals(taskInfo.getAction(), taskResult.getTask().getAction());
assertEquals(taskInfo.getDescription(), taskResult.getTask().getDescription());
assertEquals(taskInfo.getStartTime(), taskResult.getTask().getStartTime());
assertEquals(taskInfo.getHeaders(), taskResult.getTask().getHeaders());
Map<?, ?> result = taskResult.getResponseAsMap();
assertEquals("0", result.get("failure_count").toString());

assertNull(source.get("failure"));

assertNoFailures(client().admin().indices().prepareRefresh(TaskResultsService.TASK_INDEX).get());

SearchResponse searchResponse = client().prepareSearch(TaskResultsService.TASK_INDEX)
Expand Down Expand Up @@ -797,25 +787,21 @@ public void testTaskStoringFailureResult() throws Exception {
TaskInfo failedTaskInfo = events.get(0);
TaskId failedTaskId = failedTaskInfo.getTaskId();

GetResponse failedResultDoc = client()
.prepareGet(TaskResultsService.TASK_INDEX, TaskResultsService.TASK_TYPE, failedTaskId.toString())
.get();
assertTrue(failedResultDoc.isExists());

Map<String, Object> source = failedResultDoc.getSource();
@SuppressWarnings("unchecked")
Map<String, Object> task = (Map<String, Object>) source.get("task");
assertEquals(failedTaskInfo.getTaskId().getNodeId(), task.get("node"));
assertEquals(failedTaskInfo.getAction(), task.get("action"));
assertEquals(Long.toString(failedTaskInfo.getId()), task.get("id").toString());

@SuppressWarnings("unchecked")
Map<String, Object> error = (Map<String, Object>) source.get("error");
TaskResult taskResult = client().admin().cluster()
.getTask(new GetTaskRequest().setTaskId(failedTaskId)).get().getTask();
assertTrue(taskResult.isCompleted());
assertNull(taskResult.getResponse());

assertEquals(failedTaskInfo.getTaskId(), taskResult.getTask().getTaskId());
assertEquals(failedTaskInfo.getType(), taskResult.getTask().getType());
assertEquals(failedTaskInfo.getAction(), taskResult.getTask().getAction());
assertEquals(failedTaskInfo.getDescription(), taskResult.getTask().getDescription());
assertEquals(failedTaskInfo.getStartTime(), taskResult.getTask().getStartTime());
assertEquals(failedTaskInfo.getHeaders(), taskResult.getTask().getHeaders());
Map<?, ?> error = (Map<?, ?>) taskResult.getErrorAsMap();
assertEquals("Simulating operation failure", error.get("reason"));
assertEquals("illegal_state_exception", error.get("type"));

assertNull(source.get("result"));

GetTaskResponse getResponse = expectFinishedTask(failedTaskId);
assertNull(getResponse.getTask().getResponse());
assertEquals(error, getResponse.getTask().getErrorAsMap());
Expand Down
Loading