Skip to content

Commit 24d41eb

Browse files
authored
[ML] partitions model definitions into chunks (#55260) (#55484)
This paves the data layer way so that exceptionally large models are partitioned across multiple documents. This change means that nodes before 7.8.0 will not be able to use trained inference models created on nodes on or after 7.8.0. I chose the definition document limit to be 100. This *SHOULD* be plenty for any large model. One of the largest models that I have created so far had the following stats: ~314MB of inflated JSON, ~66MB when compressed, ~177MB of heap. With the chunking sizes of `16 * 1024 * 1024` its compressed string could be partitioned to 5 documents. Supporting models 20 times this size (compressed) seems adequate for now.
1 parent fa0373a commit 24d41eb

File tree

3 files changed

+106
-33
lines changed

3 files changed

+106
-33
lines changed

docs/reference/ml/df-analytics/apis/put-inference.asciidoc

+7-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
++++
99

1010
Creates an {infer} trained model.
11-
11+
+
12+
--
13+
WARNING: Models created in version 7.8.0 are not backwards compatible
14+
with older node versions. If in a mixed cluster environment,
15+
all nodes must be at least 7.8.0 to use a model stored by
16+
a 7.8.0 node.
17+
--
1218
experimental[]
1319

1420

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,18 @@ protected Response read(StreamInput in) throws IOException {
7878
}
7979

8080
@Override
81-
protected void masterOperation(Request request, ClusterState state, ActionListener<Response> listener) {
81+
protected void masterOperation(Request request,
82+
ClusterState state,
83+
ActionListener<Response> listener) {
84+
// 7.8.0 introduced splitting the model definition across multiple documents.
85+
// This means that new models will not be usable on nodes that cannot handle multiple definition documents
86+
if (state.nodes().getMinNodeVersion().before(Version.V_7_8_0)) {
87+
listener.onFailure(ExceptionsHelper.badRequestException(
88+
"Creating a new model requires that all nodes are at least version [{}]",
89+
request.getTrainedModelConfig().getModelId(),
90+
Version.V_7_8_0.toString()));
91+
return;
92+
}
8293
try {
8394
request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
8495
request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

+87-31
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import org.elasticsearch.action.ActionListener;
1616
import org.elasticsearch.action.DocWriteRequest;
1717
import org.elasticsearch.action.bulk.BulkAction;
18-
import org.elasticsearch.action.bulk.BulkRequest;
18+
import org.elasticsearch.action.bulk.BulkItemResponse;
19+
import org.elasticsearch.action.bulk.BulkRequestBuilder;
1920
import org.elasticsearch.action.bulk.BulkResponse;
2021
import org.elasticsearch.action.index.IndexRequest;
2122
import org.elasticsearch.action.search.MultiSearchAction;
@@ -86,6 +87,7 @@
8687
import java.util.Map;
8788
import java.util.Set;
8889
import java.util.TreeSet;
90+
import java.util.stream.Collectors;
8991

9092
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
9193
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -96,6 +98,9 @@ public class TrainedModelProvider {
9698
public static final Set<String> MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1");
9799
private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/";
98100
private static final String MODEL_RESOURCE_FILE_EXT = ".json";
101+
private static final int COMPRESSED_STRING_CHUNK_SIZE = 16 * 1024 * 1024;
102+
private static final int MAX_NUM_DEFINITION_DOCS = 100;
103+
private static final int MAX_COMPRESSED_STRING_SIZE = COMPRESSED_STRING_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS;
99104

100105
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
101106
private final Client client;
@@ -139,30 +144,41 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
139144
private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig,
140145
ActionListener<Boolean> listener) {
141146

142-
TrainedModelDefinitionDoc trainedModelDefinitionDoc;
147+
List<TrainedModelDefinitionDoc> trainedModelDefinitionDocs = new ArrayList<>();
143148
try {
144-
// TODO should we check length against allowed stream size???
145149
String compressedString = trainedModelConfig.getCompressedDefinition();
146-
trainedModelDefinitionDoc = new TrainedModelDefinitionDoc.Builder()
147-
.setDocNum(0)
148-
.setModelId(trainedModelConfig.getModelId())
149-
.setCompressedString(compressedString)
150-
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
151-
.setDefinitionLength(compressedString.length())
152-
.setTotalDefinitionLength(compressedString.length())
153-
.build();
150+
if (compressedString.length() > MAX_COMPRESSED_STRING_SIZE) {
151+
listener.onFailure(
152+
ExceptionsHelper.badRequestException(
153+
"Unable to store model as compressed definition has length [{}] the limit is [{}]",
154+
compressedString.length(),
155+
MAX_COMPRESSED_STRING_SIZE));
156+
return;
157+
}
158+
List<String> chunkedStrings = chunkStringWithSize(compressedString, COMPRESSED_STRING_CHUNK_SIZE);
159+
for(int i = 0; i < chunkedStrings.size(); ++i) {
160+
trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder()
161+
.setDocNum(i)
162+
.setModelId(trainedModelConfig.getModelId())
163+
.setCompressedString(chunkedStrings.get(i))
164+
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
165+
.setDefinitionLength(chunkedStrings.get(i).length())
166+
.setTotalDefinitionLength(compressedString.length())
167+
.build());
168+
}
154169
} catch (IOException ex) {
155170
listener.onFailure(ExceptionsHelper.serverError(
156-
"Unexpected IOException while serializing definition for storage for model [" + trainedModelConfig.getModelId() + "]",
157-
ex));
171+
"Unexpected IOException while serializing definition for storage for model [{}]",
172+
ex,
173+
trainedModelConfig.getModelId()));
158174
return;
159175
}
160176

161-
BulkRequest bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
177+
BulkRequestBuilder bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
162178
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
163-
.add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig))
164-
.add(createRequest(TrainedModelDefinitionDoc.docId(trainedModelConfig.getModelId(), 0), trainedModelDefinitionDoc))
165-
.request();
179+
.add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig));
180+
trainedModelDefinitionDocs.forEach(defDoc ->
181+
bulkRequest.add(createRequest(TrainedModelDefinitionDoc.docId(trainedModelConfig.getModelId(), defDoc.getDocNum()), defDoc)));
166182

