Skip to content

Commit bfa7ec9

Browse files
committed
Doing some Claude enabled docstring, type annotation and other cleanup
1 parent a0a30a6 commit bfa7ec9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+4847
-1734
lines changed

timm/layers/blur_pool.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
Hacked together by Chris Ha and Ross Wightman
77
"""
88
from functools import partial
9+
from math import comb # Python 3.8
910
from typing import Optional, Type
1011

1112
import torch
1213
import torch.nn as nn
1314
import torch.nn.functional as F
14-
import numpy as np
1515

1616
from .padding import get_padding
1717
from .typing import LayerType
@@ -45,7 +45,11 @@ def __init__(
4545
self.pad_mode = pad_mode
4646
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
4747

48-
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
48+
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
49+
coeffs = torch.tensor(
50+
[comb(filt_size - 1, k) for k in range(filt_size)],
51+
dtype=torch.float32,
52+
) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1
4953
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
5054
if channels is not None:
5155
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)

timm/layers/cond_conv2d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import math
1010
from functools import partial
11-
import numpy as np
1211
import torch
1312
from torch import nn as nn
1413
from torch.nn import functional as F
@@ -21,7 +20,7 @@
2120
def get_condconv_initializer(initializer, num_experts, expert_shape):
2221
def condconv_initializer(weight):
2322
"""CondConv initializer function."""
24-
num_params = np.prod(expert_shape)
23+
num_params = math.prod(expert_shape)
2524
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
2625
weight.shape[1] != num_params):
2726
raise (ValueError(
@@ -75,7 +74,7 @@ def reset_parameters(self):
7574
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
7675
init_weight(self.weight)
7776
if self.bias is not None:
78-
fan_in = np.prod(self.weight_shape[1:])
77+
fan_in = math.prod(self.weight_shape[1:])
7978
bound = 1 / math.sqrt(fan_in)
8079
init_bias = get_condconv_initializer(
8180
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)

timm/models/_builder.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from copy import deepcopy
55
from pathlib import Path
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
77

88
from torch import nn as nn
99
from torch.hub import load_state_dict_from_url
@@ -26,11 +26,21 @@
2626
_CHECK_HASH = False
2727
_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0
2828

29-
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
30-
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
29+
__all__ = [
30+
'set_pretrained_download_progress',
31+
'set_pretrained_check_hash',
32+
'load_custom_pretrained',
33+
'load_pretrained',
34+
'pretrained_cfg_for_features',
35+
'resolve_pretrained_cfg',
36+
'build_model_with_cfg',
37+
]
3138

3239

33-
def _resolve_pretrained_source(pretrained_cfg):
40+
ModelT = TypeVar("ModelT", bound=nn.Module) # any subclass of nn.Module
41+
42+
43+
def _resolve_pretrained_source(pretrained_cfg: Dict[str, Any]) -> Tuple[str, str]:
3444
cfg_source = pretrained_cfg.get('source', '')
3545
pretrained_url = pretrained_cfg.get('url', None)
3646
pretrained_file = pretrained_cfg.get('file', None)
@@ -78,25 +88,25 @@ def _resolve_pretrained_source(pretrained_cfg):
7888
return load_from, pretrained_loc
7989

8090

81-
def set_pretrained_download_progress(enable=True):
91+
def set_pretrained_download_progress(enable: bool = True) -> None:
8292
""" Set download progress for pretrained weights on/off (globally). """
8393
global _DOWNLOAD_PROGRESS
8494
_DOWNLOAD_PROGRESS = enable
8595

8696

87-
def set_pretrained_check_hash(enable=True):
97+
def set_pretrained_check_hash(enable: bool = True) -> None:
8898
""" Set hash checking for pretrained weights on/off (globally). """
8999
global _CHECK_HASH
90100
_CHECK_HASH = enable
91101

92102

93103
def load_custom_pretrained(
94104
model: nn.Module,
95-
pretrained_cfg: Optional[Dict] = None,
105+
pretrained_cfg: Optional[Dict[str, Any]] = None,
96106
load_fn: Optional[Callable] = None,
97107
cache_dir: Optional[Union[str, Path]] = None,
98-
):
99-
r"""Loads a custom (read non .pth) weight file
108+
) -> None:
109+
"""Loads a custom (read non .pth) weight file
100110
101111
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
102112
a passed in custom load fun, or the `load_pretrained` model member fn.
@@ -141,13 +151,13 @@ def load_custom_pretrained(
141151

142152
def load_pretrained(
143153
model: nn.Module,
144-
pretrained_cfg: Optional[Dict] = None,
154+
pretrained_cfg: Optional[Dict[str, Any]] = None,
145155
num_classes: int = 1000,
146156
in_chans: int = 3,
147157
filter_fn: Optional[Callable] = None,
148158
strict: bool = True,
149159
cache_dir: Optional[Union[str, Path]] = None,
150-
):
160+
) -> None:
151161
""" Load pretrained checkpoint
152162
153163
Args:
@@ -278,7 +288,7 @@ def load_pretrained(
278288
f' This may be expected if model is being adapted.')
279289

