Skip to content

Commit f69faa7

Browse files
authored
[ML] simplifying model license checks (#80031)
Since trained models require either a platinum or basic license, our license checking can be simplified. Now, we check if ML APIs are allowed (platinum) or if the trained model is `basic`. relates to: #79908
1 parent e6e41c0 commit f69faa7

File tree

5 files changed

+11
-21
lines changed

5 files changed

+11
-21
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.elasticsearch.core.TimeValue;
1414
import org.elasticsearch.license.License;
1515
import org.elasticsearch.license.LicensedFeature;
16-
import org.elasticsearch.license.XPackLicenseState;
1716

1817
import java.math.BigInteger;
1918
import java.nio.charset.StandardCharsets;
@@ -41,12 +40,6 @@ public final class MachineLearningField {
4140
License.OperationMode.PLATINUM
4241
);
4342

44-
public static final LicensedFeature.Momentary ML_MODEL_INFERENCE_PLATINUM_FEATURE = LicensedFeature.momentary(
45-
MachineLearningField.ML_FEATURE_FAMILY,
46-
"model-inference-platinum-check",
47-
License.OperationMode.PLATINUM
48-
);
49-
5043
private MachineLearningField() {}
5144

5245
public static String valuesToId(String... values) {
@@ -59,10 +52,4 @@ public static String valuesToId(String... values) {
5952
return new BigInteger(hashedBytes) + "_" + combined.length();
6053
}
6154

62-
public static boolean featureCheckForMode(License.OperationMode mode, XPackLicenseState licenseState) {
63-
if (mode.equals(License.OperationMode.PLATINUM)) {
64-
return ML_MODEL_INFERENCE_PLATINUM_FEATURE.check(licenseState);
65-
}
66-
return true;
67-
}
6855
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ public long getEstimatedOperations() {
351351
}
352352

353353
// TODO if we ever support anything other than "basic" and platinum, we need to adjust our feature tracking logic
354-
// Additionally, see `MachineLearningField. featureCheckForMode` for handling modes
354+
// and we need to adjust our license checks to validate more than "is basic" or not
355355
public License.OperationMode getLicenseLevel() {
356356
return licenseLevel;
357357
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.cluster.service.ClusterService;
1616
import org.elasticsearch.common.inject.Inject;
1717
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.license.License;
1819
import org.elasticsearch.license.LicenseUtils;
1920
import org.elasticsearch.license.XPackLicenseState;
2021
import org.elasticsearch.rest.RestStatus;
@@ -42,7 +43,6 @@
4243

4344
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
4445
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
45-
import static org.elasticsearch.xpack.core.ml.MachineLearningField.featureCheckForMode;
4646

4747
public class TransportInternalInferModelAction extends HandledTransportAction<Request, Response> {
4848

@@ -83,7 +83,9 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
8383
request.getModelId(),
8484
GetTrainedModelsAction.Includes.empty(),
8585
ActionListener.wrap(trainedModelConfig -> {
86-
final boolean allowed = featureCheckForMode(trainedModelConfig.getLicenseLevel(), licenseState);
86+
// Since we just checked MachineLearningField.ML_API_FEATURE.check(licenseState) and that check failed
87+
// That means we don't have a plat+ license. The only licenses for trained models are basic (free) and plat.
88+
boolean allowed = trainedModelConfig.getLicenseLevel() == License.OperationMode.BASIC;
8789
responseBuilder.setLicensed(allowed);
8890
if (allowed || request.isPreviouslyLicensed()) {
8991
doInfer(request, responseBuilder, listener);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.common.logging.DeprecationLogger;
2525
import org.elasticsearch.common.logging.HeaderWarning;
2626
import org.elasticsearch.common.util.set.Sets;
27+
import org.elasticsearch.license.License;
2728
import org.elasticsearch.license.LicenseUtils;
2829
import org.elasticsearch.license.XPackLicenseState;
2930
import org.elasticsearch.tasks.Task;
@@ -47,7 +48,6 @@
4748
import java.util.Set;
4849
import java.util.function.Predicate;
4950

50-
import static org.elasticsearch.xpack.core.ml.MachineLearningField.featureCheckForMode;
5151
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY;
5252

5353
public class TransportPutTrainedModelAliasAction extends AcknowledgedTransportMasterNodeAction<PutTrainedModelAliasAction.Request> {
@@ -93,7 +93,8 @@ protected void masterOperation(
9393
) throws Exception {
9494
final boolean mlSupported = MachineLearningField.ML_API_FEATURE.check(licenseState);
9595
final Predicate<TrainedModelConfig> isLicensed = (model) -> mlSupported
96-
|| featureCheckForMode(model.getLicenseLevel(), licenseState);
96+
// Either we support plat+ or the model is basic licensed
97+
|| model.getLicenseLevel() == License.OperationMode.BASIC;
9798
final String oldModelId = ModelAliasMetadata.fromState(state).getModelId(request.getModelAlias());
9899

99100
if (oldModelId != null && (request.isReassign() == false)) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.io.stream.StreamOutput;
1616
import org.elasticsearch.common.settings.Settings;
1717
import org.elasticsearch.index.query.QueryRewriteContext;
18+
import org.elasticsearch.license.License;
1819
import org.elasticsearch.license.LicenseUtils;
1920
import org.elasticsearch.license.XPackLicenseState;
2021
import org.elasticsearch.plugins.SearchPlugin;
@@ -51,7 +52,6 @@
5152
import java.util.function.Supplier;
5253

5354
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
54-
import static org.elasticsearch.xpack.core.ml.MachineLearningField.featureCheckForMode;
5555
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
5656

5757
public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
@@ -267,8 +267,8 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
267267
.getModelForSearch(modelId, listener.delegateFailure((delegate, model) -> {
268268
loadedModel.set(model);
269269

270-
boolean isLicensed = MachineLearningField.ML_API_FEATURE.check(licenseState)
271-
|| featureCheckForMode(model.getLicenseLevel(), licenseState);
270+
boolean isLicensed = model.getLicenseLevel() == License.OperationMode.BASIC
271+
|| MachineLearningField.ML_API_FEATURE.check(licenseState);
272272
if (isLicensed) {
273273
delegate.onResponse(null);
274274
} else {

0 commit comments

Comments
 (0)