15
15
import org .elasticsearch .action .ActionListener ;
16
16
import org .elasticsearch .action .DocWriteRequest ;
17
17
import org .elasticsearch .action .bulk .BulkAction ;
18
- import org .elasticsearch .action .bulk .BulkRequest ;
18
+ import org .elasticsearch .action .bulk .BulkItemResponse ;
19
+ import org .elasticsearch .action .bulk .BulkRequestBuilder ;
19
20
import org .elasticsearch .action .bulk .BulkResponse ;
20
21
import org .elasticsearch .action .index .IndexRequest ;
21
22
import org .elasticsearch .action .search .MultiSearchAction ;
86
87
import java .util .Map ;
87
88
import java .util .Set ;
88
89
import java .util .TreeSet ;
90
+ import java .util .stream .Collectors ;
89
91
90
92
import static org .elasticsearch .xpack .core .ClientHelper .ML_ORIGIN ;
91
93
import static org .elasticsearch .xpack .core .ClientHelper .executeAsyncWithOrigin ;
@@ -96,6 +98,9 @@ public class TrainedModelProvider {
96
98
public static final Set <String > MODELS_STORED_AS_RESOURCE = Collections .singleton ("lang_ident_model_1" );
97
99
private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/" ;
98
100
private static final String MODEL_RESOURCE_FILE_EXT = ".json" ;
101
+ private static final int COMPRESSED_STRING_CHUNK_SIZE = 16 * 1024 * 1024 ;
102
+ private static final int MAX_NUM_DEFINITION_DOCS = 100 ;
103
+ private static final int MAX_COMPRESSED_STRING_SIZE = COMPRESSED_STRING_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS ;
99
104
100
105
private static final Logger logger = LogManager .getLogger (TrainedModelProvider .class );
101
106
private final Client client ;
@@ -139,30 +144,41 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
139
144
private void storeTrainedModelAndDefinition (TrainedModelConfig trainedModelConfig ,
140
145
ActionListener <Boolean > listener ) {
141
146
142
- TrainedModelDefinitionDoc trainedModelDefinitionDoc ;
147
+ List < TrainedModelDefinitionDoc > trainedModelDefinitionDocs = new ArrayList <>() ;
143
148
try {
144
- // TODO should we check length against allowed stream size???
145
149
String compressedString = trainedModelConfig .getCompressedDefinition ();
146
- trainedModelDefinitionDoc = new TrainedModelDefinitionDoc .Builder ()
147
- .setDocNum (0 )
148
- .setModelId (trainedModelConfig .getModelId ())
149
- .setCompressedString (compressedString )
150
- .setCompressionVersion (TrainedModelConfig .CURRENT_DEFINITION_COMPRESSION_VERSION )
151
- .setDefinitionLength (compressedString .length ())
152
- .setTotalDefinitionLength (compressedString .length ())
153
- .build ();
150
+ if (compressedString .length () > MAX_COMPRESSED_STRING_SIZE ) {
151
+ listener .onFailure (
152
+ ExceptionsHelper .badRequestException (
153
+ "Unable to store model as compressed definition has length [{}] the limit is [{}]" ,
154
+ compressedString .length (),
155
+ MAX_COMPRESSED_STRING_SIZE ));
156
+ return ;
157
+ }
158
+ List <String > chunkedStrings = chunkStringWithSize (compressedString , COMPRESSED_STRING_CHUNK_SIZE );
159
+ for (int i = 0 ; i < chunkedStrings .size (); ++i ) {
160
+ trainedModelDefinitionDocs .add (new TrainedModelDefinitionDoc .Builder ()
161
+ .setDocNum (i )
162
+ .setModelId (trainedModelConfig .getModelId ())
163
+ .setCompressedString (chunkedStrings .get (i ))
164
+ .setCompressionVersion (TrainedModelConfig .CURRENT_DEFINITION_COMPRESSION_VERSION )
165
+ .setDefinitionLength (chunkedStrings .get (i ).length ())
166
+ .setTotalDefinitionLength (compressedString .length ())
167
+ .build ());
168
+ }
154
169
} catch (IOException ex ) {
155
170
listener .onFailure (ExceptionsHelper .serverError (
156
- "Unexpected IOException while serializing definition for storage for model [" + trainedModelConfig .getModelId () + "]" ,
157
- ex ));
171
+ "Unexpected IOException while serializing definition for storage for model [{}]" ,
172
+ ex ,
173
+ trainedModelConfig .getModelId ()));
158
174
return ;
159
175
}
160
176
161
- BulkRequest bulkRequest = client .prepareBulk (InferenceIndexConstants .LATEST_INDEX_NAME , MapperService .SINGLE_MAPPING_NAME )
177
+ BulkRequestBuilder bulkRequest = client .prepareBulk (InferenceIndexConstants .LATEST_INDEX_NAME , MapperService .SINGLE_MAPPING_NAME )
162
178
.setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
163
- .add (createRequest (trainedModelConfig .getModelId (), trainedModelConfig ))
164
- . add ( createRequest ( TrainedModelDefinitionDoc . docId ( trainedModelConfig . getModelId (), 0 ), trainedModelDefinitionDoc ))
165
- . request ( );
179
+ .add (createRequest (trainedModelConfig .getModelId (), trainedModelConfig ));
180
+ trainedModelDefinitionDocs . forEach ( defDoc ->
181
+ bulkRequest . add ( createRequest ( TrainedModelDefinitionDoc . docId ( trainedModelConfig . getModelId (), defDoc . getDocNum ()), defDoc )) );
166
182
167
183
ActionListener <Boolean > wrappedListener = ActionListener .wrap (
168
184
listener ::onResponse ,
@@ -182,9 +198,8 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi
182
198
183
199
ActionListener <BulkResponse > bulkResponseActionListener = ActionListener .wrap (
184
200
r -> {
185
- assert r .getItems ().length == 2 ;
201
+ assert r .getItems ().length == trainedModelDefinitionDocs . size () + 1 ;
186
202
if (r .getItems ()[0 ].isFailed ()) {
187
-
188
203
logger .error (new ParameterizedMessage (
189
204
"[{}] failed to store trained model config for inference" ,
190
205
trainedModelConfig .getModelId ()),
@@ -193,20 +208,26 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi
193
208
wrappedListener .onFailure (r .getItems ()[0 ].getFailure ().getCause ());
194
209
return ;
195
210
}
196
- if (r .getItems ()[1 ].isFailed ()) {
211
+ if (r .hasFailures ()) {
212
+ Exception firstFailure = Arrays .stream (r .getItems ())
213
+ .filter (BulkItemResponse ::isFailed )
214
+ .map (BulkItemResponse ::getFailure )
215
+ .map (BulkItemResponse .Failure ::getCause )
216
+ .findFirst ()
217
+ .orElse (new Exception ("unknown failure" ));
197
218
logger .error (new ParameterizedMessage (
198
219
"[{}] failed to store trained model definition for inference" ,
199
220
trainedModelConfig .getModelId ()),
200
- r . getItems ()[ 1 ]. getFailure (). getCause () );
201
- wrappedListener .onFailure (r . getItems ()[ 1 ]. getFailure (). getCause () );
221
+ firstFailure );
222
+ wrappedListener .onFailure (firstFailure );
202
223
return ;
203
224
}
204
225
wrappedListener .onResponse (true );
205
226
},
206
227
wrappedListener ::onFailure
207
228
);
208
229
209
- executeAsyncWithOrigin (client , ML_ORIGIN , BulkAction .INSTANCE , bulkRequest , bulkResponseActionListener );
230
+ executeAsyncWithOrigin (client , ML_ORIGIN , BulkAction .INSTANCE , bulkRequest . request () , bulkResponseActionListener );
210
231
}
211
232
212
233
public void getTrainedModel (final String modelId , final boolean includeDefinition , final ActionListener <TrainedModelConfig > listener ) {
@@ -235,11 +256,20 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio
235
256
if (includeDefinition ) {
236
257
multiSearchRequestBuilder .add (client .prepareSearch (InferenceIndexConstants .INDEX_PATTERN )
237
258
.setQuery (QueryBuilders .constantScoreQuery (QueryBuilders
238
- .idsQuery ()
239
- .addIds (TrainedModelDefinitionDoc .docId (modelId , 0 ))))
240
- // use sort to get the last
259
+ .boolQuery ()
260
+ .filter (QueryBuilders .termQuery (TrainedModelConfig .MODEL_ID .getPreferredName (), modelId ))
261
+ .filter (QueryBuilders .termQuery (InferenceIndexConstants .DOC_TYPE .getPreferredName (), TrainedModelDefinitionDoc .NAME ))))
262
+ // There should be AT MOST these many docs. There might be more if definitions have been reindex to newer indices
263
+ // If this ends up getting duplicate groups of definition documents, the parsing logic will throw away any doc that
264
+ // is in a different index than the first index seen.
265
+ .setSize (MAX_NUM_DEFINITION_DOCS )
266
+ // First find the latest index
241
267
.addSort ("_index" , SortOrder .DESC )
242
- .setSize (1 )
268
+ // Then, sort by doc_num
269
+ .addSort (SortBuilders .fieldSort (TrainedModelDefinitionDoc .DOC_NUM .getPreferredName ())
270
+ .order (SortOrder .ASC )
271
+ // We need this for the search not to fail when there are no mappings yet in the index
272
+ .unmappedType ("long" ))
243
273
.request ());
244
274
}
245
275
@@ -259,15 +289,18 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio
259
289
260
290
if (includeDefinition ) {
261
291
try {
262
- TrainedModelDefinitionDoc doc = handleSearchItem (multiSearchResponse .getResponses ()[1 ],
292
+ List < TrainedModelDefinitionDoc > docs = handleSearchItems (multiSearchResponse .getResponses ()[1 ],
263
293
modelId ,
264
294
this ::parseModelDefinitionDocLenientlyFromSource );
265
- if (doc .getCompressedString ().length () != doc .getTotalDefinitionLength ()) {
295
+ String compressedString = docs .stream ()
296
+ .map (TrainedModelDefinitionDoc ::getCompressedString )
297
+ .collect (Collectors .joining ());
298
+ if (compressedString .length () != docs .get (0 ).getTotalDefinitionLength ()) {
266
299
listener .onFailure (ExceptionsHelper .serverError (
267
300
Messages .getMessage (Messages .MODEL_DEFINITION_TRUNCATED , modelId )));
268
301
return ;
269
302
}
270
- builder .setDefinitionFromString (doc . getCompressedString () );
303
+ builder .setDefinitionFromString (compressedString );
271
304
} catch (ResourceNotFoundException ex ) {
272
305
listener .onFailure (new ResourceNotFoundException (
273
306
Messages .getMessage (Messages .MODEL_DEFINITION_NOT_FOUND , modelId )));
@@ -678,13 +711,36 @@ private Set<String> matchedResourceIds(String[] tokens) {
678
711
private static <T > T handleSearchItem (MultiSearchResponse .Item item ,
679
712
String resourceId ,
680
713
CheckedBiFunction <BytesReference , String , T , Exception > parseLeniently ) throws Exception {
714
+ return handleSearchItems (item , resourceId , parseLeniently ).get (0 );
715
+ }
716
+
717
+ // NOTE: This ignores any results that are in a different index than the first one seen in the search response.
718
+ private static <T > List <T > handleSearchItems (MultiSearchResponse .Item item ,
719
+ String resourceId ,
720
+ CheckedBiFunction <BytesReference , String , T , Exception > parseLeniently ) throws Exception {
681
721
if (item .isFailure ()) {
682
722
throw item .getFailure ();
683
723
}
684
724
if (item .getResponse ().getHits ().getHits ().length == 0 ) {
685
725
throw new ResourceNotFoundException (resourceId );
686
726
}
687
- return parseLeniently .apply (item .getResponse ().getHits ().getHits ()[0 ].getSourceRef (), resourceId );
727
+ List <T > results = new ArrayList <>(item .getResponse ().getHits ().getHits ().length );
728
+ String initialIndex = item .getResponse ().getHits ().getHits ()[0 ].getIndex ();
729
+ for (SearchHit hit : item .getResponse ().getHits ().getHits ()) {
730
+ // We don't want to spread across multiple backing indices
731
+ if (hit .getIndex ().equals (initialIndex )) {
732
+ results .add (parseLeniently .apply (hit .getSourceRef (), resourceId ));
733
+ }
734
+ }
735
+ return results ;
736
+ }
737
+
738
+ static List <String > chunkStringWithSize (String str , int chunkSize ) {
739
+ List <String > subStrings = new ArrayList <>((int )Math .ceil (str .length ()/(double )chunkSize ));
740
+ for (int i = 0 ; i < str .length ();i += chunkSize ) {
741
+ subStrings .add (str .substring (i , Math .min (i + chunkSize , str .length ())));
742
+ }
743
+ return subStrings ;
688
744
}
689
745
690
746
private TrainedModelConfig .Builder parseInferenceDocLenientlyFromSource (BytesReference source , String modelId ) throws IOException {
0 commit comments