@@ -117,6 +117,24 @@ def _norm_to_list_of_layers(maybe_layers):
117
117
)
118
118
119
119
120
+ def _get_inbound_nodes (layer ):
121
+ """Returns a list of [name, size, index] for all inbound nodes of the given layer."""
122
+ inbound_nodes = []
123
+ if layer .get ("inbound_nodes" ) is not None :
124
+ for maybe_inbound_node in layer .get ("inbound_nodes" , []):
125
+ for inbound_node_args in maybe_inbound_node .get ("args" , []):
126
+ # Sometimes this field is a list when there are multiple inbound nodes
127
+ # for the given layer.
128
+ if not isinstance (inbound_node_args , list ):
129
+ inbound_node_args = [inbound_node_args ]
130
+ for arg in inbound_node_args :
131
+ history = arg .get ("config" , {}).get ("keras_history" , [])
132
+ if len (history ) < 3 :
133
+ continue
134
+ inbound_nodes .append (history [:3 ])
135
+ return inbound_nodes
136
+
137
+
120
138
def _update_dicts (
121
139
name_scope ,
122
140
model_layer ,
@@ -149,7 +167,7 @@ def _update_dicts(
149
167
node_name = _scoped_name (name_scope , layer_config .get ("name" ))
150
168
input_layers = layer_config .get ("input_layers" )
151
169
output_layers = layer_config .get ("output_layers" )
152
- inbound_nodes = model_layer . get ( "inbound_nodes" )
170
+ inbound_nodes = _get_inbound_nodes ( model_layer )
153
171
154
172
is_functional_model = bool (input_layers and output_layers )
155
173
# In case of [1] and the parent model is functional, current layer
@@ -164,7 +182,7 @@ def _update_dicts(
164
182
elif is_parent_functional_model and not is_functional_model :
165
183
# Sequential model can take only one input. Make sure inbound to the
166
184
# model is linked to the first layer in the Sequential model.
167
- prev_node_name = _scoped_name (name_scope , inbound_nodes [0 ][0 ][ 0 ] )
185
+ prev_node_name = _scoped_name (name_scope , inbound_nodes [0 ][0 ])
168
186
elif (
169
187
not is_parent_functional_model
170
188
and prev_node_name
@@ -244,33 +262,31 @@ def keras_model_to_graph_def(keras_layer):
244
262
tf_dtype = dtypes .as_dtype (layer_config .get ("dtype" ))
245
263
node_def .attr ["dtype" ].type = tf_dtype .as_datatype_enum
246
264
if layer .get ("inbound_nodes" ) is not None :
247
- for maybe_inbound_node in layer .get ("inbound_nodes" ):
248
- inbound_nodes = _norm_to_list_of_layers (maybe_inbound_node )
249
- for [name , size , index , _ ] in inbound_nodes :
250
- inbound_name = _scoped_name (name_scope , name )
251
- # An input to a layer can be output from a model. In that case, the name
252
- # of inbound_nodes to a layer is a name of a model. Remap the name of the
253
- # model to output layer of the model. Also, since there can be multiple
254
- # outputs in a model, make sure we pick the right output_layer from the model.
255
- inbound_node_names = model_name_to_output .get (
256
- inbound_name , [inbound_name ]
257
- )
258
- # There can be multiple inbound_nodes that reference the
259
- # same upstream layer. This causes issues when looking for
260
- # a particular index in that layer, since the indices
261
- # captured in `inbound_nodes` doesn't necessarily match the
262
- # number of entries in the `inbound_node_names` list. To
263
- # avoid IndexErrors, we just use the last element in the
264
- # `inbound_node_names` in this situation.
265
- # Note that this is a quick hack to avoid IndexErrors in
266
- # this situation, and might not be an appropriate solution
267
- # to this problem in general.
268
- input_name = (
269
- inbound_node_names [index ]
270
- if index < len (inbound_node_names )
271
- else inbound_node_names [- 1 ]
272
- )
273
- node_def .input .append (input_name )
265
+ for name , size , index in _get_inbound_nodes (layer ):
266
+ inbound_name = _scoped_name (name_scope , name )
267
+ # An input to a layer can be output from a model. In that case, the name
268
+ # of inbound_nodes to a layer is a name of a model. Remap the name of the
269
+ # model to output layer of the model. Also, since there can be multiple
270
+ # outputs in a model, make sure we pick the right output_layer from the model.
271
+ inbound_node_names = model_name_to_output .get (
272
+ inbound_name , [inbound_name ]
273
+ )
274
+ # There can be multiple inbound_nodes that reference the
275
+ # same upstream layer. This causes issues when looking for
276
+ # a particular index in that layer, since the indices
277
+ # captured in `inbound_nodes` doesn't necessarily match the
278
+ # number of entries in the `inbound_node_names` list. To
279
+ # avoid IndexErrors, we just use the last element in the
280
+ # `inbound_node_names` in this situation.
281
+ # Note that this is a quick hack to avoid IndexErrors in
282
+ # this situation, and might not be an appropriate solution
283
+ # to this problem in general.
284
+ input_name = (
285
+ inbound_node_names [index ]
286
+ if index < len (inbound_node_names )
287
+ else inbound_node_names [- 1 ]
288
+ )
289
+ node_def .input .append (input_name )
274
290
elif prev_node_name is not None :
275
291
node_def .input .append (prev_node_name )
276
292
0 commit comments