Skip to content

Commit 7da41be

Browse files
authored
Merge pull request #1192 from sdbds/main
Add WDV3 support
2 parents e281e86 + 6c51c97 commit 7da41be

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

finetune/tag_images_by_wd14_tagger.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -86,23 +86,26 @@ def main(args):
8686
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
8787
files = FILES
8888
if args.onnx:
89+
files = ["selected_tags.csv"]
8990
files += FILES_ONNX
91+
else:
92+
for file in SUB_DIR_FILES:
93+
hf_hub_download(
94+
args.repo_id,
95+
file,
96+
subfolder=SUB_DIR,
97+
cache_dir=os.path.join(args.model_dir, SUB_DIR),
98+
force_download=True,
99+
force_filename=file,
100+
)
90101
for file in files:
91102
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
92-
for file in SUB_DIR_FILES:
93-
hf_hub_download(
94-
args.repo_id,
95-
file,
96-
subfolder=SUB_DIR,
97-
cache_dir=os.path.join(args.model_dir, SUB_DIR),
98-
force_download=True,
99-
force_filename=file,
100-
)
101103
else:
102104
logger.info("using existing wd14 tagger model")
103105

104106
# 画像を読み込む
105107
if args.onnx:
108+
import torch
106109
import onnx
107110
import onnxruntime as ort
108111

requirements.txt

+6-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ huggingface-hub==0.20.1
2222
# for WD14 captioning (tensorflow)
2323
# tensorflow==2.10.1
2424
# for WD14 captioning (onnx)
25-
# onnx==1.14.1
26-
# onnxruntime-gpu==1.16.0
27-
# onnxruntime==1.16.0
25+
# onnx==1.15.0
26+
# onnxruntime-gpu==1.17.1
27+
# onnxruntime==1.17.1
28+
# for cuda 12.1(default 11.8)
29+
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
30+
2831
# this is for onnx:
2932
# protobuf==3.20.3
3033
# open clip for SDXL

0 commit comments

Comments
 (0)