Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit e302ec4

Browse files
authored
[YOLOv8] Fix support for --dataset_dir argument (#1520)
* working * working * addressing comments from review * add more verbose error message * Apply suggestions from code review * Update helpers.py
1 parent 08ad4ef commit e302ec4

File tree

5 files changed

+70
-33
lines changed

5 files changed

+70
-33
lines changed

src/sparseml/yolov8/export.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615

1716
import click
1817
from sparseml.yolov8.trainers import SparseYOLO
19-
from ultralytics.yolo.utils import USER_CONFIG_DIR, get_settings, yaml_save
2018

2119

2220
# Options generated from
@@ -82,18 +80,12 @@
8280
help="cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on",
8381
)
8482
@click.option(
85-
"--datasets-dir",
83+
"--dataset-path",
8684
type=str,
8785
default=None,
88-
help="Path to override default datasets dir.",
86+
help="Path to override default dataset path.",
8987
)
9088
def main(**kwargs):
91-
if kwargs["datasets_dir"] is not None:
92-
settings = get_settings()
93-
settings["datasets_dir"] = os.path.abspath(
94-
os.path.expanduser(kwargs["datasets_dir"])
95-
)
96-
yaml_save(USER_CONFIG_DIR / "settings.yaml", settings)
9789

9890
model = SparseYOLO(kwargs["model"])
9991
model.export(**kwargs)

src/sparseml/yolov8/train.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import os
1716

1817
import click
1918
from sparseml.yolov8.trainers import SparseYOLO
20-
from ultralytics.yolo.utils import USER_CONFIG_DIR, get_settings, yaml_save
19+
from sparseml.yolov8.utils import data_from_dataset_path
2120

2221

2322
logger = logging.getLogger()
@@ -212,18 +211,15 @@
212211
"--copy-paste", type=float, default=0.0, help="segment copy-paste (probability)"
213212
)
214213
@click.option(
215-
"--datasets-dir",
214+
"--dataset-path",
216215
type=str,
217216
default=None,
218-
help="Path to override default datasets dir.",
217+
help="Path to override default dataset path.",
219218
)
220219
def main(**kwargs):
221-
if kwargs["datasets_dir"] is not None:
222-
settings = get_settings()
223-
settings["datasets_dir"] = os.path.abspath(
224-
os.path.expanduser(kwargs["datasets_dir"])
225-
)
226-
yaml_save(USER_CONFIG_DIR / "settings.yaml", settings)
220+
if kwargs["dataset_path"] is not None:
221+
kwargs["data"] = data_from_dataset_path(kwargs["data"], kwargs["dataset_path"])
222+
del kwargs["dataset_path"]
227223

228224
model = SparseYOLO(kwargs["model"])
229225
model.train(**kwargs)

