Skip to content

Commit fa116a6

Browse files
authored
[7.x] [ML][Inference] PUT API (elastic#50852) (elastic#50887)
* [ML][Inference] PUT API (elastic#50852) This adds the `PUT` API for creating trained models that support our format. This includes * HLRC change for the API * API creation * Validations of model format and call * fixing backport
1 parent 456de59 commit fa116a6

File tree

38 files changed

+1638
-413
lines changed

38 files changed

+1638
-413
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

+11
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.elasticsearch.client.ml.PutDatafeedRequest;
7474
import org.elasticsearch.client.ml.PutFilterRequest;
7575
import org.elasticsearch.client.ml.PutJobRequest;
76+
import org.elasticsearch.client.ml.PutTrainedModelRequest;
7677
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
7778
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
7879
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@@ -792,6 +793,16 @@ static Request deleteTrainedModel(DeleteTrainedModelRequest deleteRequest) {
792793
return new Request(HttpDelete.METHOD_NAME, endpoint);
793794
}
794795

796+
static Request putTrainedModel(PutTrainedModelRequest putTrainedModelRequest) throws IOException {
797+
String endpoint = new EndpointBuilder()
798+
.addPathPartAsIs("_ml", "inference")
799+
.addPathPart(putTrainedModelRequest.getTrainedModelConfig().getModelId())
800+
.build();
801+
Request request = new Request(HttpPut.METHOD_NAME, endpoint);
802+
request.setEntity(createEntity(putTrainedModelRequest, REQUEST_BODY_CONTENT_TYPE));
803+
return request;
804+
}
805+
795806
static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
796807
String endpoint = new EndpointBuilder()
797808
.addPathPartAsIs("_ml")

client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java

+44
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@
100100
import org.elasticsearch.client.ml.PutFilterResponse;
101101
import org.elasticsearch.client.ml.PutJobRequest;
102102
import org.elasticsearch.client.ml.PutJobResponse;
103+
import org.elasticsearch.client.ml.PutTrainedModelRequest;
104+
import org.elasticsearch.client.ml.PutTrainedModelResponse;
103105
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
104106
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
105107
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -2340,6 +2342,48 @@ public Cancellable getTrainedModelsAsync(GetTrainedModelsRequest request,
23402342
Collections.emptySet());
23412343
}
23422344

2345+
/**
2346+
* Put trained model config
2347+
* <p>
2348+
* For additional info
2349+
* see <a href="TODO">
2350+
* PUT Trained Model Config documentation</a>
2351+
*
2352+
* @param request The {@link PutTrainedModelRequest}
2353+
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
2354+
* @return {@link PutTrainedModelResponse} response object
2355+
*/
2356+
public PutTrainedModelResponse putTrainedModel(PutTrainedModelRequest request, RequestOptions options) throws IOException {
2357+
return restHighLevelClient.performRequestAndParseEntity(request,
2358+
MLRequestConverters::putTrainedModel,
2359+
options,
2360+
PutTrainedModelResponse::fromXContent,
2361+
Collections.emptySet());
2362+
}
2363+
2364+
/**
2365+
* Put trained model config asynchronously and notifies listener upon completion
2366+
* <p>
2367+
* For additional info
2368+
* see <a href="TODO">
2369+
* PUT Trained Model Config documentation</a>
2370+
*
2371+
* @param request The {@link PutTrainedModelRequest}
2372+
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
2373+
* @param listener Listener to be notified upon request completion
2374+
* @return cancellable that may be used to cancel the request
2375+
*/
2376+
public Cancellable putTrainedModelAsync(PutTrainedModelRequest request,
2377+
RequestOptions options,
2378+
ActionListener<PutTrainedModelResponse> listener) {
2379+
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
2380+
MLRequestConverters::putTrainedModel,
2381+
options,
2382+
PutTrainedModelResponse::fromXContent,
2383+
listener,
2384+
Collections.emptySet());
2385+
}
2386+
23432387
/**
23442388
* Gets trained model stats
23452389
* <p>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml;
20+
21+
import org.elasticsearch.client.Validatable;
22+
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
23+
import org.elasticsearch.common.Strings;
24+
import org.elasticsearch.common.xcontent.ToXContent;
25+
import org.elasticsearch.common.xcontent.ToXContentObject;
26+
import org.elasticsearch.common.xcontent.XContentBuilder;
27+
28+
import java.io.IOException;
29+
import java.util.Objects;
30+
31+
32+
public class PutTrainedModelRequest implements Validatable, ToXContentObject {
33+
34+
private final TrainedModelConfig config;
35+
36+
public PutTrainedModelRequest(TrainedModelConfig config) {
37+
this.config = config;
38+
}
39+
40+
public TrainedModelConfig getTrainedModelConfig() {
41+
return config;
42+
}
43+
44+
@Override
45+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
46+
return config.toXContent(builder, params);
47+
}
48+
49+
@Override
50+
public boolean equals(Object o) {
51+
if (this == o) return true;
52+
if (o == null || getClass() != o.getClass()) return false;
53+
PutTrainedModelRequest request = (PutTrainedModelRequest) o;
54+
return Objects.equals(config, request.config);
55+
}
56+
57+
@Override
58+
public int hashCode() {
59+
return Objects.hash(config);
60+
}
61+
62+
@Override
63+
public final String toString() {
64+
return Strings.toString(config);
65+
}
66+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml;
20+
21+
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
22+
import org.elasticsearch.common.xcontent.ToXContentObject;
23+
import org.elasticsearch.common.xcontent.XContentBuilder;
24+
import org.elasticsearch.common.xcontent.XContentParser;
25+
26+
import java.io.IOException;
27+
import java.util.Objects;
28+
29+
30+
public class PutTrainedModelResponse implements ToXContentObject {
31+
32+
private final TrainedModelConfig trainedModelConfig;
33+
34+
public static PutTrainedModelResponse fromXContent(XContentParser parser) throws IOException {
35+
return new PutTrainedModelResponse(TrainedModelConfig.PARSER.parse(parser, null).build());
36+
}
37+
38+
public PutTrainedModelResponse(TrainedModelConfig trainedModelConfig) {
39+
this.trainedModelConfig = trainedModelConfig;
40+
}
41+
42+
public TrainedModelConfig getResponse() {
43+
return trainedModelConfig;
44+
}
45+
46+
@Override
47+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
48+
return trainedModelConfig.toXContent(builder, params);
49+
}
50+
51+
@Override
52+
public boolean equals(Object o) {
53+
if (this == o) return true;
54+
if (o == null || getClass() != o.getClass()) return false;
55+
PutTrainedModelResponse response = (PutTrainedModelResponse) o;
56+
return Objects.equals(trainedModelConfig, response.trainedModelConfig);
57+
}
58+
59+
@Override
60+
public int hashCode() {
61+
return Objects.hash(trainedModelConfig);
62+
}
63+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.ml.inference;
21+
22+
import org.elasticsearch.common.CheckedFunction;
23+
import org.elasticsearch.common.bytes.BytesArray;
24+
import org.elasticsearch.common.bytes.BytesReference;
25+
import org.elasticsearch.common.io.Streams;
26+
import org.elasticsearch.common.io.stream.BytesStreamOutput;
27+
import org.elasticsearch.common.xcontent.DeprecationHandler;
28+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
29+
import org.elasticsearch.common.xcontent.ToXContentObject;
30+
import org.elasticsearch.common.xcontent.XContentHelper;
31+
import org.elasticsearch.common.xcontent.XContentParser;
32+
import org.elasticsearch.common.xcontent.XContentType;
33+
34+
import java.io.IOException;
35+
import java.io.InputStream;
36+
import java.io.OutputStream;
37+
import java.nio.charset.StandardCharsets;
38+
import java.util.Base64;
39+
import java.util.zip.GZIPInputStream;
40+
import java.util.zip.GZIPOutputStream;
41+
42+
/**
43+
* Collection of helper methods. Similar to CompressedXContent, but this utilizes GZIP.
44+
*/
45+
public final class InferenceToXContentCompressor {
46+
private static final int BUFFER_SIZE = 4096;
47+
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
48+
49+
private InferenceToXContentCompressor() {}
50+
51+
public static <T extends ToXContentObject> String deflate(T objectToCompress) throws IOException {
52+
BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false);
53+
return deflate(reference);
54+
}
55+
56+
public static <T> T inflate(String compressedString,
57+
CheckedFunction<XContentParser, T, IOException> parserFunction,
58+
NamedXContentRegistry xContentRegistry) throws IOException {
59+
try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
60+
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
61+
inflate(compressedString, MAX_INFLATED_BYTES),
62+
XContentType.JSON)) {
63+
return parserFunction.apply(parser);
64+
}
65+
}
66+
67+
static BytesReference inflate(String compressedString, long streamSize) throws IOException {
68+
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
69+
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
70+
InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
71+
return Streams.readFully(inflateStream);
72+
}
73+
74+
private static String deflate(BytesReference reference) throws IOException {
75+
BytesStreamOutput out = new BytesStreamOutput();
76+
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
77+
reference.writeTo(compressedOutput);
78+
}
79+
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
80+
}
81+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference;
20+
21+
22+
import java.io.IOException;
23+
import java.io.InputStream;
24+
import java.util.Objects;
25+
26+
/**
27+
* This is a pared down bounded input stream.
28+
* Only read is specifically enforced.
29+
*/
30+
final class SimpleBoundedInputStream extends InputStream {
31+
32+
private final InputStream in;
33+
private final long maxBytes;
34+
private long numBytes;
35+
36+
SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
37+
this.in = Objects.requireNonNull(inputStream, "inputStream");
38+
if (maxBytes < 0) {
39+
throw new IllegalArgumentException("[maxBytes] must be greater than or equal to 0");
40+
}
41+
this.maxBytes = maxBytes;
42+
}
43+
44+
45+
/**
46+
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
47+
* @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
48+
* @throws IOException on failure
49+
*/
50+
@Override
51+
public int read() throws IOException {
52+
// We have reached the maximum, signal stream completion.
53+
if (numBytes >= maxBytes) {
54+
return -1;
55+
}
56+
numBytes++;
57+
return in.read();
58+
}
59+
60+
/**
61+
* Delegates `close` to the wrapped InputStream
62+
* @throws IOException on failure
63+
*/
64+
@Override
65+
public void close() throws IOException {
66+
in.close();
67+
}
68+
}

0 commit comments

Comments
 (0)