Skip to content

Commit 6ecabf8

Browse files
authored
[ML] Fix check on E5 model platform compatibility (#113437) (#113776)
Creating an endpoint for the built in multilingual e5 model failed for linux optimised version due to an error in the logic that checks model compatibility. # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
1 parent 62147b8 commit 6ecabf8

File tree

4 files changed

+86
-30
lines changed

4 files changed

+86
-30
lines changed

docs/changelog/113437.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 113437
2+
summary: Fix check on E5 model platform compatibility
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 113577

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
2424

2525
public void testPutE5Small_withNoModelVariant() {
2626
{
27-
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
27+
String inferenceEntityId = "testPutE5Small_withNoModelVariant";
2828
expectThrows(
2929
org.elasticsearch.client.ResponseException.class,
3030
() -> putTextEmbeddingModel(inferenceEntityId, noModelIdVariantJsonEntity())
@@ -33,7 +33,7 @@ public void testPutE5Small_withNoModelVariant() {
3333
}
3434

3535
public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
36-
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
36+
String inferenceEntityId = "teste5mall_withplatformagnosticvariant";
3737
putTextEmbeddingModel(inferenceEntityId, platformAgnosticModelVariantJsonEntity());
3838
var models = getTrainedModel("_all");
3939
assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
@@ -50,9 +50,8 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
5050
deleteTextEmbeddingModel(inferenceEntityId);
5151
}
5252

53-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198")
5453
public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
55-
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
54+
String inferenceEntityId = "teste5mall_withplatformspecificvariant";
5655
if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) {
5756
putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity());
5857
var models = getTrainedModel("_all");
@@ -77,7 +76,7 @@ public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
7776
}
7877

7978
public void testPutE5Small_withFakeModelVariant() {
80-
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
79+
String inferenceEntityId = "teste5mall_withfakevariant";
8180
expectThrows(
8281
org.elasticsearch.client.ResponseException.class,
8382
() -> putTextEmbeddingModel(inferenceEntityId, fakeModelVariantJsonEntity())
@@ -112,7 +111,7 @@ private Map<String, Object> putTextEmbeddingModel(String inferenceEntityId, Stri
112111
private String noModelIdVariantJsonEntity() {
113112
return """
114113
{
115-
"service": "text_embedding",
114+
"service": "elasticsearch",
116115
"service_settings": {
117116
"num_allocations": 1,
118117
"num_threads": 1

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,14 @@ private void e5Case(
169169
Map<String, Object> serviceSettingsMap,
170170
ActionListener<Model> modelListener
171171
) {
172-
var e5ServiceSettings = MultilingualE5SmallInternalServiceSettings.fromMap(serviceSettingsMap);
172+
var esServiceSettingsBuilder = MultilingualE5SmallInternalServiceSettings.fromMap(serviceSettingsMap);
173173

174-
if (e5ServiceSettings.getModelId() == null) {
175-
e5ServiceSettings.setModelId(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures));
176-
}
177-
178-
if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, e5ServiceSettings)) {
174+
if (esServiceSettingsBuilder.getModelId() == null) {
175+
esServiceSettingsBuilder.setModelId(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures));
176+
} else if (modelVariantValidForArchitecture(platformArchitectures, esServiceSettingsBuilder.getModelId()) == false) {
179177
throw new IllegalArgumentException(
180178
"Error parsing request config, model id does not match any models available on this platform. Was ["
181-
+ e5ServiceSettings.getModelId()
179+
+ esServiceSettingsBuilder.getModelId()
182180
+ "]"
183181
);
184182
}
@@ -191,17 +189,18 @@ private void e5Case(
191189
inferenceEntityId,
192190
taskType,
193191
NAME,
194-
(MultilingualE5SmallInternalServiceSettings) e5ServiceSettings.build()
192+
(MultilingualE5SmallInternalServiceSettings) esServiceSettingsBuilder.build()
195193
)
196194
);
197195
}
198196

199-
private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(
200-
Set<String> platformArchitectures,
201-
InternalServiceSettings.Builder e5ServiceSettings
202-
) {
203-
return e5ServiceSettings.getModelId().equals(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures)) == false
204-
&& e5ServiceSettings.getModelId().equals(MULTILINGUAL_E5_SMALL_MODEL_ID) == false;
197+
static boolean modelVariantValidForArchitecture(Set<String> platformArchitectures, String modelId) {
198+
if (modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID)) {
199+
// platform agnostic model is always compatible
200+
return true;
201+
}
202+
203+
return modelId.equals(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures));
205204
}
206205

207206
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
import java.util.concurrent.atomic.AtomicInteger;
6565
import java.util.concurrent.atomic.AtomicReference;
6666

67+
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID;
68+
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
6769
import static org.hamcrest.Matchers.containsString;
6870
import static org.hamcrest.Matchers.hasSize;
6971
import static org.hamcrest.Matchers.instanceOf;
@@ -165,6 +167,36 @@ public void testParseRequestConfig() {
165167

166168
service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener);
167169
}
170+
}
171+
172+
public void testParseRequestConfig_E5() {
173+
{
174+
var service = createService(mock(Client.class));
175+
var settings = new HashMap<String, Object>();
176+
settings.put(
177+
ModelConfigurations.SERVICE_SETTINGS,
178+
new HashMap<>(
179+
Map.of(
180+
ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS,
181+
1,
182+
ElasticsearchInternalServiceSettings.NUM_THREADS,
183+
4,
184+
ElasticsearchInternalServiceSettings.MODEL_ID,
185+
MULTILINGUAL_E5_SMALL_MODEL_ID
186+
)
187+
)
188+
);
189+
190+
var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID);
191+
192+
service.parseRequestConfig(
193+
randomInferenceEntityId,
194+
TaskType.TEXT_EMBEDDING,
195+
settings,
196+
Set.of(),
197+
getModelVerificationActionListener(e5ServiceSettings)
198+
);
199+
}
168200

169201
// Invalid service settings
170202
{
@@ -178,9 +210,8 @@ public void testParseRequestConfig() {
178210
1,
179211
ElasticsearchInternalServiceSettings.NUM_THREADS,
180212
4,
181-
InternalServiceSettings.MODEL_ID,
182-
ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, // we can't directly test the eland case until we mock
183-
// the threadpool within the client
213+
ElasticsearchInternalServiceSettings.MODEL_ID,
214+
MULTILINGUAL_E5_SMALL_MODEL_ID,
184215
"not_a_valid_service_setting",
185216
randomAlphaOfLength(10)
186217
)
@@ -419,19 +450,15 @@ public void testParsePersistedConfig() {
419450
1,
420451
ElasticsearchInternalServiceSettings.NUM_THREADS,
421452
4,
422-
InternalServiceSettings.MODEL_ID,
423-
ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
453+
ElasticsearchInternalServiceSettings.MODEL_ID,
454+
MULTILINGUAL_E5_SMALL_MODEL_ID,
424455
ServiceFields.DIMENSIONS,
425456
1
426457
)
427458
)
428459
);
429460

430-
var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(
431-
1,
432-
4,
433-
ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID
434-
);
461+
var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID);
435462

436463
MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig(
437464
randomInferenceEntityId,
@@ -860,6 +887,31 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() {
860887
assertThat(model, is(expectedModel));
861888
}
862889

890+
public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() {
891+
{
892+
var architectures = Set.of("Aarch64");
893+
assertFalse(
894+
ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
895+
);
896+
897+
assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
898+
}
899+
{
900+
var architectures = Set.of("linux-x86_64");
901+
assertTrue(
902+
ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
903+
);
904+
assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
905+
}
906+
{
907+
var architectures = Set.of("linux-x86_64", "Aarch64");
908+
assertFalse(
909+
ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
910+
);
911+
assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
912+
}
913+
}
914+
863915
private ElasticsearchInternalService createService(Client client) {
864916
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
865917
return new ElasticsearchInternalService(context);

0 commit comments

Comments
 (0)