103
103
#
104
104
# This is also used if you do content translation via gettext catalogs.
105
105
# Usually you set "language" from the command line for these cases.
106
- language = "en"
106
+ language = None
107
107
108
108
# List of patterns, relative to source directory, that match files and
109
109
# directories to ignore when looking for source files.
@@ -334,30 +334,32 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
334
334
lines .append ("" )
335
335
336
336
for field in obj :
337
- meta = copy (field .meta )
338
-
339
337
lines += [f"**{ str (field )} **:" , "" ]
340
- lines += [meta .pop ("_docs" )]
341
-
342
338
if field == obj .DEFAULT :
343
- lines += [f"Also available as ``{ obj .__name__ } .DEFAULT``." ]
344
- lines += ["" ]
339
+ lines += [f"This weight is also available as ``{ obj .__name__ } .DEFAULT``." , "" ]
345
340
346
341
table = []
347
- metrics = meta .pop ("_metrics" )
348
- for dataset , dataset_metrics in metrics .items ():
349
- for metric_name , metric_value in dataset_metrics .items ():
350
- table .append ((f"{ metric_name } (on { dataset } )" , str (metric_value )))
351
342
352
- for k , v in meta .items ():
343
+ # the `meta` dict contains another embedded `metrics` dict. To
344
+ # simplify the table generation below, we create the
345
+ # `meta_with_metrics` dict, where the metrics dict has been "flattened"
346
+ meta = copy (field .meta )
347
+ metrics = meta .pop ("metrics" , {})
348
+ meta_with_metrics = dict (meta , ** metrics )
349
+
350
+ # We don't want to document these, they can be too long
351
+ for k in ["categories" , "keypoint_names" ]:
352
+ meta_with_metrics .pop (k , None )
353
+
354
+ custom_docs = meta_with_metrics .pop ("_docs" , None ) # Custom per-Weights docs
355
+ if custom_docs is not None :
356
+ lines += [custom_docs , "" ]
357
+
358
+ for k , v in meta_with_metrics .items ():
353
359
if k in {"recipe" , "license" }:
354
360
v = f"`link <{ v } >`__"
355
361
elif k == "min_size" :
356
362
v = f"height={ v [0 ]} , width={ v [1 ]} "
357
- elif k in {"categories" , "keypoint_names" } and isinstance (v , list ):
358
- max_visible = 3
359
- v_sample = ", " .join (v [:max_visible ])
360
- v = f"{ v_sample } , ... ({ len (v )- max_visible } omitted)" if len (v ) > max_visible else v_sample
361
363
table .append ((str (k ), str (v )))
362
364
table = tabulate (table , tablefmt = "rst" )
363
365
lines += [".. rst-class:: table-weights" ] # Custom CSS class, see custom_torchvision.css
@@ -366,12 +368,12 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
366
368
lines .append ("" )
367
369
lines .append (
368
370
f"The inference transforms are available at ``{ str (field )} .transforms`` and "
369
- f"perform the following preprocessing operations: { field .transforms ().describe ()} "
371
+ f"perform the following operations: { field .transforms ().describe ()} "
370
372
)
371
373
lines .append ("" )
372
374
373
375
374
- def generate_weights_table (module , table_name , metrics , dataset , include_patterns = None , exclude_patterns = None ):
376
+ def generate_weights_table (module , table_name , metrics , include_patterns = None , exclude_patterns = None ):
375
377
weights_endswith = "_QuantizedWeights" if module .__name__ .split ("." )[- 1 ] == "quantization" else "_Weights"
376
378
weight_enums = [getattr (module , name ) for name in dir (module ) if name .endswith (weights_endswith )]
377
379
weights = [w for weight_enum in weight_enums for w in weight_enum ]
@@ -388,7 +390,7 @@ def generate_weights_table(module, table_name, metrics, dataset, include_pattern
388
390
content = [
389
391
(
390
392
f":class:`{ w } <{ type (w ).__name__ } >`" ,
391
- * (w .meta ["_metrics" ][ dataset ][metric ] for metric in metrics_keys ),
393
+ * (w .meta ["metrics" ][metric ] for metric in metrics_keys ),
392
394
f"{ w .meta ['num_params' ]/ 1e6 :.1f} M" ,
393
395
f"`link <{ w .meta ['recipe' ]} >`__" ,
394
396
)
@@ -403,48 +405,32 @@ def generate_weights_table(module, table_name, metrics, dataset, include_pattern
403
405
table_file .write (".. table::\n " )
404
406
table_file .write (f" :widths: 100 { '20 ' * len (metrics_names )} 20 10\n \n " )
405
407
table_file .write (f"{ textwrap .indent (table , ' ' * 4 )} \n \n " )
406
- table_file .write (f"{ table_name } Weights Table \n { 2 * len (table_name )* '*' } \n " )
408
+ table_file .write (f"{ table_name } Weights Table \n { 4 * len (table_name )* '*' } \n " )
407
409
408
410
411
+ generate_weights_table (module = M , table_name = "classification" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )])
409
412
generate_weights_table (
410
- module = M , table_name = "classification" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )], dataset = "ImageNet-1K"
411
- )
412
- generate_weights_table (
413
- module = M .quantization ,
414
- table_name = "classification_quant" ,
415
- metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )],
416
- dataset = "ImageNet-1K" ,
413
+ module = M .quantization , table_name = "classification_quant" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )]
417
414
)
418
415
generate_weights_table (
419
- module = M .detection ,
420
- table_name = "detection" ,
421
- metrics = [("box_map" , "Box MAP" )],
422
- exclude_patterns = ["Mask" , "Keypoint" ],
423
- dataset = "COCO-val2017" ,
416
+ module = M .detection , table_name = "detection" , metrics = [("box_map" , "Box MAP" )], exclude_patterns = ["Mask" , "Keypoint" ]
424
417
)
425
418
generate_weights_table (
426
419
module = M .detection ,
427
420
table_name = "instance_segmentation" ,
428
421
metrics = [("box_map" , "Box MAP" ), ("mask_map" , "Mask MAP" )],
429
- dataset = "COCO-val2017" ,
430
422
include_patterns = ["Mask" ],
431
423
)
432
424
generate_weights_table (
433
425
module = M .detection ,
434
426
table_name = "detection_keypoint" ,
435
427
metrics = [("box_map" , "Box MAP" ), ("kp_map" , "Keypoint MAP" )],
436
- dataset = "COCO-val2017" ,
437
428
include_patterns = ["Keypoint" ],
438
429
)
439
430
generate_weights_table (
440
- module = M .segmentation ,
441
- table_name = "segmentation" ,
442
- metrics = [("miou" , "Mean IoU" ), ("pixel_acc" , "pixelwise Acc" )],
443
- dataset = "COCO-val2017-VOC-labels" ,
444
- )
445
- generate_weights_table (
446
- module = M .video , table_name = "video" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )], dataset = "Kinetics-400"
431
+ module = M .segmentation , table_name = "segmentation" , metrics = [("miou" , "Mean IoU" ), ("pixel_acc" , "pixelwise Acc" )]
447
432
)
433
+ generate_weights_table (module = M .video , table_name = "video" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )])
448
434
449
435
450
436
def setup (app ):
0 commit comments