Skip to content

Commit 658ca53

Browse files
authored
cleanup prototype transforms functional tests (#6622)
* cleanup prototype transforms functional tests * fix * oust local functions
1 parent f49edd3 commit 658ca53

5 files changed

+201
-402
lines changed

test/common_utils.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,8 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
205205

206206

207207
def cache(fn):
208-
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite buffer size,
209-
but also caches exceptions.
210-
211-
.. warning::
212-
213-
Only use this on deterministic functions.
208+
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
209+
but this also caches exceptions.
214210
"""
215211
sentinel = object()
216212
out_cache = {}
@@ -238,11 +234,3 @@ def wrapper(*args, **kwargs):
238234
return out
239235

240236
return wrapper
241-
242-
243-
@cache
244-
def script(fn):
245-
try:
246-
return torch.jit.script(fn)
247-
except Exception as error:
248-
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error

test/prototype_transforms_dispatcher_infos.py

-24
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,16 @@
11
import dataclasses
2-
import functools
32
from typing import Callable, Dict, Type
43

54
import pytest
6-
import torch
75
import torchvision.prototype.transforms.functional as F
8-
from prototype_common_utils import ArgsKwargs
96
from prototype_transforms_kernel_infos import KERNEL_INFOS
10-
from test_prototype_transforms_functional import FUNCTIONAL_INFOS
117
from torchvision.prototype import features
128

139
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
1410

1511
KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}
1612

1713

18-
# Helper class to use the infos from the old framework for now tests
19-
class PreloadedArgsKwargs(ArgsKwargs):
20-
def load(self, device="cpu"):
21-
args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in self.args)
22-
kwargs = {
23-
keyword: arg.to(device) if isinstance(arg, torch.Tensor) else arg for keyword, arg in self.kwargs.items()
24-
}
25-
return args, kwargs
26-
27-
28-
def preloaded_sample_inputs(args_kwargs):
29-
for args, kwargs in args_kwargs:
30-
yield PreloadedArgsKwargs(*args, **kwargs)
31-
32-
33-
KERNEL_SAMPLE_INPUTS_FN_MAP.update(
34-
{info.functional: functools.partial(preloaded_sample_inputs, info.sample_inputs()) for info in FUNCTIONAL_INFOS}
35-
)
36-
37-
3814
@dataclasses.dataclass
3915
class DispatcherInfo:
4016
dispatcher: Callable

test/test_prototype_transforms_dispatchers.py

-31
This file was deleted.

0 commit comments

Comments
 (0)