Skip to content

Commit 86dce6f

Browse files
Add antialias to layers.Resizing and add more tests. (#20972)
1 parent d1fb581 commit 86dce6f

File tree

2 files changed

+36
-63
lines changed

2 files changed

+36
-63
lines changed

keras/src/layers/preprocessing/image_preprocessing/resizing.py

+4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
pad_to_aspect_ratio=False,
8080
fill_mode="constant",
8181
fill_value=0.0,
82+
antialias=False,
8283
data_format=None,
8384
**kwargs,
8485
):
@@ -91,6 +92,7 @@ def __init__(
9192
self.pad_to_aspect_ratio = pad_to_aspect_ratio
9293
self.fill_mode = fill_mode
9394
self.fill_value = fill_value
95+
self.antialias = bool(antialias)
9496
if self.data_format == "channels_first":
9597
self.height_axis = -2
9698
self.width_axis = -1
@@ -104,6 +106,7 @@ def transform_images(self, images, transformation=None, training=True):
104106
images,
105107
size=size,
106108
interpolation=self.interpolation,
109+
antialias=self.antialias,
107110
data_format=self.data_format,
108111
crop_to_aspect_ratio=self.crop_to_aspect_ratio,
109112
pad_to_aspect_ratio=self.pad_to_aspect_ratio,
@@ -299,6 +302,7 @@ def get_config(self):
299302
"pad_to_aspect_ratio": self.pad_to_aspect_ratio,
300303
"fill_mode": self.fill_mode,
301304
"fill_value": self.fill_value,
305+
"antialias": self.antialias,
302306
"data_format": self.data_format,
303307
}
304308
return {**base_config, **config}

keras/src/layers/preprocessing/image_preprocessing/resizing_test.py

+32-63
Original file line numberDiff line numberDiff line change
@@ -7,80 +7,49 @@
77
from keras.src import backend
88
from keras.src import layers
99
from keras.src import testing
10+
from keras.src.testing.test_utils import named_product
1011

1112

1213
class ResizingTest(testing.TestCase):
13-
def test_resizing_basics(self):
14-
self.run_layer_test(
15-
layers.Resizing,
16-
init_kwargs={
17-
"height": 6,
18-
"width": 6,
19-
"data_format": "channels_last",
20-
"interpolation": "bicubic",
21-
"crop_to_aspect_ratio": True,
22-
},
23-
input_shape=(2, 12, 12, 3),
24-
expected_output_shape=(2, 6, 6, 3),
25-
expected_num_trainable_weights=0,
26-
expected_num_non_trainable_weights=0,
27-
expected_num_seed_generators=0,
28-
expected_num_losses=0,
29-
supports_masking=False,
30-
run_training_check=False,
31-
)
32-
self.run_layer_test(
33-
layers.Resizing,
34-
init_kwargs={
35-
"height": 6,
36-
"width": 6,
37-
"data_format": "channels_first",
38-
"interpolation": "bilinear",
39-
"crop_to_aspect_ratio": True,
40-
},
41-
input_shape=(2, 3, 12, 12),
42-
expected_output_shape=(2, 3, 6, 6),
43-
expected_num_trainable_weights=0,
44-
expected_num_non_trainable_weights=0,
45-
expected_num_seed_generators=0,
46-
expected_num_losses=0,
47-
supports_masking=False,
48-
run_training_check=False,
49-
)
50-
self.run_layer_test(
51-
layers.Resizing,
52-
init_kwargs={
53-
"height": 6,
54-
"width": 6,
55-
"data_format": "channels_last",
56-
"interpolation": "nearest",
57-
"crop_to_aspect_ratio": False,
58-
},
59-
input_shape=(2, 12, 12, 3),
60-
expected_output_shape=(2, 6, 6, 3),
61-
expected_num_trainable_weights=0,
62-
expected_num_non_trainable_weights=0,
63-
expected_num_seed_generators=0,
64-
expected_num_losses=0,
65-
supports_masking=False,
66-
run_training_check=False,
14+
@parameterized.named_parameters(
15+
named_product(
16+
interpolation=["nearest", "bilinear", "bicubic", "lanczos5"],
17+
crop_pad=[(False, False), (True, False), (False, True)],
18+
antialias=[False, True],
19+
data_format=["channels_last", "channels_first"],
6720
)
68-
69-
@pytest.mark.skipif(
70-
backend.backend() == "torch", reason="Torch does not support lanczos."
7121
)
72-
def test_resizing_basics_lanczos5(self):
22+
def test_resizing_basics(
23+
self,
24+
interpolation,
25+
crop_pad,
26+
antialias,
27+
data_format,
28+
):
29+
if interpolation == "lanczos5" and backend.backend() == "torch":
30+
self.skipTest("Torch does not support lanczos.")
31+
32+
crop_to_aspect_ratio, pad_to_aspect_ratio = crop_pad
33+
if data_format == "channels_last":
34+
input_shape = (2, 12, 12, 3)
35+
expected_output_shape = (2, 6, 6, 3)
36+
else:
37+
input_shape = (2, 3, 12, 12)
38+
expected_output_shape = (2, 3, 6, 6)
39+
7340
self.run_layer_test(
7441
layers.Resizing,
7542
init_kwargs={
7643
"height": 6,
7744
"width": 6,
78-
"data_format": "channels_first",
79-
"interpolation": "lanczos5",
80-
"crop_to_aspect_ratio": False,
45+
"interpolation": interpolation,
46+
"crop_to_aspect_ratio": crop_to_aspect_ratio,
47+
"pad_to_aspect_ratio": pad_to_aspect_ratio,
48+
"antialias": antialias,
49+
"data_format": data_format,
8150
},
82-
input_shape=(2, 3, 12, 12),
83-
expected_output_shape=(2, 3, 6, 6),
51+
input_shape=input_shape,
52+
expected_output_shape=expected_output_shape,
8453
expected_num_trainable_weights=0,
8554
expected_num_non_trainable_weights=0,
8655
expected_num_seed_generators=0,

0 commit comments

Comments
 (0)