File tree 1 file changed +7
-2
lines changed
tensorboard/plugins/graph
1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -234,10 +234,15 @@ def keras_model_to_graph_def(keras_layer):
234
234
keras_cls_name = layer .get ("class_name" ).encode ("ascii" )
235
235
node_def .attr ["keras_class" ].s = keras_cls_name
236
236
237
- if layer_config .get ("dtype" ) is not None :
237
+ dtype_or_policy = layer_config .get ("dtype" )
238
+ # Skip dtype processing if this is a dict, since it's presumably a instance of
239
+ # tf/keras/mixed_precision/Policy rather than a single dtype.
240
+ # TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
241
+ if dtype_or_policy is not None and not isinstance (
242
+ dtype_or_policy , dict
243
+ ):
238
244
tf_dtype = dtypes .as_dtype (layer_config .get ("dtype" ))
239
245
node_def .attr ["dtype" ].type = tf_dtype .as_datatype_enum
240
-
241
246
if layer .get ("inbound_nodes" ) is not None :
242
247
for maybe_inbound_node in layer .get ("inbound_nodes" ):
243
248
inbound_nodes = _norm_to_list_of_layers (maybe_inbound_node )
You can’t perform that action at this time.
0 commit comments