diff --git a/supervision/detection/tools/inference_slicer.py b/supervision/detection/tools/inference_slicer.py index 05469dd61..886cf0909 100644 --- a/supervision/detection/tools/inference_slicer.py +++ b/supervision/detection/tools/inference_slicer.py @@ -1,6 +1,6 @@ import warnings from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Callable, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np @@ -13,6 +13,7 @@ SupervisionWarnings, warn_deprecated, ) +from supervision.utils.iterables import create_batches def move_detections( @@ -71,9 +72,14 @@ class InferenceSlicer: overlap_filter (Union[OverlapFilter, str]): Strategy for filtering or merging overlapping detections in slices. iou_threshold (float): Intersection over Union (IoU) threshold - used when filtering by overlap. + used for non-max suppression. callback (Callable): A function that performs inference on a given image - slice and returns detections. + slice and returns detections. Should accept `np.ndarray` if + `batch_size` is `1` (default) and `List[np.ndarray]` otherwise. + See examples for more details. + batch_size (int): How many images to pass to the model. Defaults to 1. + For other values, `callback` should accept a list of images. Higher + value uses more memory but may be faster. thread_workers (int): Number of threads for parallel execution. Note: @@ -85,12 +91,16 @@ class InferenceSlicer: def __init__( self, - callback: Callable[[np.ndarray], Detections], + callback: Union[ + Callable[[np.ndarray], Detections], + Callable[[List[np.ndarray]], List[Detections]], + ], slice_wh: Tuple[int, int] = (320, 320), overlap_ratio_wh: Optional[Tuple[float, float]] = (0.2, 0.2), overlap_wh: Optional[Tuple[int, int]] = None, overlap_filter: Union[OverlapFilter, str] = OverlapFilter.NON_MAX_SUPPRESSION, iou_threshold: float = 0.5, + batch_size: int = 1, thread_workers: int = 1, ): if overlap_ratio_wh is not None: @@ -108,8 +118,14 @@ def __init__( self.iou_threshold = iou_threshold self.overlap_filter = OverlapFilter.from_value(overlap_filter) self.callback = callback + self.batch_size = batch_size self.thread_workers = thread_workers + if self.batch_size < 1: + raise ValueError("batch_size should be greater than 0") + if self.thread_workers < 1: + raise ValueError("thread_workers should be greater than 0.") + def __call__(self, image: np.ndarray) -> Detections: """ Performs slicing-based inference on the provided image using the specified @@ -133,9 +149,22 @@ def __call__(self, image: np.ndarray) -> Detections: image = cv2.imread(SOURCE_IMAGE_PATH) model = YOLO(...) - def callback(image_slice: np.ndarray) -> sv.Detections: - result = model(image_slice)[0] - return sv.Detections.from_ultralytics(result) + # Option 1: Single slice + def callback(slice: np.ndarray) -> sv.Detections: + result = model(slice)[0] + detections = sv.Detections.from_ultralytics(result) + return detections + + slicer = sv.InferenceSlicer(callback=callback) + detections = slicer(image) + + + # Option 2: Batch slices (Faster, but uses more memory) + def callback(slices: List[np.ndarray]) -> List[sv.Detections]: + results = model(slices) + detections_list = [ + sv.Detections.from_ultralytics(result) for result in results] + return detections_list slicer = sv.InferenceSlicer( callback=callback, @@ -153,13 +182,36 @@ def callback(image_slice: np.ndarray) -> sv.Detections: overlap_ratio_wh=self.overlap_ratio_wh, overlap_wh=self.overlap_wh, ) + batched_offsets_generator = create_batches(offsets, self.batch_size) + + if self.thread_workers == 1: + for offset_batch in batched_offsets_generator: + if self.batch_size == 1: + result = self._callback_image_single(image, offset_batch[0]) + detections_list.append(result) + else: + results = self._callback_image_batch(image, offset_batch) + detections_list.extend(results) - with ThreadPoolExecutor(max_workers=self.thread_workers) as executor: - futures = [ - executor.submit(self._run_callback, image, offset) for offset in offsets - ] - for future in as_completed(futures): - detections_list.append(future.result()) + else: + with ThreadPoolExecutor(max_workers=self.thread_workers) as executor: + futures = [] + for offset_batch in batched_offsets_generator: + if self.batch_size == 1: + future = executor.submit( + self._callback_image_single, image, offset_batch[0] + ) + else: + future = executor.submit( + self._callback_image_batch, image, offset_batch + ) + futures.append(future) + + for future in as_completed(futures): + if self.batch_size == 1: + detections_list.append(future.result()) + else: + detections_list.extend(future.result()) merged = Detections.merge(detections_list=detections_list) if self.overlap_filter == OverlapFilter.NONE: @@ -175,27 +227,60 @@ def callback(image_slice: np.ndarray) -> sv.Detections: ) return merged - def _run_callback(self, image, offset) -> Detections: + def _callback_image_single( + self, image: np.ndarray, offset: np.ndarray + ) -> Detections: """ - Run the provided callback on a slice of an image. + Run the callback on a single image. Args: image (np.ndarray): The input image on which inference needs to run - offset (np.ndarray): An array of shape `(4,)` containing coordinates - for the slice. - - Returns: - Detections: A collection of detections for the slice. """ + assert isinstance(offset, np.ndarray) + image_slice = crop_image(image=image, xyxy=offset) detections = self.callback(image_slice) - resolution_wh = (image.shape[1], image.shape[0]) - detections = move_detections( - detections=detections, offset=offset[:2], resolution_wh=resolution_wh - ) + if not isinstance(detections, Detections): + raise ValueError( + f"Callback should return a single Detections object when " + f"max_batch_size is 1. Instead it returned: {type(detections)}" + ) + detections = move_detections(detections=detections, offset=offset[:2]) return detections + def _callback_image_batch( + self, image: np.ndarray, offsets_batch: List[np.ndarray] + ) -> List[Detections]: + """ + Run the callback on a batch of images. + + Args: + image (np.ndarray): The input image on which inference needs to run + offsets_batch (List[np.ndarray]): List of N arrays of shape `(4,)`, + containing coordinates of the slices. + + Returns: + List[Detections]: Detections found in each slice + """ + assert isinstance(offsets_batch, list) + + slices = [crop_image(image=image, xyxy=offset) for offset in offsets_batch] + detections_in_slices = self.callback(slices) + if not isinstance(detections_in_slices, list): + raise ValueError( + f"Callback should return a list of Detections objects when " + f"max_batch_size is greater than 1. " + f"Instead it returned: {type(detections_in_slices)}" + ) + + detections_with_offset = [ + move_detections(detections=detections, offset=offset[:2]) + for detections, offset in zip(detections_in_slices, offsets_batch) + ] + + return detections_with_offset + @staticmethod def _generate_offset( resolution_wh: Tuple[int, int],