29
29
import org .elasticsearch .common .settings .Setting ;
30
30
import org .elasticsearch .common .unit .ByteSizeValue ;
31
31
import org .elasticsearch .common .util .concurrent .AtomicArray ;
32
+ import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
32
33
import org .elasticsearch .common .xcontent .XContentHelper ;
33
34
import org .elasticsearch .common .xcontent .support .XContentMapValues ;
34
35
import org .elasticsearch .core .Nullable ;
35
36
import org .elasticsearch .core .Releasable ;
36
37
import org .elasticsearch .core .TimeValue ;
38
+ import org .elasticsearch .index .IndexingPressure ;
37
39
import org .elasticsearch .index .mapper .InferenceMetadataFieldsMapper ;
38
40
import org .elasticsearch .inference .ChunkInferenceInput ;
39
41
import org .elasticsearch .inference .ChunkedInference ;
@@ -109,18 +111,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
109
111
private final InferenceServiceRegistry inferenceServiceRegistry ;
110
112
private final ModelRegistry modelRegistry ;
111
113
private final XPackLicenseState licenseState ;
114
+ private final IndexingPressure indexingPressure ;
112
115
private volatile long batchSizeInBytes ;
113
116
114
117
public ShardBulkInferenceActionFilter (
115
118
ClusterService clusterService ,
116
119
InferenceServiceRegistry inferenceServiceRegistry ,
117
120
ModelRegistry modelRegistry ,
118
- XPackLicenseState licenseState
121
+ XPackLicenseState licenseState ,
122
+ IndexingPressure indexingPressure
119
123
) {
120
124
this .clusterService = clusterService ;
121
125
this .inferenceServiceRegistry = inferenceServiceRegistry ;
122
126
this .modelRegistry = modelRegistry ;
123
127
this .licenseState = licenseState ;
128
+ this .indexingPressure = indexingPressure ;
124
129
this .batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE .get (clusterService .getSettings ()).getBytes ();
125
130
clusterService .getClusterSettings ().addSettingsUpdateConsumer (INDICES_INFERENCE_BATCH_SIZE , this ::setBatchSize );
126
131
}
@@ -146,8 +151,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
146
151
BulkShardRequest bulkShardRequest = (BulkShardRequest ) request ;
147
152
var fieldInferenceMetadata = bulkShardRequest .consumeInferenceFieldMap ();
148
153
if (fieldInferenceMetadata != null && fieldInferenceMetadata .isEmpty () == false ) {
149
- Runnable onInferenceCompletion = () -> chain .proceed (task , action , request , listener );
150
- processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion );
154
+ // Maintain coordinating indexing pressure from inference until the indexing operations are complete
155
+ IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .createCoordinatingOperation (false );
156
+ Runnable onInferenceCompletion = () -> chain .proceed (
157
+ task ,
158
+ action ,
159
+ request ,
160
+ ActionListener .releaseAfter (listener , coordinatingIndexingPressure )
161
+ );
162
+ processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion , coordinatingIndexingPressure );
151
163
return ;
152
164
}
153
165
}
@@ -157,12 +169,14 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
157
169
private void processBulkShardRequest (
158
170
Map <String , InferenceFieldMetadata > fieldInferenceMap ,
159
171
BulkShardRequest bulkShardRequest ,
160
- Runnable onCompletion
172
+ Runnable onCompletion ,
173
+ IndexingPressure .Coordinating coordinatingIndexingPressure
161
174
) {
162
175
final ProjectMetadata project = clusterService .state ().getMetadata ().getProject ();
163
176
var index = project .index (bulkShardRequest .index ());
164
177
boolean useLegacyFormat = InferenceMetadataFieldsMapper .isEnabled (index .getSettings ()) == false ;
165
- new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion ).run ();
178
+ new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion , coordinatingIndexingPressure )
179
+ .run ();
166
180
}
167
181
168
182
private record InferenceProvider (InferenceService service , Model model ) {}
@@ -232,18 +246,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
232
246
private final BulkShardRequest bulkShardRequest ;
233
247
private final Runnable onCompletion ;
234
248
private final AtomicArray <FieldInferenceResponseAccumulator > inferenceResults ;
249
+ private final IndexingPressure .Coordinating coordinatingIndexingPressure ;
235
250
236
251
private AsyncBulkShardInferenceAction (
237
252
boolean useLegacyFormat ,
238
253
Map <String , InferenceFieldMetadata > fieldInferenceMap ,
239
254
BulkShardRequest bulkShardRequest ,
240
- Runnable onCompletion
255
+ Runnable onCompletion ,
256
+ IndexingPressure .Coordinating coordinatingIndexingPressure
241
257
) {
242
258
this .useLegacyFormat = useLegacyFormat ;
243
259
this .fieldInferenceMap = fieldInferenceMap ;
244
260
this .bulkShardRequest = bulkShardRequest ;
245
261
this .inferenceResults = new AtomicArray <>(bulkShardRequest .items ().length );
246
262
this .onCompletion = onCompletion ;
263
+ this .coordinatingIndexingPressure = coordinatingIndexingPressure ;
247
264
}
248
265
249
266
@ Override
@@ -431,9 +448,9 @@ public void onFailure(Exception exc) {
431
448
*/
432
449
private long addFieldInferenceRequests (BulkItemRequest item , int itemIndex , Map <String , List <FieldInferenceRequest >> requestsMap ) {
433
450
boolean isUpdateRequest = false ;
434
- final IndexRequest indexRequest ;
451
+ final IndexRequestWithIndexingPressure indexRequest ;
435
452
if (item .request () instanceof IndexRequest ir ) {
436
- indexRequest = ir ;
453
+ indexRequest = new IndexRequestWithIndexingPressure ( ir ) ;
437
454
} else if (item .request () instanceof UpdateRequest updateRequest ) {
438
455
isUpdateRequest = true ;
439
456
if (updateRequest .script () != null ) {
@@ -447,13 +464,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
447
464
);
448
465
return 0 ;
449
466
}
450
- indexRequest = updateRequest .doc ();
467
+ indexRequest = new IndexRequestWithIndexingPressure ( updateRequest .doc () );
451
468
} else {
452
469
// ignore delete request
453
470
return 0 ;
454
471
}
455
472
456
- final Map <String , Object > docMap = indexRequest .sourceAsMap ();
473
+ final Map <String , Object > docMap = indexRequest .getIndexRequest (). sourceAsMap ();
457
474
long inputLength = 0 ;
458
475
for (var entry : fieldInferenceMap .values ()) {
459
476
String field = entry .getName ();
@@ -489,6 +506,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
489
506
* This ensures that the field is treated as intentionally cleared,
490
507
* preventing any unintended carryover of prior inference results.
491
508
*/
509
+ if (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
510
+ return inputLength ;
511
+ }
512
+
492
513
var slot = ensureResponseAccumulatorSlot (itemIndex );
493
514
slot .addOrUpdateResponse (
494
515
new FieldInferenceResponse (field , sourceField , null , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -510,6 +531,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
510
531
}
511
532
continue ;
512
533
}
534
+
513
535
var slot = ensureResponseAccumulatorSlot (itemIndex );
514
536
final List <String > values ;
515
537
try {
@@ -527,7 +549,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
527
549
List <FieldInferenceRequest > requests = requestsMap .computeIfAbsent (inferenceId , k -> new ArrayList <>());
528
550
int offsetAdjustment = 0 ;
529
551
for (String v : values ) {
530
- inputLength += v .length ();
552
+ if (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
553
+ return inputLength ;
554
+ }
555
+
531
556
if (v .isBlank ()) {
532
557
slot .addOrUpdateResponse (
533
558
new FieldInferenceResponse (field , sourceField , v , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -536,6 +561,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
536
561
requests .add (
537
562
new FieldInferenceRequest (itemIndex , field , sourceField , v , order ++, offsetAdjustment , chunkingSettings )
538
563
);
564
+ inputLength += v .length ();
539
565
}
540
566
541
567
// When using the inference metadata fields format, all the input values are concatenated so that the
@@ -545,9 +571,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
545
571
}
546
572
}
547
573
}
574
+
548
575
return inputLength ;
549
576
}
550
577
578
+ private static class IndexRequestWithIndexingPressure {
579
+ private final IndexRequest indexRequest ;
580
+ private boolean indexingPressureIncremented ;
581
+
582
+ private IndexRequestWithIndexingPressure (IndexRequest indexRequest ) {
583
+ this .indexRequest = indexRequest ;
584
+ this .indexingPressureIncremented = false ;
585
+ }
586
+
587
+ private IndexRequest getIndexRequest () {
588
+ return indexRequest ;
589
+ }
590
+
591
+ private boolean isIndexingPressureIncremented () {
592
+ return indexingPressureIncremented ;
593
+ }
594
+
595
+ private void setIndexingPressureIncremented () {
596
+ this .indexingPressureIncremented = true ;
597
+ }
598
+ }
599
+
600
+ private boolean incrementIndexingPressure (IndexRequestWithIndexingPressure indexRequest , int itemIndex ) {
601
+ boolean success = true ;
602
+ if (indexRequest .isIndexingPressureIncremented () == false ) {
603
+ try {
604
+ // Track operation count as one operation per document source update
605
+ coordinatingIndexingPressure .increment (1 , indexRequest .getIndexRequest ().source ().ramBytesUsed ());
606
+ indexRequest .setIndexingPressureIncremented ();
607
+ } catch (EsRejectedExecutionException e ) {
608
+ addInferenceResponseFailure (
609
+ itemIndex ,
610
+ new InferenceException (
611
+ "Insufficient memory available to update source on document [" + indexRequest .getIndexRequest ().id () + "]" ,
612
+ e
613
+ )
614
+ );
615
+ success = false ;
616
+ }
617
+ }
618
+
619
+ return success ;
620
+ }
621
+
551
622
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot (int id ) {
552
623
FieldInferenceResponseAccumulator acc = inferenceResults .get (id );
553
624
if (acc == null ) {
@@ -624,6 +695,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
624
695
inferenceFieldsMap .put (fieldName , result );
625
696
}
626
697
698
+ BytesReference originalSource = indexRequest .source ();
627
699
if (useLegacyFormat ) {
628
700
var newDocMap = indexRequest .sourceAsMap ();
629
701
for (var entry : inferenceFieldsMap .entrySet ()) {
@@ -636,6 +708,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
636
708
indexRequest .source (builder );
637
709
}
638
710
}
711
+ long modifiedSourceSize = indexRequest .source ().ramBytesUsed ();
712
+
713
+ // Add the indexing pressure from the source modifications.
714
+ // Don't increment operation count because we count one source update as one operation, and we already accounted for those
715
+ // in addFieldInferenceRequests.
716
+ try {
717
+ coordinatingIndexingPressure .increment (0 , modifiedSourceSize - originalSource .ramBytesUsed ());
718
+ } catch (EsRejectedExecutionException e ) {
719
+ indexRequest .source (originalSource , indexRequest .getContentType ());
720
+ item .abort (
721
+ item .index (),
722
+ new InferenceException (
723
+ "Insufficient memory available to insert inference results into document [" + indexRequest .id () + "]" ,
724
+ e
725
+ )
726
+ );
727
+ }
639
728
}
640
729
}
641
730
0 commit comments