5
5
*/
6
6
package org .elasticsearch .xpack .ml .dataframe ;
7
7
8
- import org .elasticsearch .Version ;
9
8
import org .elasticsearch .action .ActionListener ;
10
9
import org .elasticsearch .action .index .IndexAction ;
11
10
import org .elasticsearch .action .index .IndexRequest ;
12
11
import org .elasticsearch .action .index .IndexResponse ;
13
12
import org .elasticsearch .action .search .SearchAction ;
14
13
import org .elasticsearch .action .search .SearchResponse ;
15
14
import org .elasticsearch .client .Client ;
16
- import org .elasticsearch .cluster .service .ClusterService ;
17
15
import org .elasticsearch .common .settings .Settings ;
18
16
import org .elasticsearch .common .util .concurrent .ThreadContext ;
19
17
import org .elasticsearch .search .SearchHit ;
22
20
import org .elasticsearch .threadpool .ThreadPool ;
23
21
import org .elasticsearch .xpack .core .ml .action .GetDataFrameAnalyticsStatsAction ;
24
22
import org .elasticsearch .xpack .core .ml .action .GetDataFrameAnalyticsStatsActionResponseTests ;
25
- import org .elasticsearch .xpack .core .ml .action .StartDataFrameAnalyticsAction .TaskParams ;
26
23
import org .elasticsearch .xpack .core .ml .utils .PhaseProgress ;
27
24
import org .elasticsearch .xpack .ml .dataframe .DataFrameAnalyticsTask .StartingState ;
28
- import org .elasticsearch .xpack .ml .notifications .DataFrameAnalyticsAuditor ;
29
25
import org .mockito .ArgumentCaptor ;
30
26
import org .mockito .InOrder ;
31
27
import org .mockito .stubbing .Answer ;
@@ -115,13 +111,13 @@ public void testDetermineStartingState_GivenEmptyProgress() {
115
111
assertThat (startingState , equalTo (StartingState .FINISHED ));
116
112
}
117
113
118
- private void testMarkAsCompleted (SearchHits searchHits , String expectedIndexOrAlias ) {
114
+ private void testPersistProgress (SearchHits searchHits , String expectedIndexOrAlias ) {
119
115
Client client = mock (Client .class );
120
116
ThreadPool threadPool = mock (ThreadPool .class );
121
117
when (threadPool .getThreadContext ()).thenReturn (new ThreadContext (Settings .EMPTY ));
122
118
when (client .threadPool ()).thenReturn (threadPool );
123
119
124
- GetDataFrameAnalyticsStatsAction .Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests .randomResponse ();
120
+ GetDataFrameAnalyticsStatsAction .Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests .randomResponse (1 );
125
121
doAnswer (withResponse (getStatsResponse )).when (client ).execute (eq (GetDataFrameAnalyticsStatsAction .INSTANCE ), any (), any ());
126
122
127
123
SearchResponse searchResponse = mock (SearchResponse .class );
@@ -131,40 +127,30 @@ private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAl
131
127
IndexResponse indexResponse = mock (IndexResponse .class );
132
128
doAnswer (withResponse (indexResponse )).when (client ).execute (eq (IndexAction .INSTANCE ), any (), any ());
133
129
134
- TaskParams taskParams = new TaskParams ("task_id" , Version .CURRENT , Collections .emptyList (), false );
135
- DataFrameAnalyticsTask task =
136
- new DataFrameAnalyticsTask (
137
- 0 ,
138
- "" ,
139
- "" ,
140
- null ,
141
- null ,
142
- client ,
143
- mock (ClusterService .class ),
144
- mock (DataFrameAnalyticsManager .class ),
145
- mock (DataFrameAnalyticsAuditor .class ),
146
- taskParams );
147
- task .markAsCompleted ();
130
+ Runnable runnable = mock (Runnable .class );
131
+
132
+ DataFrameAnalyticsTask .persistProgress (client , "task_id" , runnable );
148
133
149
134
ArgumentCaptor <IndexRequest > indexRequestCaptor = ArgumentCaptor .forClass (IndexRequest .class );
150
135
151
- InOrder inOrder = inOrder (client );
136
+ InOrder inOrder = inOrder (client , runnable );
152
137
inOrder .verify (client ).execute (eq (GetDataFrameAnalyticsStatsAction .INSTANCE ), any (), any ());
153
138
inOrder .verify (client ).execute (eq (SearchAction .INSTANCE ), any (), any ());
154
139
inOrder .verify (client ).execute (eq (IndexAction .INSTANCE ), indexRequestCaptor .capture (), any ());
140
+ inOrder .verify (runnable ).run ();
155
141
inOrder .verifyNoMoreInteractions ();
156
142
157
143
IndexRequest indexRequest = indexRequestCaptor .getValue ();
158
144
assertThat (indexRequest .index (), equalTo (expectedIndexOrAlias ));
159
145
assertThat (indexRequest .id (), equalTo ("data_frame_analytics-task_id-progress" ));
160
146
}
161
147
162
- public void testMarkAsCompleted_ProgressDocumentCreated () {
163
- testMarkAsCompleted (SearchHits .empty (), ".ml-state-write" );
148
+ public void testPersistProgress_ProgressDocumentCreated () {
149
+ testPersistProgress (SearchHits .empty (), ".ml-state-write" );
164
150
}
165
151
166
- public void testMarkAsCompleted_ProgressDocumentUpdated () {
167
- testMarkAsCompleted (
152
+ public void testPersistProgress_ProgressDocumentUpdated () {
153
+ testPersistProgress (
168
154
new SearchHits (new SearchHit []{ SearchHit .createFromMap (Collections .singletonMap ("_index" , ".ml-state-dummy" )) }, null , 0.0f ),
169
155
".ml-state-dummy" );
170
156
}
0 commit comments