|
9 | 9 |
|
10 | 10 | import org.elasticsearch.ElasticsearchStatusException;
|
11 | 11 | import org.elasticsearch.action.ActionListener;
|
| 12 | +import org.elasticsearch.common.Strings; |
12 | 13 | import org.elasticsearch.core.TimeValue;
|
13 | 14 | import org.elasticsearch.inference.InferenceService;
|
| 15 | +import org.elasticsearch.inference.InferenceServiceResults; |
14 | 16 | import org.elasticsearch.inference.Model;
|
| 17 | +import org.elasticsearch.inference.TaskType; |
15 | 18 | import org.elasticsearch.rest.RestStatus;
|
| 19 | +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; |
| 20 | +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; |
| 21 | +import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel; |
16 | 22 |
|
17 | 23 | public class ElasticsearchInternalServiceModelValidator implements ModelValidator {
|
18 | 24 |
|
19 |
| - ModelValidator modelValidator; |
| 25 | + private final ServiceIntegrationValidator serviceIntegrationValidator; |
20 | 26 |
|
21 |
| - public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) { |
22 |
| - this.modelValidator = modelValidator; |
| 27 | + public ElasticsearchInternalServiceModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) { |
| 28 | + this.serviceIntegrationValidator = serviceIntegrationValidator; |
23 | 29 | }
|
24 | 30 |
|
25 | 31 | @Override
|
26 | 32 | public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener<Model> listener) {
|
27 |
| - service.start(model, timeout, ActionListener.wrap((modelDeploymentStarted) -> { |
28 |
| - if (modelDeploymentStarted) { |
29 |
| - try { |
30 |
| - modelValidator.validate(service, model, timeout, listener.delegateResponse((l, exception) -> { |
31 |
| - stopModelDeployment(service, model, l, exception); |
32 |
| - })); |
33 |
| - } catch (Exception e) { |
34 |
| - stopModelDeployment(service, model, listener, e); |
35 |
| - } |
36 |
| - } else { |
37 |
| - listener.onFailure( |
38 |
| - new ElasticsearchStatusException("Could not deploy model for inference endpoint", RestStatus.INTERNAL_SERVER_ERROR) |
| 33 | + if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) { |
| 34 | + var temporaryModelWithModelId = new CustomElandEmbeddingModel( |
| 35 | + elandModel.getServiceSettings().modelId(), |
| 36 | + elandModel.getTaskType(), |
| 37 | + elandModel.getConfigurations().getService(), |
| 38 | + elandModel.getServiceSettings(), |
| 39 | + elandModel.getConfigurations().getChunkingSettings() |
| 40 | + ); |
| 41 | + |
| 42 | + serviceIntegrationValidator.validate( |
| 43 | + service, |
| 44 | + temporaryModelWithModelId, |
| 45 | + timeout, |
| 46 | + listener.delegateFailureAndWrap((delegate, r) -> { |
| 47 | + delegate.onResponse(postValidate(service, model, r)); |
| 48 | + }) |
| 49 | + ); |
| 50 | + } else { |
| 51 | + listener.onResponse(model); |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) { |
| 56 | + if (results instanceof TextEmbeddingResults<?> embeddingResults) { |
| 57 | + var serviceSettings = model.getServiceSettings(); |
| 58 | + var dimensions = serviceSettings.dimensions(); |
| 59 | + int embeddingSize = getEmbeddingSize(embeddingResults); |
| 60 | + |
| 61 | + if (Boolean.TRUE.equals(serviceSettings.dimensionsSetByUser()) |
| 62 | + && dimensions != null |
| 63 | + && (dimensions.equals(embeddingSize) == false)) { |
| 64 | + throw new ElasticsearchStatusException( |
| 65 | + Strings.format( |
| 66 | + "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " |
| 67 | + + "Please recreate the [%s] configuration with the correct dimensions", |
| 68 | + embeddingResults.getFirstEmbeddingSize(), |
| 69 | + serviceSettings.dimensions(), |
| 70 | + model.getInferenceEntityId() |
| 71 | + ), |
| 72 | + RestStatus.BAD_REQUEST |
39 | 73 | );
|
40 | 74 | }
|
41 |
| - }, listener::onFailure)); |
| 75 | + |
| 76 | + return service.updateModelWithEmbeddingDetails(model, embeddingSize); |
| 77 | + } else { |
| 78 | + throw new ElasticsearchStatusException( |
| 79 | + "Validation call did not return expected results type." |
| 80 | + + "Expected a result of type [" |
| 81 | + + TextEmbeddingFloatResults.NAME |
| 82 | + + "] got [" |
| 83 | + + (results == null ? "null" : results.getWriteableName()) |
| 84 | + + "]", |
| 85 | + RestStatus.BAD_REQUEST |
| 86 | + ); |
| 87 | + } |
42 | 88 | }
|
43 | 89 |
|
44 |
| - private void stopModelDeployment(InferenceService service, Model model, ActionListener<Model> listener, Exception e) { |
45 |
| - service.stop( |
46 |
| - model, |
47 |
| - ActionListener.wrap( |
48 |
| - (v) -> listener.onFailure(e), |
49 |
| - (ex) -> listener.onFailure( |
50 |
| - new ElasticsearchStatusException( |
51 |
| - "Model validation failed and model deployment could not be stopped", |
52 |
| - RestStatus.INTERNAL_SERVER_ERROR, |
53 |
| - ex |
54 |
| - ) |
55 |
| - ) |
56 |
| - ) |
57 |
| - ); |
| 90 | + private int getEmbeddingSize(TextEmbeddingResults<?> embeddingResults) { |
| 91 | + int embeddingSize; |
| 92 | + try { |
| 93 | + embeddingSize = embeddingResults.getFirstEmbeddingSize(); |
| 94 | + } catch (Exception e) { |
| 95 | + throw new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e); |
| 96 | + } |
| 97 | + return embeddingSize; |
58 | 98 | }
|
59 | 99 | }
|
0 commit comments