16
16
import org .elasticsearch .cluster .service .ClusterService ;
17
17
import org .elasticsearch .common .settings .Settings ;
18
18
import org .elasticsearch .common .util .concurrent .ThreadContext ;
19
+ import org .elasticsearch .common .xcontent .DeprecationHandler ;
20
+ import org .elasticsearch .common .xcontent .NamedXContentRegistry ;
21
+ import org .elasticsearch .common .xcontent .XContentParser ;
22
+ import org .elasticsearch .common .xcontent .json .JsonXContent ;
19
23
import org .elasticsearch .persistent .PersistentTasksService ;
20
24
import org .elasticsearch .persistent .UpdatePersistentTaskStatusAction ;
21
25
import org .elasticsearch .search .SearchHit ;
22
26
import org .elasticsearch .search .SearchHits ;
23
27
import org .elasticsearch .tasks .TaskManager ;
24
28
import org .elasticsearch .test .ESTestCase ;
25
29
import org .elasticsearch .threadpool .ThreadPool ;
26
- import org .elasticsearch .xpack .core .ml .action .GetDataFrameAnalyticsStatsAction ;
27
- import org .elasticsearch .xpack .core .ml .action .GetDataFrameAnalyticsStatsActionResponseTests ;
28
30
import org .elasticsearch .xpack .core .ml .action .StartDataFrameAnalyticsAction ;
29
31
import org .elasticsearch .xpack .core .ml .dataframe .DataFrameAnalyticsState ;
30
32
import org .elasticsearch .xpack .core .ml .dataframe .DataFrameAnalyticsTaskState ;
33
+ import org .elasticsearch .xpack .core .ml .job .persistence .AnomalyDetectorsIndex ;
31
34
import org .elasticsearch .xpack .core .ml .utils .PhaseProgress ;
32
35
import org .elasticsearch .xpack .ml .dataframe .DataFrameAnalyticsTask .StartingState ;
33
36
import org .elasticsearch .xpack .ml .dataframe .stats .ProgressTracker ;
36
39
import org .mockito .InOrder ;
37
40
import org .mockito .stubbing .Answer ;
38
41
42
+ import java .io .IOException ;
39
43
import java .util .Arrays ;
40
44
import java .util .Collections ;
41
45
import java .util .List ;
@@ -126,14 +130,24 @@ public void testDetermineStartingState_GivenEmptyProgress() {
126
130
assertThat (startingState , equalTo (StartingState .FINISHED ));
127
131
}
128
132
129
- private void testPersistProgress (SearchHits searchHits , String expectedIndexOrAlias ) {
133
+ private void testPersistProgress (SearchHits searchHits , String expectedIndexOrAlias ) throws IOException {
130
134
Client client = mock (Client .class );
131
135
ThreadPool threadPool = mock (ThreadPool .class );
132
136
when (threadPool .getThreadContext ()).thenReturn (new ThreadContext (Settings .EMPTY ));
133
137
when (client .threadPool ()).thenReturn (threadPool );
134
138
135
- GetDataFrameAnalyticsStatsAction .Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests .randomResponse (1 );
136
- doAnswer (withResponse (getStatsResponse )).when (client ).execute (eq (GetDataFrameAnalyticsStatsAction .INSTANCE ), any (), any ());
139
+ ClusterService clusterService = mock (ClusterService .class );
140
+ DataFrameAnalyticsManager analyticsManager = mock (DataFrameAnalyticsManager .class );
141
+ DataFrameAnalyticsAuditor auditor = mock (DataFrameAnalyticsAuditor .class );
142
+ PersistentTasksService persistentTasksService = new PersistentTasksService (clusterService , mock (ThreadPool .class ), client );
143
+
144
+ List <PhaseProgress > progress = List .of (
145
+ new PhaseProgress (ProgressTracker .REINDEXING , 100 ),
146
+ new PhaseProgress (ProgressTracker .LOADING_DATA , 50 ),
147
+ new PhaseProgress (ProgressTracker .WRITING_RESULTS , 0 ));
148
+
149
+ StartDataFrameAnalyticsAction .TaskParams taskParams = new StartDataFrameAnalyticsAction .TaskParams (
150
+ "task_id" , Version .CURRENT , progress , false );
137
151
138
152
SearchResponse searchResponse = mock (SearchResponse .class );
139
153
when (searchResponse .getHits ()).thenReturn (searchHits );
@@ -142,14 +156,20 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
142
156
IndexResponse indexResponse = mock (IndexResponse .class );
143
157
doAnswer (withResponse (indexResponse )).when (client ).execute (eq (IndexAction .INSTANCE ), any (), any ());
144
158
159
+ TaskManager taskManager = mock (TaskManager .class );
160
+
145
161
Runnable runnable = mock (Runnable .class );
146
162
147
- DataFrameAnalyticsTask .persistProgress (client , "task_id" , runnable );
163
+ DataFrameAnalyticsTask task =
164
+ new DataFrameAnalyticsTask (
165
+ 123 , "type" , "action" , null , Map .of (), client , clusterService , analyticsManager , auditor , taskParams );
166
+ task .init (persistentTasksService , taskManager , "task-id" , 42 );
167
+
168
+ task .persistProgress (client , "task_id" , runnable );
148
169
149
170
ArgumentCaptor <IndexRequest > indexRequestCaptor = ArgumentCaptor .forClass (IndexRequest .class );
150
171
151
172
InOrder inOrder = inOrder (client , runnable );
152
- inOrder .verify (client ).execute (eq (GetDataFrameAnalyticsStatsAction .INSTANCE ), any (), any ());
153
173
inOrder .verify (client ).execute (eq (SearchAction .INSTANCE ), any (), any ());
154
174
inOrder .verify (client ).execute (eq (IndexAction .INSTANCE ), indexRequestCaptor .capture (), any ());
155
175
inOrder .verify (runnable ).run ();
@@ -158,27 +178,33 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
158
178
IndexRequest indexRequest = indexRequestCaptor .getValue ();
159
179
assertThat (indexRequest .index (), equalTo (expectedIndexOrAlias ));
160
180
assertThat (indexRequest .id (), equalTo ("data_frame_analytics-task_id-progress" ));
181
+
182
+ try (XContentParser parser = JsonXContent .jsonXContent .createParser (
183
+ NamedXContentRegistry .EMPTY , DeprecationHandler .IGNORE_DEPRECATIONS , indexRequest .source ().utf8ToString ())) {
184
+ StoredProgress parsedProgress = StoredProgress .PARSER .apply (parser , null );
185
+ assertThat (parsedProgress .get (), equalTo (progress ));
186
+ }
161
187
}
162
188
163
- public void testPersistProgress_ProgressDocumentCreated () {
189
+ public void testPersistProgress_ProgressDocumentCreated () throws IOException {
164
190
testPersistProgress (SearchHits .empty (), ".ml-state-write" );
165
191
}
166
192
167
- public void testPersistProgress_ProgressDocumentUpdated () {
193
+ public void testPersistProgress_ProgressDocumentUpdated () throws IOException {
168
194
testPersistProgress (
169
195
new SearchHits (new SearchHit []{ SearchHit .createFromMap (Map .of ("_index" , ".ml-state-dummy" )) }, null , 0.0f ),
170
196
".ml-state-dummy" );
171
197
}
172
198
173
- public void testSetFailed () {
199
+ public void testSetFailed () throws IOException {
174
200
testSetFailed (false );
175
201
}
176
202
177
- public void testSetFailedDuringNodeShutdown () {
203
+ public void testSetFailedDuringNodeShutdown () throws IOException {
178
204
testSetFailed (true );
179
205
}
180
206
181
- private void testSetFailed (boolean nodeShuttingDown ) {
207
+ private void testSetFailed (boolean nodeShuttingDown ) throws IOException {
182
208
ThreadPool threadPool = mock (ThreadPool .class );
183
209
when (threadPool .getThreadContext ()).thenReturn (new ThreadContext (Settings .EMPTY ));
184
210
Client client = mock (Client .class );
@@ -190,15 +216,25 @@ private void testSetFailed(boolean nodeShuttingDown) {
190
216
PersistentTasksService persistentTasksService = new PersistentTasksService (clusterService , mock (ThreadPool .class ), client );
191
217
TaskManager taskManager = mock (TaskManager .class );
192
218
219
+ List <PhaseProgress > progress = List .of (
220
+ new PhaseProgress (ProgressTracker .REINDEXING , 100 ),
221
+ new PhaseProgress (ProgressTracker .LOADING_DATA , 100 ),
222
+ new PhaseProgress (ProgressTracker .WRITING_RESULTS , 30 ));
223
+
193
224
StartDataFrameAnalyticsAction .TaskParams taskParams =
194
225
new StartDataFrameAnalyticsAction .TaskParams (
195
226
"job-id" ,
196
227
Version .CURRENT ,
197
- List .of (
198
- new PhaseProgress (ProgressTracker .REINDEXING , 0 ),
199
- new PhaseProgress (ProgressTracker .LOADING_DATA , 0 ),
200
- new PhaseProgress (ProgressTracker .WRITING_RESULTS , 0 )),
228
+ progress ,
201
229
false );
230
+
231
+ SearchResponse searchResponse = mock (SearchResponse .class );
232
+ when (searchResponse .getHits ()).thenReturn (SearchHits .empty ());
233
+ doAnswer (withResponse (searchResponse )).when (client ).execute (eq (SearchAction .INSTANCE ), any (), any ());
234
+
235
+ IndexResponse indexResponse = mock (IndexResponse .class );
236
+ doAnswer (withResponse (indexResponse )).when (client ).execute (eq (IndexAction .INSTANCE ), any (), any ());
237
+
202
238
DataFrameAnalyticsTask task =
203
239
new DataFrameAnalyticsTask (
204
240
123 , "type" , "action" , null , Map .of (), client , clusterService , analyticsManager , auditor , taskParams );
@@ -210,7 +246,23 @@ private void testSetFailed(boolean nodeShuttingDown) {
210
246
verify (analyticsManager ).isNodeShuttingDown ();
211
247
verify (client , atLeastOnce ()).settings ();
212
248
verify (client , atLeastOnce ()).threadPool ();
249
+
213
250
if (nodeShuttingDown == false ) {
251
+ // Verify progress was persisted
252
+ ArgumentCaptor <IndexRequest > indexRequestCaptor = ArgumentCaptor .forClass (IndexRequest .class );
253
+ verify (client ).execute (eq (SearchAction .INSTANCE ), any (), any ());
254
+ verify (client ).execute (eq (IndexAction .INSTANCE ), indexRequestCaptor .capture (), any ());
255
+
256
+ IndexRequest indexRequest = indexRequestCaptor .getValue ();
257
+ assertThat (indexRequest .index (), equalTo (AnomalyDetectorsIndex .jobStateIndexWriteAlias ()));
258
+ assertThat (indexRequest .id (), equalTo ("data_frame_analytics-job-id-progress" ));
259
+
260
+ try (XContentParser parser = JsonXContent .jsonXContent .createParser (
261
+ NamedXContentRegistry .EMPTY , DeprecationHandler .IGNORE_DEPRECATIONS , indexRequest .source ().utf8ToString ())) {
262
+ StoredProgress parsedProgress = StoredProgress .PARSER .apply (parser , null );
263
+ assertThat (parsedProgress .get (), equalTo (progress ));
264
+ }
265
+
214
266
verify (client ).execute (
215
267
same (UpdatePersistentTaskStatusAction .INSTANCE ),
216
268
eq (new UpdatePersistentTaskStatusAction .Request (
0 commit comments