Skip to content

Commit 7412791

Browse files
committed
Add support for optional Functional inputs.
1 parent 7a81739 commit 7412791

File tree

8 files changed

+77
-43
lines changed

8 files changed

+77
-43
lines changed

keras/src/layers/core/input_layer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
sparse=None,
1717
batch_shape=None,
1818
input_tensor=None,
19+
optional=False,
1920
name=None,
2021
**kwargs,
2122
):
@@ -69,6 +70,7 @@ def __init__(
6970
self._input_tensor = input_tensor
7071
Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor)
7172
self.built = True
73+
self.optional = optional
7274

7375
def call(self):
7476
return
@@ -95,6 +97,7 @@ def Input(
9597
batch_shape=None,
9698
name=None,
9799
tensor=None,
100+
optional=False,
98101
):
99102
"""Used to instantiate a Keras tensor.
100103
@@ -127,6 +130,8 @@ def Input(
127130
tensor: Optional existing tensor to wrap into the `Input` layer.
128131
If set, the layer will use this tensor rather
129132
than creating a new placeholder tensor.
133+
optional: Boolean, whether the input is optional or not.
134+
An optional input can accept `None` values.
130135
131136
Returns:
132137
A Keras tensor.
@@ -148,5 +153,6 @@ def Input(
148153
batch_shape=batch_shape,
149154
name=name,
150155
input_tensor=tensor,
156+
optional=optional,
151157
)
152158
return layer.output

keras/src/layers/input_spec.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class InputSpec:
3030
as long as the last axis of the spec is 1.
3131
name: Expected key corresponding to this input when passing data as
3232
a dictionary.
33+
optional: Boolean, whether the input is optional or not.
34+
An optional input can accept `None` values.
3335
3436
Example:
3537
@@ -56,6 +58,7 @@ def __init__(
5658
axes=None,
5759
allow_last_axis_squeeze=False,
5860
name=None,
61+
optional=False,
5962
):
6063
self.dtype = (
6164
backend.standardize_dtype(dtype) if dtype is not None else None
@@ -69,6 +72,7 @@ def __init__(
6972
self.max_ndim = max_ndim
7073
self.min_ndim = min_ndim
7174
self.name = name
75+
self.optional = optional
7276
self.allow_last_axis_squeeze = allow_last_axis_squeeze
7377
try:
7478
axes = axes or {}
@@ -152,12 +156,18 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
152156
inputs = list_inputs
153157

154158
inputs = tree.flatten(inputs)
155-
if len(input_spec) != len(inputs):
159+
if len(inputs) != len(input_spec):
156160
raise ValueError(
157-
f"Layer '{layer_name}' expected {len(input_spec)} input(s). "
158-
f"Received {len(inputs)} instead."
161+
f'Layer "{layer_name}" expects {len(input_spec)} input(s),'
162+
f" but it received {len(inputs)} input tensors. "
163+
f"Inputs received: {inputs}"
159164
)
160-
for x in inputs:
165+
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
166+
if spec is None:
167+
continue
168+
if x is None and spec.optional:
169+
continue
170+
161171
# Having a shape/dtype is the only commonality of the various
162172
# tensor-like objects that may be passed. The most common kind of
163173
# invalid type we are guarding for is a Layer instance (Functional API),
@@ -168,16 +178,6 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
168178
f"(of type {type(x)}) as input for layer '{layer_name}'."
169179
)
170180

171-
if len(inputs) != len(input_spec):
172-
raise ValueError(
173-
f'Layer "{layer_name}" expects {len(input_spec)} input(s),'
174-
f" but it received {len(inputs)} input tensors. "
175-
f"Inputs received: {inputs}"
176-
)
177-
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
178-
if spec is None:
179-
continue
180-
181181
shape = backend.standardize_shape(x.shape)
182182
ndim = len(shape)
183183
# Check ndim.

keras/src/layers/layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,10 @@ def maybe_convert(x):
747747
# 2. Enforce that only tensors can be passed positionally.
748748
if not self._allow_non_tensor_positional_args:
749749
for arg in tree.flatten(args):
750-
if not isinstance(arg, KerasTensor) and not backend.is_tensor(
751-
arg
750+
if (
751+
not isinstance(arg, KerasTensor)
752+
and not backend.is_tensor(arg)
753+
and arg is not None
752754
):
753755
raise ValueError(
754756
"Only input tensors may be passed as "

keras/src/models/functional.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,23 @@ def _flatten_to_reference_inputs(self, inputs):
207207
def _convert_inputs_to_tensors(self, flat_inputs):
208208
converted = []
209209
for x, input in zip(flat_inputs, self._inputs):
210-
converted.append(
211-
ops.convert_to_tensor(x, dtype=input.dtype, sparse=input.sparse)
212-
)
210+
if x is None: # TODO: check if optional
211+
converted.append(x)
212+
else:
213+
converted.append(
214+
ops.convert_to_tensor(
215+
x, dtype=input.dtype, sparse=input.sparse
216+
)
217+
)
213218
return converted
214219

215220
def _adjust_input_rank(self, flat_inputs):
216221
flat_ref_shapes = [x.shape for x in self._inputs]
217222
adjusted = []
218223
for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
224+
if x is None:
225+
adjusted.append(x)
226+
continue
219227
x_rank = len(x.shape)
220228
ref_rank = len(ref_shape)
221229
if x_rank == ref_rank:
@@ -273,10 +281,15 @@ def shape_with_no_batch_size(x):
273281
return tuple(x)
274282

275283
def make_spec_for_tensor(x):
284+
optional = False
285+
if isinstance(x._keras_history[0], InputLayer):
286+
if x._keras_history[0].optional:
287+
optional = True
276288
return InputSpec(
277289
shape=shape_with_no_batch_size(x.shape),
278290
allow_last_axis_squeeze=True,
279291
name=x._keras_history[0].name,
292+
optional=optional,
280293
)
281294

282295
if isinstance(self._inputs_struct, dict):

keras/src/models/functional_test.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def test_bad_input_spec(self):
304304
ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)"
305305
):
306306
model(np.zeros((2, 3)))
307-
with self.assertRaisesRegex(ValueError, "expected 1 input"):
307+
with self.assertRaisesRegex(ValueError, "expects 1 input"):
308308
model([np.zeros((2, 4)), np.zeros((2, 4))])
309309

310310
# List input
@@ -313,7 +313,7 @@ def test_bad_input_spec(self):
313313
x = input_a + input_b
314314
outputs = layers.Dense(2)(x)
315315
model = Functional([input_a, input_b], outputs)
316-
with self.assertRaisesRegex(ValueError, "expected 2 input"):
316+
with self.assertRaisesRegex(ValueError, "expects 2 input"):
317317
model(np.zeros((2, 3)))
318318
with self.assertRaisesRegex(
319319
ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)"
@@ -322,7 +322,7 @@ def test_bad_input_spec(self):
322322

323323
# Dict input
324324
model = Functional({"a": input_a, "b": input_b}, outputs)
325-
with self.assertRaisesRegex(ValueError, "expected 2 input"):
325+
with self.assertRaisesRegex(ValueError, "expects 2 input"):
326326
model(np.zeros((2, 3)))
327327
with self.assertRaisesRegex(
328328
ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)"
@@ -432,6 +432,26 @@ def test_deeply_nested_model(self):
432432
out_eager["others"]["3"], new_out_eager["others"]["3"]
433433
)
434434

435+
def test_optional_inputs(self):
436+
class OptionalInputLayer(layers.Layer):
437+
def call(self, x, y=None):
438+
if y is not None:
439+
return x + y
440+
return x
441+
442+
def compute_output_shape(self, x_shape):
443+
return x_shape
444+
445+
i1 = Input((2,))
446+
i2 = Input((2,), optional=True)
447+
outputs = OptionalInputLayer()(i1, i2)
448+
model = Model([i1, i2], outputs)
449+
450+
# Eager test
451+
out = model([np.ones((2, 2)), None])
452+
self.assertAllClose(out, np.ones((2, 2)))
453+
# Note: it's not intended to work in symbolic mode (yet).
454+
435455
def test_add_loss(self):
436456
# TODO
437457
pass

keras/src/ops/operation.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,16 @@ def compute_output_spec(self, *args, **kwargs):
7979
try:
8080
return backend.compute_output_spec(self.call, *args, **kwargs)
8181
except Exception as e:
82-
if isinstance(e, TypeError):
83-
raise e
84-
else:
85-
new_e = RuntimeError(
86-
"Could not automatically infer the output shape / dtype of "
87-
f"'{self.name}' (of type {self.__class__.__name__}). "
88-
f"Either the `{self.__class__.__name__}.call()` method "
89-
f"is incorrect, or you need to implement the "
90-
f"`{self.__class__.__name__}.compute_output_spec() / "
91-
"compute_output_shape()` method. "
92-
f"Error encountered:\n\n{e}"
93-
)
94-
raise new_e.with_traceback(e.__traceback__) from None
82+
new_e = e.__class__(
83+
"Could not automatically infer the output shape / dtype of "
84+
f"'{self.name}' (of type {self.__class__.__name__}). "
85+
f"Either the `{self.__class__.__name__}.call()` method "
86+
f"is incorrect, or you need to implement the "
87+
f"`{self.__class__.__name__}.compute_output_spec() / "
88+
"compute_output_shape()` method. "
89+
f"Error encountered:\n\n{e}"
90+
)
91+
raise new_e.with_traceback(e.__traceback__) from None
9592

9693
def __new__(cls, *args, **kwargs):
9794
"""We override __new__ to saving serializable constructor arguments.

keras/src/ops/symbolic_arguments.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ def fill_in(self, tensor_dict):
4040

4141
def switch_fn(x):
4242
if isinstance(x, KerasTensor):
43-
val = tensor_dict.get(id(x), None)
44-
if val is not None:
45-
return val
43+
return tensor_dict.get(id(x), None)
4644
return x
4745

4846
return self.convert(switch_fn)

keras/src/ops/symbolic_arguments_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ def test_fill_in_multiple_arg(self):
9999

100100
# Call the method to be tested
101101
result, _ = sym_args.fill_in(tensor_dict)
102-
103-
self.assertEqual(result, ((a, 2),))
102+
self.assertEqual(result, ((None, 2),))
104103

105104
# Testing fill in function for args and kwargs
106105
def test_fill_in(self):
@@ -115,9 +114,8 @@ def test_fill_in(self):
115114
a,
116115
b,
117116
),
118-
{1: c},
117+
{"1": c},
119118
)
120119

121120
(values, _) = sym_args.fill_in(dictionary)
122-
123-
self.assertEqual(values, ((3, b), {1: 2}))
121+
self.assertEqual(values, ((3, None), {"1": 2}))

0 commit comments

Comments
 (0)