Skip to content

[YOLOv8-Seg] Validation #924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/deepsparse/yolo/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def process_inputs(
multi_label=inputs.multi_label,
original_image_shapes=original_image_shapes,
return_masks=inputs.return_masks,
return_intermediate_outputs=inputs.return_intermediate_outputs,
)
return [image_batch], postprocessing_kwargs

Expand Down Expand Up @@ -309,6 +310,9 @@ def process_engine_outputs(
boxes=batch_boxes,
scores=batch_scores,
labels=batch_labels,
intermediate_outputs=engine_outputs[0]
if kwargs.get("return_intermediate_outputs")
else None,
)

def _make_batch(self, image_batch: List[numpy.ndarray]) -> numpy.ndarray:
Expand Down
11 changes: 10 additions & 1 deletion src/deepsparse/yolo/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

from collections import namedtuple
from typing import Iterable, List, TextIO
from typing import Any, Iterable, List, Optional, TextIO

import numpy
from PIL import Image
Expand Down Expand Up @@ -61,6 +61,11 @@ class YOLOInput(ComputerVisionSchema):
description="Controls whether the pipeline should additionally "
"return segmentation masks (if running a segmentation model)",
)
return_intermediate_outputs: bool = Field(
default=False,
description="Controls whether the pipeline should additionally "
"return intermediate outputs from the model",
)

