Skip to content

[ML][HLRC] adds put and delete trained model alias APIs to rest high-level client #69214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.client.ml.DeleteForecastRequest;
import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelAliasRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
Expand Down Expand Up @@ -62,6 +63,7 @@
import org.elasticsearch.client.ml.PutDatafeedRequest;
import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutTrainedModelAliasRequest;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
Expand Down Expand Up @@ -857,6 +859,32 @@ static Request putTrainedModel(PutTrainedModelRequest putTrainedModelRequest) th
return request;
}

static Request putTrainedModelAlias(PutTrainedModelAliasRequest putTrainedModelAliasRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml", "trained_models")
.addPathPart(putTrainedModelAliasRequest.getModelId())
.addPathPartAsIs("model_aliases")
.addPathPart(putTrainedModelAliasRequest.getModelAlias())
.build();
Request request = new Request(HttpPut.METHOD_NAME, endpoint);
RequestConverters.Params params = new RequestConverters.Params();
if (putTrainedModelAliasRequest.getReassign() != null) {
params.putParam(PutTrainedModelAliasRequest.REASSIGN, Boolean.toString(putTrainedModelAliasRequest.getReassign()));
}
request.addParameters(params.asMap());
return request;
}

static Request deleteTrainedModelAlias(DeleteTrainedModelAliasRequest deleteTrainedModelAliasRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml", "trained_models")
.addPathPart(deleteTrainedModelAliasRequest.getModelId())
.addPathPartAsIs("model_aliases")
.addPathPart(deleteTrainedModelAliasRequest.getModelAlias())
.build();
return new Request(HttpDelete.METHOD_NAME, endpoint);
}

static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.ml.CloseJobRequest;
import org.elasticsearch.client.ml.CloseJobResponse;
import org.elasticsearch.client.ml.DeleteTrainedModelAliasRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryResponse;
Expand Down Expand Up @@ -89,6 +90,7 @@
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
import org.elasticsearch.client.ml.PutTrainedModelAliasRequest;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
Expand Down Expand Up @@ -2552,4 +2554,90 @@ public Cancellable deleteTrainedModelAsync(DeleteTrainedModelRequest request,
listener,
Collections.emptySet());
}

/**
* Creates or reassigns a trained model alias
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-models-aliases.html">
* Put Trained Model Aliases documentation</a>
*
* @param request The {@link PutTrainedModelAliasRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @return action acknowledgement
* @throws IOException when there is a serialization issue sending the request or receiving the response
*/
public AcknowledgedResponse putTrainedModelAlias(PutTrainedModelAliasRequest request, RequestOptions options) throws IOException {
return restHighLevelClient.performRequestAndParseEntity(request,
MLRequestConverters::putTrainedModelAlias,
options,
AcknowledgedResponse::fromXContent,
Collections.emptySet());
}

/**
* Creates or reassigns a trained model alias asynchronously and notifies listener upon completion
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-models-aliases.html">
* Put Trained Model Aliases documentation</a>
*
* @param request The {@link PutTrainedModelAliasRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @param listener Listener to be notified upon request completion
* @return cancellable that may be used to cancel the request
*/
public Cancellable putTrainedModelAliasAsync(PutTrainedModelAliasRequest request,
RequestOptions options,
ActionListener<AcknowledgedResponse> listener) {
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
MLRequestConverters::putTrainedModelAlias,
options,
AcknowledgedResponse::fromXContent,
listener,
Collections.emptySet());
}

/**
* Deletes a trained model alias
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/delete-trained-models-aliases.html">
* Delete Trained Model Aliases documentation</a>
*
* @param request The {@link DeleteTrainedModelAliasRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @return action acknowledgement
* @throws IOException when there is a serialization issue sending the request or receiving the response
*/
public AcknowledgedResponse deleteTrainedModelAlias(DeleteTrainedModelAliasRequest request, RequestOptions options) throws IOException {
return restHighLevelClient.performRequestAndParseEntity(request,
MLRequestConverters::deleteTrainedModelAlias,
options,
AcknowledgedResponse::fromXContent,
Collections.emptySet());
}

/**
* Deletes a trained model alias asynchronously and notifies listener upon completion
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/delete-trained-models-aliases.html">
* Delete Trained Model Aliases documentation</a>
*
* @param request The {@link DeleteTrainedModelAliasRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @param listener Listener to be notified upon request completion
* @return cancellable that may be used to cancel the request
*/
public Cancellable deleteTrainedModelAliasAsync(DeleteTrainedModelAliasRequest request,
RequestOptions options,
ActionListener<AcknowledgedResponse> listener) {
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
MLRequestConverters::deleteTrainedModelAlias,
options,
AcknowledgedResponse::fromXContent,
listener,
Collections.emptySet());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml;