280290

281-
def pretrained_cfg_for_features(pretrained_cfg):
291+
def pretrained_cfg_for_features(pretrained_cfg: Dict[str, Any]) -> Dict[str, Any]:
282292
pretrained_cfg = deepcopy(pretrained_cfg)
283293
# remove default pretrained cfg fields that don't have much relevance for feature backbone
284294
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
@@ -287,14 +297,14 @@ def pretrained_cfg_for_features(pretrained_cfg):
287297
return pretrained_cfg
288298

289299

290-
def _filter_kwargs(kwargs, names):
300+
def _filter_kwargs(kwargs: Dict[str, Any], names: List[str]) -> None:
291301
if not kwargs or not names:
292302
return
293303
for n in names:
294304
kwargs.pop(n, None)
295305

296306

297-
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
307+
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) -> None:
298308
""" Update the default_cfg and kwargs before passing to model
299309
300310
Args:
@@ -340,6 +350,7 @@ def resolve_pretrained_cfg(
340350
pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
341351
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
342352
) -> PretrainedCfg:
353+
"""Resolve pretrained configuration from various sources."""
343354
model_with_tag = variant
344355
pretrained_tag = None
345356
if pretrained_cfg:
@@ -371,7 +382,7 @@ def resolve_pretrained_cfg(
371382

372383

373384
def build_model_with_cfg(
374-
model_cls: Callable,
385+
model_cls: Union[Type[ModelT], Callable[..., ModelT]],
375386
variant: str,
376387
pretrained: bool,
377388
pretrained_cfg: Optional[Dict] = None,
@@ -383,7 +394,7 @@ def build_model_with_cfg(
383394
cache_dir: Optional[Union[str, Path]] = None,
384395
kwargs_filter: Optional[Tuple[str]] = None,
385396
**kwargs,
386-
):
397+
) -> ModelT:
387398
""" Build model with specified default_cfg and optional model_cfg
388399
389400
This helper fn aids in the construction of a model including:

timm/models/_factory.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
from pathlib import Path
3-
from typing import Any, Dict, Optional, Union
3+
from typing import Any, Dict, Optional, Tuple, Union
44
from urllib.parse import urlsplit
55

6+
from torch import nn
7+
68
from timm.layers import set_layer_config
79
from ._helpers import load_checkpoint
810
from ._hub import load_model_config_from_hf, load_model_config_from_path
@@ -13,7 +15,8 @@
1315
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
1416

1517

16-
def parse_model_name(model_name: str):
18+
def parse_model_name(model_name: str) -> Tuple[Optional[str], str]:
19+
"""Parse source and name from potentially prefixed model name."""
1720
if model_name.startswith('hf_hub'):
1821
# NOTE for backwards compat, deprecate hf_hub use
1922
model_name = model_name.replace('hf_hub', 'hf-hub')
@@ -29,9 +32,9 @@ def parse_model_name(model_name: str):
2932
return None, model_name
3033

3134

