Skip to content

Commit b2b50a3

Browse files
authored
keras: skip config if it's a policy obj (#5553)
`layer_config` could be a keras Policy object. Gracefully fall with checking the type. #5548
1 parent b6a8f98 commit b2b50a3

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tensorboard/plugins/graph/keras_util.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,15 @@ def keras_model_to_graph_def(keras_layer):
234234
keras_cls_name = layer.get("class_name").encode("ascii")
235235
node_def.attr["keras_class"].s = keras_cls_name
236236

237-
if layer_config.get("dtype") is not None:
237+
dtype_or_policy = layer_config.get("dtype")
238+
# Skip dtype processing if this is a dict, since it's presumably a instance of
239+
# tf/keras/mixed_precision/Policy rather than a single dtype.
240+
# TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
241+
if dtype_or_policy is not None and not isinstance(
242+
dtype_or_policy, dict
243+
):
238244
tf_dtype = dtypes.as_dtype(layer_config.get("dtype"))
239245
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
240-
241246
if layer.get("inbound_nodes") is not None:
242247
for maybe_inbound_node in layer.get("inbound_nodes"):
243248
inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node)

0 commit comments

Comments
 (0)