import org.elasticsearch.client.Validatable;

import java.util.Objects;

public class DeleteTrainedModelAliasRequest implements Validatable {

private final String modelAlias;
private final String modelId;

public DeleteTrainedModelAliasRequest(String modelAlias, String modelId) {
this.modelAlias = Objects.requireNonNull(modelAlias);
this.modelId = Objects.requireNonNull(modelId);
}

public String getModelAlias() {
return modelAlias;
}

public String getModelId() {
return modelId;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DeleteTrainedModelAliasRequest request = (DeleteTrainedModelAliasRequest) o;
return Objects.equals(modelAlias, request.modelAlias)
&& Objects.equals(modelId, request.modelId);
}

@Override
public int hashCode() {
return Objects.hash(modelAlias, modelId);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml;

import org.elasticsearch.client.Validatable;

import java.util.Objects;

public class PutTrainedModelAliasRequest implements Validatable {

public static final String REASSIGN = "reassign";

private final String modelAlias;
private final String modelId;
private final Boolean reassign;

public PutTrainedModelAliasRequest(String modelAlias, String modelId, Boolean reassign) {
this.modelAlias = Objects.requireNonNull(modelAlias);
this.modelId = Objects.requireNonNull(modelId);
this.reassign = reassign;
}

public String getModelAlias() {
return modelAlias;
}

public String getModelId() {
return modelId;
}

public Boolean getReassign() {
return reassign;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PutTrainedModelAliasRequest request = (PutTrainedModelAliasRequest) o;
return Objects.equals(modelAlias, request.modelAlias)
&& Objects.equals(modelId, request.modelId)
&& Objects.equals(reassign, request.reassign);
}

@Override
public int hashCode() {
return Objects.hash(modelAlias, modelId, reassign);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.client.ml.DeleteForecastRequest;
import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelAliasRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
Expand Down Expand Up @@ -59,6 +60,7 @@
import org.elasticsearch.client.ml.PutDatafeedRequest;
import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutTrainedModelAliasRequest;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
Expand Down Expand Up @@ -119,7 +121,9 @@
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.core.IsNull.nullValue;

public class MLRequestConvertersTests extends ESTestCase {
Expand Down Expand Up @@ -965,6 +969,52 @@ public void testPutTrainedModel() throws IOException {
}
}

public void testPutTrainedModelAlias() throws IOException {
PutTrainedModelAliasRequest putTrainedModelAliasRequest = new PutTrainedModelAliasRequest(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomBoolean() ? null : randomBoolean()
);

Request request = MLRequestConverters.putTrainedModelAlias(putTrainedModelAliasRequest);

assertEquals(HttpPut.METHOD_NAME, request.getMethod());
assertThat(
request.getEndpoint(),
equalTo(
"/_ml/trained_models/"
+ putTrainedModelAliasRequest.getModelId()
+ "/model_aliases/"
+ putTrainedModelAliasRequest.getModelAlias()
)
);
if (putTrainedModelAliasRequest.getReassign() != null) {
assertThat(request.getParameters().get("reassign"), equalTo(putTrainedModelAliasRequest.getReassign().toString()));
} else {
assertThat(request.getParameters(), not(hasKey("reassign")));
}
}

public void testDeleteTrainedModelAlias() throws IOException {
DeleteTrainedModelAliasRequest deleteTrainedModelAliasRequest = new DeleteTrainedModelAliasRequest(
randomAlphaOfLength(10),
randomAlphaOfLength(10)
);

Request request = MLRequestConverters.deleteTrainedModelAlias(deleteTrainedModelAliasRequest);

assertEquals(HttpDelete.METHOD_NAME, request.getMethod());
assertThat(
request.getEndpoint(),
equalTo(
"/_ml/trained_models/"
+ deleteTrainedModelAliasRequest.getModelId()
+ "/model_aliases/"
+ deleteTrainedModelAliasRequest.getModelAlias()
)
);
}

public void testPutFilter() throws IOException {
MlFilter filter = MlFilterTests.createRandomBuilder("foo").build();
PutFilterRequest putFilterRequest = new PutFilterRequest(filter);
Expand Down
Loading