src/sparseml/yolov8/trainers.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
from sparseml.pytorch.utils.helpers import download_framework_model_by_recipe_type
3333
from sparseml.pytorch.utils.logger import LoggerManager, PythonLogger, WANDBLogger
3434
from sparseml.yolov8.modules import Bottleneck, Conv
35-
from sparseml.yolov8.utils import check_coco128_segmentation, create_grad_sampler
35+
from sparseml.yolov8.utils import (
36+
check_coco128_segmentation,
37+
create_grad_sampler,
38+
data_from_dataset_path,
39+
)
3640
from sparseml.yolov8.utils.export_samples import export_sample_inputs_outputs
3741
from sparseml.yolov8.validators import (
3842
SparseClassificationValidator,
@@ -662,6 +666,11 @@ def export(self, **kwargs):
662666
if kwargs["device"] is not None and "cpu" not in kwargs["device"]:
663667
overrides["device"] = "cuda:" + kwargs["device"]
664668
overrides["deterministic"] = kwargs["deterministic"]
669+
if kwargs["dataset_path"] is not None:
670+
overrides["data"] = data_from_dataset_path(
671+
overrides["data"], kwargs["dataset_path"]
672+
)
673+
665674
trainer = self.TrainerClass(overrides=overrides)
666675
self.model = self.model.to(trainer.device)
667676

@@ -717,9 +726,12 @@ def export(self, **kwargs):
717726
if args["export_samples"]:
718727
trainer_config = get_cfg(cfg=DEFAULT_SPARSEML_CONFIG_PATH)
719728

729+
if args["dataset_path"] is not None:
730+
args["data"] = data_from_dataset_path(
731+
args["data"], args["dataset_path"]
732+
)
720733
trainer_config.data = args["data"]
721734
trainer_config.imgsz = args["imgsz"]
722-
723735
trainer = DetectionTrainer(trainer_config)
724736
# inconsistency in name between
725737
# validation and test sets

src/sparseml/yolov8/utils/helpers.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import glob
1516
import os
1617
import warnings
1718
from argparse import Namespace
1819
from typing import Any, Dict
1920

21+
import yaml
22+
2023
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
24+
from ultralytics.yolo.data.utils import ROOT
2125
from ultralytics.yolo.engine.model import DetectionModel
2226
from ultralytics.yolo.engine.trainer import BaseTrainer
2327

2428

25-
__all__ = ["check_coco128_segmentation", "create_grad_sampler"]
29+
__all__ = [
30+
"check_coco128_segmentation",
31+
"create_grad_sampler",
32+
"data_from_dataset_path",
33+
]
2634

2735

2836
def check_coco128_segmentation(args: Namespace) -> Namespace:
@@ -69,3 +77,36 @@ def create_grad_sampler(
6977
/ train_loader.batch_size,
7078
)
7179
return grad_sampler
80+
81+
82+
def data_from_dataset_path(data: str, dataset_path: str) -> str:
83+
"""
84+
Given a dataset name, fetch the yaml config for the dataset
85+
from the Ultralytics dataset repo, overwrite its 'path'
86+
attribute (dataset root dir) to point to the `dataset_path`
87+
and finally save it to the current working directory.
88+
This allows to create load data yaml config files that point
89+
to the arbitrary directories on the disk.
90+
91+
:param data: name of the dataset (e.g. "coco.yaml")
92+
:param dataset_path: path to the dataset directory
93+
:return: a path to the new yaml config file
94+
(saved in the current working directory)
95+
"""
96+
ultralytics_dataset_path = glob.glob(os.path.join(ROOT, "**", data), recursive=True)
97+
if len(ultralytics_dataset_path) != 1:
98+
raise ValueError(
99+
"Expected to find a single path to the "
100+
f"dataset yaml file: {data}, but found {ultralytics_dataset_path}"
101+
)
102+
ultralytics_dataset_path = ultralytics_dataset_path[0]
103+
with open(ultralytics_dataset_path, "r") as f:
104+
yaml_config = yaml.safe_load(f)
105+
yaml_config["path"] = dataset_path
106+
107+
yaml_save_path = os.path.join(os.getcwd(), data)
108+
109+
# save the new dataset yaml file
110+
with open(yaml_save_path, "w") as outfile:
111+
yaml.dump(yaml_config, outfile, default_flow_style=False)
112+
return yaml_save_path

src/sparseml/yolov8/val.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615

1716
import click
1817
from sparseml.yolov8.trainers import SparseYOLO
19-
from ultralytics.yolo.utils import USER_CONFIG_DIR, get_settings, yaml_save
18+
from sparseml.yolov8.utils import data_from_dataset_path
2019

2120

2221
@click.command(
@@ -72,18 +71,15 @@
7271
)
7372
@click.option("--plots", default=False, is_flag=True, help="show plots during training")
7473
@click.option(
75-
"--datasets-dir",
74+
"--dataset-path",
7675
type=str,
7776
default=None,
78-
help="Path to override default datasets dir.",
77+
help="Path to override default datasets path.",
7978
)
8079
def main(**kwargs):
81-
if kwargs["datasets_dir"] is not None:
82-
settings = get_settings()
83-
settings["datasets_dir"] = os.path.abspath(
84-
os.path.expanduser(kwargs["datasets_dir"])
85-
)
86-
yaml_save(USER_CONFIG_DIR / "settings.yaml", settings)
80+
if kwargs["dataset_path"] is not None:
81+
kwargs["data"] = data_from_dataset_path(kwargs["data"], kwargs["dataset_path"])
82+
del kwargs["dataset_path"]
8783

8884
model = SparseYOLO(kwargs["model"])
8985
if hasattr(model, "overrides"):

0 commit comments

Comments
 (0)