Skip to content

Doing some Claude enabled docstring, type annotation and other cleanup #2504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions timm/layers/blur_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
Hacked together by Chris Ha and Ross Wightman
"""
from functools import partial
from math import comb # Python 3.8
from typing import Optional, Type

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from .padding import get_padding
from .typing import LayerType
Expand Down Expand Up @@ -45,7 +45,11 @@ def __init__(
self.pad_mode = pad_mode
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4

coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
coeffs = torch.tensor(
[comb(filt_size - 1, k) for k in range(filt_size)],
dtype=torch.float32,
) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
if channels is not None:
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)
Expand Down
5 changes: 2 additions & 3 deletions timm/layers/cond_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import math
from functools import partial
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
Expand All @@ -21,7 +20,7 @@
def get_condconv_initializer(initializer, num_experts, expert_shape):
def condconv_initializer(weight):
"""CondConv initializer function."""
num_params = np.prod(expert_shape)
num_params = math.prod(expert_shape)
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
weight.shape[1] != num_params):
raise (ValueError(
Expand Down Expand Up @@ -75,7 +74,7 @@ def reset_parameters(self):
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
init_weight(self.weight)
if self.bias is not None:
fan_in = np.prod(self.weight_shape[1:])
fan_in = math.prod(self.weight_shape[1:])
bound = 1 / math.sqrt(fan_in)
init_bias = get_condconv_initializer(
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
Expand Down
43 changes: 27 additions & 16 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

from torch import nn as nn
from torch.hub import load_state_dict_from_url
Expand All @@ -26,11 +26,21 @@
_CHECK_HASH = False
_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0

__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
__all__ = [
'set_pretrained_download_progress',
'set_pretrained_check_hash',
'load_custom_pretrained',
'load_pretrained',
'pretrained_cfg_for_features',
'resolve_pretrained_cfg',
'build_model_with_cfg',
]


def _resolve_pretrained_source(pretrained_cfg):
ModelT = TypeVar("ModelT", bound=nn.Module) # any subclass of nn.Module


def _resolve_pretrained_source(pretrained_cfg: Dict[str, Any]) -> Tuple[str, str]:
cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
Expand Down Expand Up @@ -78,25 +88,25 @@ def _resolve_pretrained_source(pretrained_cfg):
return load_from, pretrained_loc


def set_pretrained_download_progress(enable=True):
def set_pretrained_download_progress(enable: bool = True) -> None:
""" Set download progress for pretrained weights on/off (globally). """
global _DOWNLOAD_PROGRESS
_DOWNLOAD_PROGRESS = enable


def set_pretrained_check_hash(enable=True):
def set_pretrained_check_hash(enable: bool = True) -> None:
""" Set hash checking for pretrained weights on/off (globally). """
global _CHECK_HASH
_CHECK_HASH = enable


def load_custom_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
pretrained_cfg: Optional[Dict[str, Any]] = None,
load_fn: Optional[Callable] = None,
cache_dir: Optional[Union[str, Path]] = None,
):
r"""Loads a custom (read non .pth) weight file
) -> None:
"""Loads a custom (read non .pth) weight file

Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
a passed in custom load fun, or the `load_pretrained` model member fn.
Expand Down Expand Up @@ -141,13 +151,13 @@ def load_custom_pretrained(

def load_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
pretrained_cfg: Optional[Dict[str, Any]] = None,
num_classes: int = 1000,
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
cache_dir: Optional[Union[str, Path]] = None,
):
) -> None:
""" Load pretrained checkpoint

Args:
Expand Down Expand Up @@ -278,7 +288,7 @@ def load_pretrained(
f' This may be expected if model is being adapted.')


def pretrained_cfg_for_features(pretrained_cfg):
def pretrained_cfg_for_features(pretrained_cfg: Dict[str, Any]) -> Dict[str, Any]:
pretrained_cfg = deepcopy(pretrained_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
Expand All @@ -287,14 +297,14 @@ def pretrained_cfg_for_features(pretrained_cfg):
return pretrained_cfg


def _filter_kwargs(kwargs, names):
def _filter_kwargs(kwargs: Dict[str, Any], names: List[str]) -> None:
if not kwargs or not names:
return
for n in names:
kwargs.pop(n, None)


def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) -> None:
""" Update the default_cfg and kwargs before passing to model

Args:
Expand Down Expand Up @@ -340,6 +350,7 @@ def resolve_pretrained_cfg(
pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
) -> PretrainedCfg:
"""Resolve pretrained configuration from various sources."""
model_with_tag = variant
pretrained_tag = None
if pretrained_cfg:
Expand Down Expand Up @@ -371,7 +382,7 @@ def resolve_pretrained_cfg(


def build_model_with_cfg(
model_cls: Callable,
model_cls: Union[Type[ModelT], Callable[..., ModelT]],
variant: str,
pretrained: bool,
pretrained_cfg: Optional[Dict] = None,
Expand All @@ -383,7 +394,7 @@ def build_model_with_cfg(
cache_dir: Optional[Union[str, Path]] = None,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs,
):
) -> ModelT:
""" Build model with specified default_cfg and optional model_cfg