@classmethod
def from_files(
Expand Down Expand Up @@ -105,6 +110,10 @@ class YOLOOutput(BaseModel):
labels: List[List[str]] = Field(
description="List of labels, one for each prediction"
)
intermediate_outputs: Optional[Any] = Field(
default=None,
description="Intermediate outputs " "from the YOLOv8 segmentation model.",
)

def __getitem__(self, index):
if index >= len(self.boxes):
Expand Down
15 changes: 14 additions & 1 deletion src/deepsparse/yolov8/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def process_engine_outputs_seg(
nm=nm,
multi_label=kwargs.get("multi_label", False),
)
detections_output = torch.stack(detections_output)

mask_protos = numpy.stack(mask_protos)
original_image_shapes = kwargs.get("original_image_shapes")
batch_boxes, batch_scores, batch_labels, batch_masks = [], [], [], []
Expand All @@ -145,6 +145,15 @@ def process_engine_outputs_seg(
)

bboxes = detection_output[:, :4]

# check if empty detection
if bboxes.shape[0] == 0:
batch_boxes.append([None])
batch_scores.append([None])
batch_labels.append([None])
batch_masks.append([None])
continue

bboxes = self._scale_boxes(bboxes, original_image_shape)
scores = detection_output[:, 4]
labels = detection_output[:, 5]
Expand All @@ -155,6 +164,7 @@ def process_engine_outputs_seg(
batch_boxes.append(bboxes.tolist())
batch_scores.append(scores.tolist())
batch_labels.append(labels.tolist())

batch_masks.append(
process_mask_upsample(
protos=protos,
Expand All @@ -179,4 +189,7 @@ def process_engine_outputs_seg(
scores=batch_scores,
classes=batch_labels,
masks=batch_masks if kwargs.get("return_masks") else None,
intermediate_outputs=(detections, mask_protos)
if kwargs.get("return_intermediate_outputs")
else None,
)
20 changes: 15 additions & 5 deletions src/deepsparse/yolov8/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List
from typing import Any, List, Optional, Tuple

from pydantic import BaseModel, Field

Expand All @@ -25,13 +25,23 @@ class YOLOSegOutput(BaseModel):
Output model for YOLOv8 Segmentation model
"""

boxes: List[List[List[float]]] = Field(
boxes: List[List[Optional[List[float]]]] = Field(
description="List of bounding boxes, one for each prediction"
)
scores: List[List[float]] = Field(
scores: List[List[Optional[float]]] = Field(
description="List of scores, one for each prediction"
)
classes: List[List[str]] = Field(
classes: List[List[Optional[str]]] = Field(
description="List of labels, one for each prediction"
)
masks: List[Any] = Field(description="List of masks, one for each prediction")
masks: Optional[List[Any]] = Field(
description="List of masks, one for each prediction"
)

intermediate_outputs: Optional[Tuple[Any, Any]] = Field(
default=None,
description="A tuple that contains of intermediate outputs "
"from the YOLOv8 segmentation model. The tuple"
"contains two items: predictions from the model"
"and mask prototypes",
)
2 changes: 2 additions & 0 deletions src/deepsparse/yolov8/utils/validation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@

# flake8: noqa
from .detection_validator import *
from .helpers import *
from .segmentation_validator import *
174 changes: 27 additions & 147 deletions src/deepsparse/yolov8/utils/validation/deepsparse_validator.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# neuralmagic: no copyright
# flake8: noqa

import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
from typing import Dict

from tqdm import tqdm

import torch
from deepsparse import Pipeline
from deepsparse.yolo import YOLOOutput
from ultralytics.yolo.cfg import get_cfg
from deepsparse.yolov8.utils.validation.helpers import schema_to_tensor
from ultralytics.yolo.data.utils import check_det_dataset
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import (
DEFAULT_CFG,
LOGGER,
Expand All @@ -35,119 +16,22 @@
TQDM_BAR_FORMAT,
callbacks,
)
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode


__all__ = ["DeepSparseValidator"]


def schema_to_tensor(pipeline_outputs: YOLOOutput, device: str) -> List[torch.Tensor]:
"""
Transform the YOLOOutput to the format expected by the validation code.

:param pipeline_outputs: YOLOOutput from the pipeline
:param device: device to move the tensors to
:return list of tensor with the format [x1, y1, x2, y2, confidence, class]
"""

preds = []

for boxes, labels, confidence in zip(
pipeline_outputs.boxes, pipeline_outputs.labels, pipeline_outputs.scores
):

boxes = torch.tensor(boxes)

# map labels to integers and reshape for concatenation
labels = list(map(int, list(map(float, labels))))
labels = torch.tensor(labels).view(-1, 1)

# reshape for concatenation
scores = torch.tensor(confidence).view(-1, 1)
# concatenate and append to preds
preds.append(torch.cat([boxes, scores, labels], axis=1).to(device))
return preds


# adapted from ULTRALYTICS GITHUB:
# https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/engine/validator.py
# the appropriate edits are marked with # deepsparse edit: <edit comment>


class DeepSparseValidator(BaseValidator): # deepsparse edit: overwriting BaseValidator
"""
A DeepSparseValidator class for creating validators for
YOLOv8 Deepsparse pipeline.

Attributes:
pipeline (Pipeline): DeepSparse Pipeline to be evaluated
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
logger (logging.Logger): Logger to use for validation.
args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
batch_i (int): Current batch index.
training (bool): Whether the model is in training mode.
speed (float): Batch processing speed in seconds.
jdict (dict): Dictionary to store validation results.
save_dir (Path): Directory to save results.
"""

def __init__(
self,
pipeline: Pipeline, # deepsparse edit: added pipeline
dataloader=None,
save_dir=None,
pbar=None,
logger=None,
args=None,
):
"""
Initializes a DeepSparseValidator instance.
Args:
pipeline (Pipeline): DeepSparse Pipeline to be evaluated
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
logger (logging.Logger): Logger to log messages.
args (SimpleNamespace): Configuration for the validator.
"""
self.pipeline = pipeline # deepsparse edit: added pipeline
self.dataloader = dataloader
self.pbar = pbar
self.logger = logger or LOGGER
self.args = args or get_cfg(DEFAULT_CFG)
self.model = None
self.data = None
self.device = None
self.batch_i = None
class DeepSparseValidator:
def __init__(self, pipeline):
self.pipeline = pipeline
self.training = False
self.speed = None
self.jdict = None

project = self.args.project or Path(SETTINGS["runs_dir"]) / self.args.task
name = self.args.name or f"{self.args.mode}"
self.save_dir = save_dir or increment_path(
Path(project) / name,
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True,
)
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(
parents=True, exist_ok=True
)

if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001

self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks

@smart_inference_mode()
# deepsparse edit: replaced arguments `trainer` and `model`
# with `stride` and `classes`
def __call__(self, stride: int, classes: Dict[int, str]):
# `classes`
def __call__(self, classes: Dict[int, str]):
"""
Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority).
Expand All @@ -167,40 +51,34 @@ def __call__(self, stride: int, classes: Dict[int, str]):
self.dataloader = self.dataloader or self.get_dataloader(
self.data.get("val") or self.data.set("test"), self.args.batch
)
# deepsparse edit: left only profiler for inference, removed the redundant
# profilers for pre-process, loss and post-process
dt = Profile()

dt = Profile(), Profile(), Profile(), Profile()
n_batches = len(self.dataloader)
desc = self.get_desc()
# NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
# which may affect classification task since this arg is in yolov5/classify/val.py.
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
# deepsparse edit: replaced argument `model` with `classes`
self.init_metrics(classes=classes)
self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar):
self.run_callbacks("on_val_batch_start")
self.batch_i = batch_i
# pre-process
with dt[0]:
batch = self.preprocess(batch)

# deepsparse edit:
# - removed the redundant pre-process function
# - removed the redundant loss computation
# - removed the redundant post-process function

# deepsparse edit: replaced the inference model with the DeepSparse pipeline
# inference
with dt:
with dt[1]:
outputs = self.pipeline(
images=[x.cpu().numpy() for x in batch["img"]],
iou_thres=self.args.iou,
conf_thres=self.args.conf,
multi_label=True,
images=[x.cpu().numpy() * 255 for x in batch["img"]],
return_intermediate_outputs=True,
)
preds = schema_to_tensor(pipeline_outputs=outputs, device=self.device)
batch["bboxes"] = batch["bboxes"].to(self.device)
batch["cls"] = batch["cls"].to(self.device)
batch["batch_idx"] = batch["batch_idx"].to(self.device)

# pre-process predictions
with dt[3]:
preds = self.postprocess(preds)

self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
Expand All @@ -211,13 +89,15 @@ def __call__(self, stride: int, classes: Dict[int, str]):
stats = self.get_stats()
self.check_stats(stats)
self.print_results()
self.speed = dt.t / len(self.dataloader.dataset) * 1e3 # speeds per image
self.speed = tuple(
x.t / len(self.dataloader.dataset) * 1e3 for x in dt
) # speeds per image
self.run_callbacks("on_val_end")

# deepsparse_edit: changed the string formatting to match the
# removed profilers
self.logger.info("Speed: %.1fms inference per image" % self.speed)

self.logger.info(
"Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image"
% self.speed
)
if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), "w") as f:
self.logger.info(f"Saving {f.name}...")
Expand Down
Loading