167183
ActionListener<Boolean> wrappedListener = ActionListener.wrap(
168184
listener::onResponse,
@@ -182,9 +198,8 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi
182198

183199
ActionListener<BulkResponse> bulkResponseActionListener = ActionListener.wrap(
184200
r -> {
185-
assert r.getItems().length == 2;
201+
assert r.getItems().length == trainedModelDefinitionDocs.size() + 1;
186202
if (r.getItems()[0].isFailed()) {
187-
188203
logger.error(new ParameterizedMessage(
189204
"[{}] failed to store trained model config for inference",
190205
trainedModelConfig.getModelId()),
@@ -193,20 +208,26 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi
193208
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
194209
return;
195210
}
196-
if (r.getItems()[1].isFailed()) {
211+
if (r.hasFailures()) {
212+
Exception firstFailure = Arrays.stream(r.getItems())
213+
.filter(BulkItemResponse::isFailed)
214+
.map(BulkItemResponse::getFailure)
215+
.map(BulkItemResponse.Failure::getCause)
216+
.findFirst()
217+
.orElse(new Exception("unknown failure"));
197218
logger.error(new ParameterizedMessage(
198219
"[{}] failed to store trained model definition for inference",
199220
trainedModelConfig.getModelId()),
200-
r.getItems()[1].getFailure().getCause());
201-
wrappedListener.onFailure(r.getItems()[1].getFailure().getCause());
221+
firstFailure);
222+
wrappedListener.onFailure(firstFailure);
202223
return;
203224
}
204225
wrappedListener.onResponse(true);
205226
},
206227
wrappedListener::onFailure
207228
);
208229

209-
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest, bulkResponseActionListener);
230+
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener);
210231
}
211232

212233
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
@@ -235,11 +256,20 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio
235256
if (includeDefinition) {
236257
multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
237258
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
238-
.idsQuery()
239-
.addIds(TrainedModelDefinitionDoc.docId(modelId, 0))))
240-
// use sort to get the last
259+
.boolQuery()
260+
.filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
261+
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME))))
262+
// There should be AT MOST these many docs. There might be more if definitions have been reindex to newer indices
263+
// If this ends up getting duplicate groups of definition documents, the parsing logic will throw away any doc that
264+
// is in a different index than the first index seen.
265+
.setSize(MAX_NUM_DEFINITION_DOCS)
266+
// First find the latest index
241267
.addSort("_index", SortOrder.DESC)
242-
.setSize(1)
268+
// Then, sort by doc_num
269+
.addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName())
270+
.order(SortOrder.ASC)
271+
// We need this for the search not to fail when there are no mappings yet in the index
272+
.unmappedType("long"))
243273
.request());
244274
}
245275

@@ -259,15 +289,18 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio
259289

260290
if (includeDefinition) {
261291
try {
262-
TrainedModelDefinitionDoc doc = handleSearchItem(multiSearchResponse.getResponses()[1],
292+
List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
263293
modelId,
264294
this::parseModelDefinitionDocLenientlyFromSource);
265-
if (doc.getCompressedString().length() != doc.getTotalDefinitionLength()) {
295+
String compressedString = docs.stream()
296+
.map(TrainedModelDefinitionDoc::getCompressedString)
297+
.collect(Collectors.joining());
298+
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
266299
listener.onFailure(ExceptionsHelper.serverError(
267300
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
268301
return;
269302
}
270-
builder.setDefinitionFromString(doc.getCompressedString());
303+
builder.setDefinitionFromString(compressedString);
271304
} catch (ResourceNotFoundException ex) {
272305
listener.onFailure(new ResourceNotFoundException(
273306
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
@@ -678,13 +711,36 @@ private Set<String> matchedResourceIds(String[] tokens) {
678711
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
679712
String resourceId,
680713
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
714+
return handleSearchItems(item, resourceId, parseLeniently).get(0);
715+
}
716+
717+
// NOTE: This ignores any results that are in a different index than the first one seen in the search response.
718+
private static <T> List<T> handleSearchItems(MultiSearchResponse.Item item,
719+
String resourceId,
720+
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
681721
if (item.isFailure()) {
682722
throw item.getFailure();
683723
}
684724
if (item.getResponse().getHits().getHits().length == 0) {
685725
throw new ResourceNotFoundException(resourceId);
686726
}
687-
return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId);
727+
List<T> results = new ArrayList<>(item.getResponse().getHits().getHits().length);
728+
String initialIndex = item.getResponse().getHits().getHits()[0].getIndex();
729+
for (SearchHit hit : item.getResponse().getHits().getHits()) {
730+
// We don't want to spread across multiple backing indices
731+
if (hit.getIndex().equals(initialIndex)) {
732+
results.add(parseLeniently.apply(hit.getSourceRef(), resourceId));
733+
}
734+
}
735+
return results;
736+
}
737+
738+
static List<String> chunkStringWithSize(String str, int chunkSize) {
739+
List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
740+
for (int i = 0; i < str.length();i += chunkSize) {
741+
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
742+
}
743+
return subStrings;
688744
}
689745

690746
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException {

0 commit comments

Comments
 (0)