This helper fn aids in the construction of a model including:
Expand Down
19 changes: 11 additions & 8 deletions timm/models/_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
from urllib.parse import urlsplit

from torch import nn

from timm.layers import set_layer_config
from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf, load_model_config_from_path
Expand All @@ -13,7 +15,8 @@
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']


def parse_model_name(model_name: str):
def parse_model_name(model_name: str) -> Tuple[Optional[str], str]:
"""Parse source and name from potentially prefixed model name."""
if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
Expand All @@ -29,9 +32,9 @@ def parse_model_name(model_name: str):
return None, model_name


def safe_model_name(model_name: str, remove_source: bool = True):
# return a filename / path safe model name
def make_safe(name):
def safe_model_name(model_name: str, remove_source: bool = True) -> str:
"""Return a filename / path safe model name."""
def make_safe(name: str) -> str:
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
if remove_source:
model_name = parse_model_name(model_name)[-1]
Expand All @@ -42,14 +45,14 @@ def create_model(
model_name: str,
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: Optional[Union[str, Path]] = None,
cache_dir: Optional[Union[str, Path]] = None,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
**kwargs,
):
**kwargs: Any,
) -> nn.Module:
"""Create a model.

Lookup model's entrypoint function and pass relevant args to create a new model.
Expand Down
67 changes: 57 additions & 10 deletions timm/models/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from typing import Any, Callable, Dict, Optional, Union

import torch

try:
import safetensors.torch

_has_safetensors = True
except ImportError:
_has_safetensors = False
Expand All @@ -18,7 +20,7 @@
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']


def _remove_prefix(text, prefix):
def _remove_prefix(text: str, prefix: str) -> str:
# FIXME replace with 3.9 stdlib fn when min at 3.9
if text.startswith(prefix):
return text[len(prefix):]
Expand All @@ -45,6 +47,17 @@ def load_state_dict(
device: Union[str, torch.device] = 'cpu',
weights_only: bool = False,
) -> Dict[str, Any]:
"""Load state dictionary from checkpoint file.

Args:
checkpoint_path: Path to checkpoint file.
use_ema: Whether to use EMA weights if available.
device: Device to load checkpoint to.
weights_only: Whether to load only weights (torch.load parameter).

Returns:
State dictionary loaded from checkpoint.
"""
if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
Expand Down Expand Up @@ -83,7 +96,22 @@ def load_checkpoint(
remap: bool = False,
filter_fn: Optional[Callable] = None,
weights_only: bool = False,
):
) -> Any:
"""Load checkpoint into model.

Args:
model: Model to load checkpoint into.
checkpoint_path: Path to checkpoint file.
use_ema: Whether to use EMA weights if available.
device: Device to load checkpoint to.
strict: Whether to strictly enforce state_dict keys match.
remap: Whether to remap state dict keys by order.
filter_fn: Optional function to filter state dict.
weights_only: Whether to load only weights (torch.load parameter).

Returns:
Incompatible keys from model.load_state_dict().
"""
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
Expand All @@ -105,9 +133,18 @@ def remap_state_dict(
state_dict: Dict[str, Any],
model: torch.nn.Module,
allow_reshape: bool = True
):
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
) -> Dict[str, Any]:
"""Remap checkpoint by iterating over state dicts in order (ignoring original keys).

This assumes models (and originating state dict) were created with params registered in same order.

Args:
state_dict: State dict to remap.
model: Model whose state dict keys to use.
allow_reshape: Whether to allow reshaping tensors to match.

Returns:
Remapped state dictionary.
"""
out_dict = {}
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
Expand All @@ -116,18 +153,30 @@ def remap_state_dict(
if allow_reshape:
vb = vb.reshape(va.shape)
else:
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
out_dict[ka] = vb
return out_dict


def resume_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
optimizer: torch.optim.Optimizer = None,
loss_scaler: Any = None,
optimizer: Optional[torch.optim.Optimizer] = None,
loss_scaler: Optional[Any] = None,
log_info: bool = True,
):
) -> Optional[int]:
"""Resume training from checkpoint.

Args:
model: Model to load checkpoint into.
checkpoint_path: Path to checkpoint file.
optimizer: Optional optimizer to restore state.
loss_scaler: Optional AMP loss scaler to restore state.
log_info: Whether to log loading info.

Returns:
Resume epoch number if available, else None.
"""
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
Expand Down Expand Up @@ -162,5 +211,3 @@ def resume_checkpoint(
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()


Loading