Skip to content

Commit da73c91

Browse files
authored
[ML] Fix tests randomly failing on CI (elastic#51142) (elastic#51150)
1 parent b70ebde commit da73c91

File tree

3 files changed

+24
-38
lines changed

3 files changed

+24
-38
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020
public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
2121

22-
public static Response randomResponse() {
23-
int listSize = randomInt(10);
22+
public static Response randomResponse(int listSize) {
2423
List<Response.Stats> analytics = new ArrayList<>(listSize);
2524
for (int j = 0; j < listSize; j++) {
2625
String failureReason = randomBoolean() ? null : randomAlphaOfLength(10);
@@ -37,7 +36,7 @@ public static Response randomResponse() {
3736

3837
@Override
3938
protected Response createTestInstance() {
40-
return randomResponse();
39+
return randomResponse(randomInt(10));
4140
}
4241

4342
@Override

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ public void markAsCompleted() {
121121
isMarkAsCompletedCalled = true;
122122
}
123123

124-
persistProgress(() -> super.markAsCompleted());
124+
persistProgress(client, taskParams.getId(), () -> super.markAsCompleted());
125125
}
126126

127127
@Override
128128
public void markAsFailed(Exception e) {
129-
persistProgress(() -> super.markAsFailed(e));
129+
persistProgress(client, taskParams.getId(), () -> super.markAsFailed(e));
130130
}
131131

132132
public void stop(String reason, TimeValue timeout) {
@@ -244,21 +244,22 @@ private TaskId getReindexTaskId() {
244244
}
245245
}
246246

247-
private void persistProgress(Runnable runnable) {
248-
LOGGER.debug("[{}] Persisting progress", taskParams.getId());
247+
// Visible for testing
248+
static void persistProgress(Client client, String jobId, Runnable runnable) {
249+
LOGGER.debug("[{}] Persisting progress", jobId);
249250

250-
String progressDocId = StoredProgress.documentId(taskParams.getId());
251+
String progressDocId = StoredProgress.documentId(jobId);
251252
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();
252253

253254
// Step 4: Run the runnable provided as the argument
254255
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
255256
indexResponse -> {
256-
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
257+
LOGGER.debug("[{}] Successfully indexed progress document", jobId);
257258
runnable.run();
258259
},
259260
indexError -> {
260261
LOGGER.error(new ParameterizedMessage(
261-
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
262+
"[{}] cannot persist progress as an error occurred while indexing", jobId), indexError);
262263
runnable.run();
263264
}
264265
);
@@ -283,7 +284,7 @@ private void persistProgress(Runnable runnable) {
283284
},
284285
e -> {
285286
LOGGER.error(new ParameterizedMessage(
286-
"[{}] cannot persist progress as an error occurred while retrieving former progress document", taskParams.getId()), e);
287+
"[{}] cannot persist progress as an error occurred while retrieving former progress document", jobId), e);
287288
runnable.run();
288289
}
289290
);
@@ -302,13 +303,13 @@ private void persistProgress(Runnable runnable) {
302303
},
303304
e -> {
304305
LOGGER.error(new ParameterizedMessage(
305-
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
306+
"[{}] cannot persist progress as an error occurred while retrieving stats", jobId), e);
306307
runnable.run();
307308
}
308309
);
309310

310311
// Step 1: Fetch progress to be persisted
311-
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
312+
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(jobId);
312313
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
313314
}
314315

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
*/
66
package org.elasticsearch.xpack.ml.dataframe;
77

8-
import org.elasticsearch.Version;
98
import org.elasticsearch.action.ActionListener;
109
import org.elasticsearch.action.index.IndexAction;
1110
import org.elasticsearch.action.index.IndexRequest;
1211
import org.elasticsearch.action.index.IndexResponse;
1312
import org.elasticsearch.action.search.SearchAction;
1413
import org.elasticsearch.action.search.SearchResponse;
1514
import org.elasticsearch.client.Client;
16-
import org.elasticsearch.cluster.service.ClusterService;
1715
import org.elasticsearch.common.settings.Settings;
1816
import org.elasticsearch.common.util.concurrent.ThreadContext;
1917
import org.elasticsearch.search.SearchHit;
@@ -22,10 +20,8 @@
2220
import org.elasticsearch.threadpool.ThreadPool;
2321
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
2422
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
25-
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.TaskParams;
2623
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
2724
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
28-
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
2925
import org.mockito.ArgumentCaptor;
3026
import org.mockito.InOrder;
3127
import org.mockito.stubbing.Answer;
@@ -115,13 +111,13 @@ public void testDetermineStartingState_GivenEmptyProgress() {
115111
assertThat(startingState, equalTo(StartingState.FINISHED));
116112
}
117113

118-
private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAlias) {
114+
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) {
119115
Client client = mock(Client.class);
120116
ThreadPool threadPool = mock(ThreadPool.class);
121117
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
122118
when(client.threadPool()).thenReturn(threadPool);
123119

124-
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse();
120+
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse(1);
125121
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
126122

127123
SearchResponse searchResponse = mock(SearchResponse.class);
@@ -131,40 +127,30 @@ private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAl
131127
IndexResponse indexResponse = mock(IndexResponse.class);
132128
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
133129

134-
TaskParams taskParams = new TaskParams("task_id", Version.CURRENT, Collections.emptyList(), false);
135-
DataFrameAnalyticsTask task =
136-
new DataFrameAnalyticsTask(
137-
0,
138-
"",
139-
"",
140-
null,
141-
null,
142-
client,
143-
mock(ClusterService.class),
144-
mock(DataFrameAnalyticsManager.class),
145-
mock(DataFrameAnalyticsAuditor.class),
146-
taskParams);
147-
task.markAsCompleted();
130+
Runnable runnable = mock(Runnable.class);
131+
132+
DataFrameAnalyticsTask.persistProgress(client, "task_id", runnable);
148133

149134
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
150135

151-
InOrder inOrder = inOrder(client);
136+
InOrder inOrder = inOrder(client, runnable);
152137
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
153138
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
154139
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
140+
inOrder.verify(runnable).run();
155141
inOrder.verifyNoMoreInteractions();
156142

157143
IndexRequest indexRequest = indexRequestCaptor.getValue();
158144
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
159145
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));
160146
}
161147

162-
public void testMarkAsCompleted_ProgressDocumentCreated() {
163-
testMarkAsCompleted(SearchHits.empty(), ".ml-state-write");
148+
public void testPersistProgress_ProgressDocumentCreated() {
149+
testPersistProgress(SearchHits.empty(), ".ml-state-write");
164150
}
165151

166-
public void testMarkAsCompleted_ProgressDocumentUpdated() {
167-
testMarkAsCompleted(
152+
public void testPersistProgress_ProgressDocumentUpdated() {
153+
testPersistProgress(
168154
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f),
169155
".ml-state-dummy");
170156
}

0 commit comments

Comments
 (0)