16
16
import org .elasticsearch .cluster .service .ClusterService ;
17
17
import org .elasticsearch .common .settings .Settings ;
18
18
import org .elasticsearch .common .util .concurrent .ThreadContext ;
19
+ import org .elasticsearch .common .xcontent .DeprecationHandler ;
20
+ import org .elasticsearch .common .xcontent .NamedXContentRegistry ;
21
+ import org .elasticsearch .common .xcontent .XContentParser ;
22
+ import org .elasticsearch .common .xcontent .json .JsonXContent ;
19
23
import org .elasticsearch .persistent .PersistentTasksService ;
20
24
import org .elasticsearch .persistent .UpdatePersistentTaskStatusAction ;
21
25
import org .elasticsearch .search .SearchHit ;
22
26
import org .elasticsearch .search .SearchHits ;
23
27
import org .elasticsearch .tasks .TaskManager ;
24
28
import org .elasticsearch .test .ESTestCase ;
25
29
import org .elasticsearch .threadpool .ThreadPool ;
26
- import org .elasticsearch .xpack .core .ml .action .GetDataFrameAnalyticsStatsAction ;
27
- import org .elasticsearch .xpack .core .ml .action .GetDataFrameAnalyticsStatsActionResponseTests ;
28
30
import org .elasticsearch .xpack .core .ml .action .StartDataFrameAnalyticsAction ;
29
31
import org .elasticsearch .xpack .core .ml .dataframe .DataFrameAnalyticsState ;
30
32
import org .elasticsearch .xpack .core .ml .dataframe .DataFrameAnalyticsTaskState ;
33
+ import org .elasticsearch .xpack .core .ml .job .persistence .AnomalyDetectorsIndex ;
31
34
import org .elasticsearch .xpack .core .ml .utils .PhaseProgress ;
32
35
import org .elasticsearch .xpack .ml .dataframe .DataFrameAnalyticsTask .StartingState ;
33
36
import org .elasticsearch .xpack .ml .dataframe .stats .ProgressTracker ;
36
39
import org .mockito .InOrder ;
37
40
import org .mockito .stubbing .Answer ;
38
41
42
+ import java .io .IOException ;
39
43
import java .util .Arrays ;
40
44
import java .util .Collections ;
41
45
import java .util .List ;
@@ -125,14 +129,25 @@ public void testDetermineStartingState_GivenEmptyProgress() {
125
129
assertThat (startingState , equalTo (StartingState .FINISHED ));
126
130
}
127
131
128
- private void testPersistProgress (SearchHits searchHits , String expectedIndexOrAlias ) {
132
+ private void testPersistProgress (SearchHits searchHits , String expectedIndexOrAlias ) throws IOException {
129
133
Client client = mock (Client .class );
134
+ when (client .settings ()).thenReturn (Settings .EMPTY );
130
135
ThreadPool threadPool = mock (ThreadPool .class );
131
136
when (threadPool .getThreadContext ()).thenReturn (new ThreadContext (Settings .EMPTY ));
132
137
when (client .threadPool ()).thenReturn (threadPool );
133
138
134
- GetDataFrameAnalyticsStatsAction .Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests .randomResponse (1 );
135
- doAnswer (withResponse (getStatsResponse )).when (client ).execute (eq (GetDataFrameAnalyticsStatsAction .INSTANCE ), any (), any ());
139
+ ClusterService clusterService = mock (ClusterService .class );
140
+ DataFrameAnalyticsManager analyticsManager = mock (DataFrameAnalyticsManager .class );
141
+ DataFrameAnalyticsAuditor auditor = mock (DataFrameAnalyticsAuditor .class );
142
+ PersistentTasksService persistentTasksService = new PersistentTasksService (clusterService , threadPool , client );
143
+
144
+ List <PhaseProgress > progress = Arrays .asList (
145
+ new PhaseProgress (ProgressTracker .REINDEXING , 100 ),
146
+ new PhaseProgress (ProgressTracker .LOADING_DATA , 50 ),
147
+ new PhaseProgress (ProgressTracker .WRITING_RESULTS , 0 ));
148
+
149
+ StartDataFrameAnalyticsAction .TaskParams taskParams = new StartDataFrameAnalyticsAction .TaskParams (
150
+ "task_id" , Version .CURRENT , progress , false );
136
151
137
152
SearchResponse searchResponse = mock (SearchResponse .class );
138
153
when (searchResponse .getHits ()).thenReturn (searchHits );
@@ -141,14 +156,20 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
141
156
IndexResponse indexResponse = mock (IndexResponse .class );
142
157
doAnswer (withResponse (indexResponse )).when (client ).execute (eq (IndexAction .INSTANCE ), any (), any ());
143
158
159
+ TaskManager taskManager = mock (TaskManager .class );
160
+
144
161
Runnable runnable = mock (Runnable .class );
145
162
146
- DataFrameAnalyticsTask .persistProgress (client , "task_id" , runnable );
163
+ DataFrameAnalyticsTask task =
164
+ new DataFrameAnalyticsTask (
165
+ 123 , "type" , "action" , null , Collections .emptyMap (), client , clusterService , analyticsManager , auditor , taskParams );
166
+ task .init (persistentTasksService , taskManager , "task-id" , 42 );
167
+
168
+ task .persistProgress (client , "task_id" , runnable );
147
169
148
170
ArgumentCaptor <IndexRequest > indexRequestCaptor = ArgumentCaptor .forClass (IndexRequest .class );
149
171
150
172
InOrder inOrder = inOrder (client , runnable );
151
- inOrder .verify (client ).execute (eq (GetDataFrameAnalyticsStatsAction .INSTANCE ), any (), any ());
152
173
inOrder .verify (client ).execute (eq (SearchAction .INSTANCE ), any (), any ());
153
174
inOrder .verify (client ).execute (eq (IndexAction .INSTANCE ), indexRequestCaptor .capture (), any ());
154
175
inOrder .verify (runnable ).run ();
@@ -157,27 +178,33 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
157
178
IndexRequest indexRequest = indexRequestCaptor .getValue ();
158
179
assertThat (indexRequest .index (), equalTo (expectedIndexOrAlias ));
159
180
assertThat (indexRequest .id (), equalTo ("data_frame_analytics-task_id-progress" ));
181
+
182
+ try (XContentParser parser = JsonXContent .jsonXContent .createParser (
183
+ NamedXContentRegistry .EMPTY , DeprecationHandler .IGNORE_DEPRECATIONS , indexRequest .source ().utf8ToString ())) {
184
+ StoredProgress parsedProgress = StoredProgress .PARSER .apply (parser , null );
185
+ assertThat (parsedProgress .get (), equalTo (progress ));
186
+ }
160
187
}
161
188
162
- public void testPersistProgress_ProgressDocumentCreated () {
189
+ public void testPersistProgress_ProgressDocumentCreated () throws IOException {
163
190
testPersistProgress (SearchHits .empty (), ".ml-state-write" );
164
191
}
165
192
166
- public void testPersistProgress_ProgressDocumentUpdated () {
193
+ public void testPersistProgress_ProgressDocumentUpdated () throws IOException {
167
194
testPersistProgress (
168
195
new SearchHits (new SearchHit []{ SearchHit .createFromMap (Collections .singletonMap ("_index" , ".ml-state-dummy" )) }, null , 0.0f ),
169
196
".ml-state-dummy" );
170
197
}
171
198
172
- public void testSetFailed () {
199
+ public void testSetFailed () throws IOException {
173
200
testSetFailed (false );
174
201
}
175
202
176
- public void testSetFailedDuringNodeShutdown () {
203
+ public void testSetFailedDuringNodeShutdown () throws IOException {
177
204
testSetFailed (true );
178
205
}
179
206
180
- private void testSetFailed (boolean nodeShuttingDown ) {
207
+ private void testSetFailed (boolean nodeShuttingDown ) throws IOException {
181
208
ThreadPool threadPool = mock (ThreadPool .class );
182
209
when (threadPool .getThreadContext ()).thenReturn (new ThreadContext (Settings .EMPTY ));
183
210
Client client = mock (Client .class );
@@ -190,15 +217,25 @@ private void testSetFailed(boolean nodeShuttingDown) {
190
217
PersistentTasksService persistentTasksService = new PersistentTasksService (clusterService , mock (ThreadPool .class ), client );
191
218
TaskManager taskManager = mock (TaskManager .class );
192
219
220
+ List <PhaseProgress > progress = Arrays .asList (
221
+ new PhaseProgress (ProgressTracker .REINDEXING , 100 ),
222
+ new PhaseProgress (ProgressTracker .LOADING_DATA , 100 ),
223
+ new PhaseProgress (ProgressTracker .WRITING_RESULTS , 30 ));
224
+
193
225
StartDataFrameAnalyticsAction .TaskParams taskParams =
194
226
new StartDataFrameAnalyticsAction .TaskParams (
195
227
"job-id" ,
196
228
Version .CURRENT ,
197
- Arrays .asList (
198
- new PhaseProgress (ProgressTracker .REINDEXING , 0 ),
199
- new PhaseProgress (ProgressTracker .LOADING_DATA , 0 ),
200
- new PhaseProgress (ProgressTracker .WRITING_RESULTS , 0 )),
229
+ progress ,
201
230
false );
231
+
232
+ SearchResponse searchResponse = mock (SearchResponse .class );
233
+ when (searchResponse .getHits ()).thenReturn (SearchHits .empty ());
234
+ doAnswer (withResponse (searchResponse )).when (client ).execute (eq (SearchAction .INSTANCE ), any (), any ());
235
+
236
+ IndexResponse indexResponse = mock (IndexResponse .class );
237
+ doAnswer (withResponse (indexResponse )).when (client ).execute (eq (IndexAction .INSTANCE ), any (), any ());
238
+
202
239
DataFrameAnalyticsTask task =
203
240
new DataFrameAnalyticsTask (
204
241
123 , "type" , "action" , null , Collections .emptyMap (), client , clusterService , analyticsManager , auditor , taskParams );
@@ -210,7 +247,23 @@ private void testSetFailed(boolean nodeShuttingDown) {
210
247
verify (analyticsManager ).isNodeShuttingDown ();
211
248
verify (client , atLeastOnce ()).settings ();
212
249
verify (client , atLeastOnce ()).threadPool ();
250
+
213
251
if (nodeShuttingDown == false ) {
252
+ // Verify progress was persisted
253
+ ArgumentCaptor <IndexRequest > indexRequestCaptor = ArgumentCaptor .forClass (IndexRequest .class );
254
+ verify (client ).execute (eq (SearchAction .INSTANCE ), any (), any ());
255
+ verify (client ).execute (eq (IndexAction .INSTANCE ), indexRequestCaptor .capture (), any ());
256
+
257
+ IndexRequest indexRequest = indexRequestCaptor .getValue ();
258
+ assertThat (indexRequest .index (), equalTo (AnomalyDetectorsIndex .jobStateIndexWriteAlias ()));
259
+ assertThat (indexRequest .id (), equalTo ("data_frame_analytics-job-id-progress" ));
260
+
261
+ try (XContentParser parser = JsonXContent .jsonXContent .createParser (
262
+ NamedXContentRegistry .EMPTY , DeprecationHandler .IGNORE_DEPRECATIONS , indexRequest .source ().utf8ToString ())) {
263
+ StoredProgress parsedProgress = StoredProgress .PARSER .apply (parser , null );
264
+ assertThat (parsedProgress .get (), equalTo (progress ));
265
+ }
266
+
214
267
verify (client ).execute (
215
268
same (UpdatePersistentTaskStatusAction .INSTANCE ),
216
269
eq (new UpdatePersistentTaskStatusAction .Request (
0 commit comments