Skip to content

Commit 8f07159

Browse files
committed
Convert to floats at the beginning.
1 parent bda072d commit 8f07159

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

references/classification/presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
):
2020
trans = [
2121
transforms.ToImageTensor(),
22+
transforms.ConvertImageDtype(torch.float),
2223
transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True),
2324
]
2425
if hflip_prob > 0:
@@ -35,7 +36,6 @@ def __init__(
3536
trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation))
3637
trans.extend(
3738
[
38-
transforms.ConvertImageDtype(torch.float),
3939
transforms.Normalize(mean=mean, std=std),
4040
]
4141
)
@@ -62,9 +62,9 @@ def __init__(
6262
self.transforms = transforms.Compose(
6363
[
6464
transforms.ToImageTensor(),
65+
transforms.ConvertImageDtype(torch.float),
6566
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
6667
transforms.CenterCrop(crop_size),
67-
transforms.ConvertImageDtype(torch.float),
6868
transforms.Normalize(mean=mean, std=std),
6969
]
7070
)

0 commit comments

Comments
 (0)