32-
def safe_model_name(model_name: str, remove_source: bool = True):
33-
# return a filename / path safe model name
34-
def make_safe(name):
35+
def safe_model_name(model_name: str, remove_source: bool = True) -> str:
36+
"""Return a filename / path safe model name."""
37+
def make_safe(name: str) -> str:
3538
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
3639
if remove_source:
3740
model_name = parse_model_name(model_name)[-1]
@@ -42,14 +45,14 @@ def create_model(
4245
model_name: str,
4346
pretrained: bool = False,
4447
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
45-
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
48+
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
4649
checkpoint_path: Optional[Union[str, Path]] = None,
4750
cache_dir: Optional[Union[str, Path]] = None,
4851
scriptable: Optional[bool] = None,
4952
exportable: Optional[bool] = None,
5053
no_jit: Optional[bool] = None,
51-
**kwargs,
52-
):
54+
**kwargs: Any,
55+
) -> nn.Module:
5356
"""Create a model.
5457
5558
Lookup model's entrypoint function and pass relevant args to create a new model.

timm/models/_helpers.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from typing import Any, Callable, Dict, Optional, Union
88

99
import torch
10+
1011
try:
1112
import safetensors.torch
13+
1214
_has_safetensors = True
1315
except ImportError:
1416
_has_safetensors = False
@@ -18,7 +20,7 @@
1820
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
1921

2022

21-
def _remove_prefix(text, prefix):
23+
def _remove_prefix(text: str, prefix: str) -> str:
2224
# FIXME replace with 3.9 stdlib fn when min at 3.9
2325
if text.startswith(prefix):
2426
return text[len(prefix):]
@@ -45,6 +47,17 @@ def load_state_dict(
4547
device: Union[str, torch.device] = 'cpu',
4648
weights_only: bool = False,
4749
) -> Dict[str, Any]:
50+
"""Load state dictionary from checkpoint file.
51+
52+
Args:
53+
checkpoint_path: Path to checkpoint file.
54+
use_ema: Whether to use EMA weights if available.
55+
device: Device to load checkpoint to.
56+
weights_only: Whether to load only weights (torch.load parameter).
57+
58+
Returns:
59+
State dictionary loaded from checkpoint.
60+
"""
4861
if checkpoint_path and os.path.isfile(checkpoint_path):
4962
# Check if safetensors or not and load weights accordingly
5063
if str(checkpoint_path).endswith(".safetensors"):
@@ -83,7 +96,22 @@ def load_checkpoint(
8396
remap: bool = False,
8497
filter_fn: Optional[Callable] = None,
8598
weights_only: bool = False,
86-
):
99+
) -> Any:
100+
"""Load checkpoint into model.
101+
102+
Args:
103+
model: Model to load checkpoint into.
104+
checkpoint_path: Path to checkpoint file.
105+
use_ema: Whether to use EMA weights if available.
106+
device: Device to load checkpoint to.
107+
strict: Whether to strictly enforce state_dict keys match.
108+
remap: Whether to remap state dict keys by order.
109+
filter_fn: Optional function to filter state dict.
110+
weights_only: Whether to load only weights (torch.load parameter).
111+
112+
Returns:
113+
Incompatible keys from model.load_state_dict().
114+
"""
87115
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
88116
# numpy checkpoint, try to load via model specific load_pretrained fn
89117
if hasattr(model, 'load_pretrained'):
@@ -105,9 +133,18 @@ def remap_state_dict(
105133
state_dict: Dict[str, Any],
106134
model: torch.nn.Module,
107135
allow_reshape: bool = True
108-
):
109-
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
136+
) -> Dict[str, Any]:
137+
"""Remap checkpoint by iterating over state dicts in order (ignoring original keys).
138+
110139
This assumes models (and originating state dict) were created with params registered in same order.
140+
141+
Args:
142+
state_dict: State dict to remap.
143+
model: Model whose state dict keys to use.
144+
allow_reshape: Whether to allow reshaping tensors to match.
145+
146+
Returns:
147+
Remapped state dictionary.
111148
"""
112149
out_dict = {}
113150
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
@@ -116,18 +153,30 @@ def remap_state_dict(
116153
if allow_reshape:
117154
vb = vb.reshape(va.shape)
118155
else:
119-
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
156+
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
120157
out_dict[ka] = vb
121158
return out_dict
122159

123160

124161
def resume_checkpoint(
125162
model: torch.nn.Module,
126163
checkpoint_path: str,
127-
optimizer: torch.optim.Optimizer = None,
128-
loss_scaler: Any = None,
164+
optimizer: Optional[torch.optim.Optimizer] = None,
165+
loss_scaler: Optional[Any] = None,
129166
log_info: bool = True,
130-
):
167+
) -> Optional[int]:
168+
"""Resume training from checkpoint.
169+
170+
Args:
171+
model: Model to load checkpoint into.
172+
checkpoint_path: Path to checkpoint file.
173+
optimizer: Optional optimizer to restore state.
174+
loss_scaler: Optional AMP loss scaler to restore state.
175+
log_info: Whether to log loading info.
176+
177+
Returns:
178+
Resume epoch number if available, else None.
179+
"""
131180
resume_epoch = None
132181
if os.path.isfile(checkpoint_path):
133182
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
@@ -162,5 +211,3 @@ def resume_checkpoint(
162211
else:
163212
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
164213
raise FileNotFoundError()
165-
166-

0 commit comments

Comments
 (0)