28
28
import org .elasticsearch .common .settings .Setting ;
29
29
import org .elasticsearch .common .unit .ByteSizeValue ;
30
30
import org .elasticsearch .common .util .concurrent .AtomicArray ;
31
+ import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
31
32
import org .elasticsearch .common .xcontent .XContentHelper ;
32
33
import org .elasticsearch .common .xcontent .support .XContentMapValues ;
33
34
import org .elasticsearch .core .Nullable ;
34
35
import org .elasticsearch .core .Releasable ;
35
36
import org .elasticsearch .core .TimeValue ;
37
+ import org .elasticsearch .index .IndexingPressure ;
36
38
import org .elasticsearch .index .mapper .InferenceMetadataFieldsMapper ;
37
39
import org .elasticsearch .inference .ChunkInferenceInput ;
38
40
import org .elasticsearch .inference .ChunkedInference ;
@@ -108,18 +110,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
108
110
private final InferenceServiceRegistry inferenceServiceRegistry ;
109
111
private final ModelRegistry modelRegistry ;
110
112
private final XPackLicenseState licenseState ;
113
+ private final IndexingPressure indexingPressure ;
111
114
private volatile long batchSizeInBytes ;
112
115
113
116
public ShardBulkInferenceActionFilter (
114
117
ClusterService clusterService ,
115
118
InferenceServiceRegistry inferenceServiceRegistry ,
116
119
ModelRegistry modelRegistry ,
117
- XPackLicenseState licenseState
120
+ XPackLicenseState licenseState ,
121
+ IndexingPressure indexingPressure
118
122
) {
119
123
this .clusterService = clusterService ;
120
124
this .inferenceServiceRegistry = inferenceServiceRegistry ;
121
125
this .modelRegistry = modelRegistry ;
122
126
this .licenseState = licenseState ;
127
+ this .indexingPressure = indexingPressure ;
123
128
this .batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE .get (clusterService .getSettings ()).getBytes ();
124
129
clusterService .getClusterSettings ().addSettingsUpdateConsumer (INDICES_INFERENCE_BATCH_SIZE , this ::setBatchSize );
125
130
}
@@ -145,8 +150,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
145
150
BulkShardRequest bulkShardRequest = (BulkShardRequest ) request ;
146
151
var fieldInferenceMetadata = bulkShardRequest .consumeInferenceFieldMap ();
147
152
if (fieldInferenceMetadata != null && fieldInferenceMetadata .isEmpty () == false ) {
148
- Runnable onInferenceCompletion = () -> chain .proceed (task , action , request , listener );
149
- processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion );
153
+ // Maintain coordinating indexing pressure from inference until the indexing operations are complete
154
+ IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .createCoordinatingOperation (false );
155
+ Runnable onInferenceCompletion = () -> chain .proceed (
156
+ task ,
157
+ action ,
158
+ request ,
159
+ ActionListener .releaseAfter (listener , coordinatingIndexingPressure )
160
+ );
161
+ processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion , coordinatingIndexingPressure );
150
162
return ;
151
163
}
152
164
}
@@ -156,11 +168,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
156
168
private void processBulkShardRequest (
157
169
Map <String , InferenceFieldMetadata > fieldInferenceMap ,
158
170
BulkShardRequest bulkShardRequest ,
159
- Runnable onCompletion
171
+ Runnable onCompletion ,
172
+ IndexingPressure .Coordinating coordinatingIndexingPressure
160
173
) {
161
174
var index = clusterService .state ().getMetadata ().index (bulkShardRequest .index ());
162
175
boolean useLegacyFormat = InferenceMetadataFieldsMapper .isEnabled (index .getSettings ()) == false ;
163
- new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion ).run ();
176
+ new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion , coordinatingIndexingPressure )
177
+ .run ();
164
178
}
165
179
166
180
private record InferenceProvider (InferenceService service , Model model ) {}
@@ -230,18 +244,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
230
244
private final BulkShardRequest bulkShardRequest ;
231
245
private final Runnable onCompletion ;
232
246
private final AtomicArray <FieldInferenceResponseAccumulator > inferenceResults ;
247
+ private final IndexingPressure .Coordinating coordinatingIndexingPressure ;
233
248
234
249
private AsyncBulkShardInferenceAction (
235
250
boolean useLegacyFormat ,
236
251
Map <String , InferenceFieldMetadata > fieldInferenceMap ,
237
252
BulkShardRequest bulkShardRequest ,
238
- Runnable onCompletion
253
+ Runnable onCompletion ,
254
+ IndexingPressure .Coordinating coordinatingIndexingPressure
239
255
) {
240
256
this .useLegacyFormat = useLegacyFormat ;
241
257
this .fieldInferenceMap = fieldInferenceMap ;
242
258
this .bulkShardRequest = bulkShardRequest ;
243
259
this .inferenceResults = new AtomicArray <>(bulkShardRequest .items ().length );
244
260
this .onCompletion = onCompletion ;
261
+ this .coordinatingIndexingPressure = coordinatingIndexingPressure ;
245
262
}
246
263
247
264
@ Override
@@ -429,9 +446,9 @@ public void onFailure(Exception exc) {
429
446
*/
430
447
private long addFieldInferenceRequests (BulkItemRequest item , int itemIndex , Map <String , List <FieldInferenceRequest >> requestsMap ) {
431
448
boolean isUpdateRequest = false ;
432
- final IndexRequest indexRequest ;
449
+ final IndexRequestWithIndexingPressure indexRequest ;
433
450
if (item .request () instanceof IndexRequest ir ) {
434
- indexRequest = ir ;
451
+ indexRequest = new IndexRequestWithIndexingPressure ( ir ) ;
435
452
} else if (item .request () instanceof UpdateRequest updateRequest ) {
436
453
isUpdateRequest = true ;
437
454
if (updateRequest .script () != null ) {
@@ -445,13 +462,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
445
462
);
446
463
return 0 ;
447
464
}
448
- indexRequest = updateRequest .doc ();
465
+ indexRequest = new IndexRequestWithIndexingPressure ( updateRequest .doc () );
449
466
} else {
450
467
// ignore delete request
451
468
return 0 ;
452
469
}
453
470
454
- final Map <String , Object > docMap = indexRequest .sourceAsMap ();
471
+ final Map <String , Object > docMap = indexRequest .getIndexRequest (). sourceAsMap ();
455
472
long inputLength = 0 ;
456
473
for (var entry : fieldInferenceMap .values ()) {
457
474
String field = entry .getName ();
@@ -487,6 +504,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
487
504
* This ensures that the field is treated as intentionally cleared,
488
505
* preventing any unintended carryover of prior inference results.
489
506
*/
507
+ if (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
508
+ return inputLength ;
509
+ }
510
+
490
511
var slot = ensureResponseAccumulatorSlot (itemIndex );
491
512
slot .addOrUpdateResponse (
492
513
new FieldInferenceResponse (field , sourceField , null , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -508,6 +529,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
508
529
}
509
530
continue ;
510
531
}
532
+
511
533
var slot = ensureResponseAccumulatorSlot (itemIndex );
512
534
final List <String > values ;
513
535
try {
@@ -525,7 +547,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
525
547
List <FieldInferenceRequest > requests = requestsMap .computeIfAbsent (inferenceId , k -> new ArrayList <>());
526
548
int offsetAdjustment = 0 ;
527
549
for (String v : values ) {
528
- inputLength += v .length ();
550
+ if (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
551
+ return inputLength ;
552
+ }
553
+
529
554
if (v .isBlank ()) {
530
555
slot .addOrUpdateResponse (
531
556
new FieldInferenceResponse (field , sourceField , v , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -534,6 +559,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
534
559
requests .add (
535
560
new FieldInferenceRequest (itemIndex , field , sourceField , v , order ++, offsetAdjustment , chunkingSettings )
536
561
);
562
+ inputLength += v .length ();
537
563
}
538
564
539
565
// When using the inference metadata fields format, all the input values are concatenated so that the
@@ -543,9 +569,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
543
569
}
544
570
}
545
571
}
572
+
546
573
return inputLength ;
547
574
}
548
575
576
+ private static class IndexRequestWithIndexingPressure {
577
+ private final IndexRequest indexRequest ;
578
+ private boolean indexingPressureIncremented ;
579
+
580
+ private IndexRequestWithIndexingPressure (IndexRequest indexRequest ) {
581
+ this .indexRequest = indexRequest ;
582
+ this .indexingPressureIncremented = false ;
583
+ }
584
+
585
+ private IndexRequest getIndexRequest () {
586
+ return indexRequest ;
587
+ }
588
+
589
+ private boolean isIndexingPressureIncremented () {
590
+ return indexingPressureIncremented ;
591
+ }
592
+
593
+ private void setIndexingPressureIncremented () {
594
+ this .indexingPressureIncremented = true ;
595
+ }
596
+ }
597
+
598
+ private boolean incrementIndexingPressure (IndexRequestWithIndexingPressure indexRequest , int itemIndex ) {
599
+ boolean success = true ;
600
+ if (indexRequest .isIndexingPressureIncremented () == false ) {
601
+ try {
602
+ // Track operation count as one operation per document source update
603
+ coordinatingIndexingPressure .increment (1 , indexRequest .getIndexRequest ().source ().ramBytesUsed ());
604
+ indexRequest .setIndexingPressureIncremented ();
605
+ } catch (EsRejectedExecutionException e ) {
606
+ addInferenceResponseFailure (
607
+ itemIndex ,
608
+ new InferenceException (
609
+ "Insufficient memory available to update source on document [" + indexRequest .getIndexRequest ().id () + "]" ,
610
+ e
611
+ )
612
+ );
613
+ success = false ;
614
+ }
615
+ }
616
+
617
+ return success ;
618
+ }
619
+
549
620
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot (int id ) {
550
621
FieldInferenceResponseAccumulator acc = inferenceResults .get (id );
551
622
if (acc == null ) {
@@ -622,6 +693,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
622
693
inferenceFieldsMap .put (fieldName , result );
623
694
}
624
695
696
+ BytesReference originalSource = indexRequest .source ();
625
697
if (useLegacyFormat ) {
626
698
var newDocMap = indexRequest .sourceAsMap ();
627
699
for (var entry : inferenceFieldsMap .entrySet ()) {
@@ -634,6 +706,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
634
706
indexRequest .source (builder );
635
707
}
636
708
}
709
+ long modifiedSourceSize = indexRequest .source ().ramBytesUsed ();
710
+
711
+ // Add the indexing pressure from the source modifications.
712
+ // Don't increment operation count because we count one source update as one operation, and we already accounted for those
713
+ // in addFieldInferenceRequests.
714
+ try {
715
+ coordinatingIndexingPressure .increment (0 , modifiedSourceSize - originalSource .ramBytesUsed ());
716
+ } catch (EsRejectedExecutionException e ) {
717
+ indexRequest .source (originalSource , indexRequest .getContentType ());
718
+ item .abort (
719
+ item .index (),
720
+ new InferenceException (
721
+ "Insufficient memory available to insert inference results into document [" + indexRequest .id () + "]" ,
722
+ e
723
+ )
724
+ );
725
+ }
637
726
}
638
727
}
639
728
0 commit comments