Skip to content

Fixes regression for issue #5548: Avoid attempting to convert dtypes from "mixed precision" policy types. #6859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/plugins/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ py_library(
deps = [
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/compat/tensorflow_stub",
"//tensorboard/util:tb_logging",
],
)

Expand Down
37 changes: 28 additions & 9 deletions tensorboard/plugins/graph/keras_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
"""
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.tensorflow_stub import dtypes
from tensorboard.util import tb_logging


logger = tb_logging.get_logger()


def _walk_layers(keras_layer):
Expand Down Expand Up @@ -259,19 +263,34 @@ def keras_model_to_graph_def(keras_layer):

dtype_or_policy = layer_config.get("dtype")
dtype = None
has_unsupported_value = False
# If this is a dict, try and extract the dtype string from
# `config.name`; keras will export like this for non-input layers. If
# we can't find `config.name`, we skip it as it's presumably a instance
# of tf/keras/mixed_precision/Policy rather than a single dtype.
# TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
if isinstance(dtype_or_policy, dict):
if "config" in dtype_or_policy:
dtype = dtype_or_policy.get("config").get("name")
# `config.name`. Keras will export like this for non-input layers and
# some other cases (e.g. tf/keras/mixed_precision/Policy, as described
# in issue #5548).
if isinstance(dtype_or_policy, dict) and "config" in dtype_or_policy:
dtype = dtype_or_policy.get("config").get("name")
elif dtype_or_policy is not None:
dtype = dtype_or_policy

if dtype is not None:
tf_dtype = dtypes.as_dtype(dtype)
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
try:
tf_dtype = dtypes.as_dtype(dtype)
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
except TypeError:
has_unsupported_value = True
elif dtype_or_policy is not None:
has_unsupported_value = True

if has_unsupported_value:
# There's at least one known case when this happens, which is when
# mixed precision dtype policies are used, as described in issue
# #5548. (See https://keras.io/api/mixed_precision/).
# There might be a better way to handle this, but here we are.
logger.warning(
"Unsupported dtype value in graph model config (json):\n%s",
dtype_or_policy,
)
if layer.get("inbound_nodes") is not None:
for name, size, index in _get_inbound_nodes(layer):
inbound_name = _scoped_name(name_scope, name)
Expand Down
14 changes: 14 additions & 0 deletions tensorboard/plugins/graph/keras_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,20 @@ def test_keras_model_to_graph_def_functional_multiple_inbound_nodes_from_same_no

self.assertGraphDefToModel(expected_proto, model)

def test__keras_model_to_graph_def__does_not_crash_with_mixed_precision_dtype_policy(
self,
):
# See https://keras.io/api/mixed_precision/ for more info.
# Test to avoid regression on issue #5548
first_layer = tf.keras.layers.Dense(
1, input_shape=(1,), dtype="mixed_float16"
)
model = tf.keras.Sequential([first_layer])

model_config = json.loads(model.to_json())
# This line should not raise errors:
keras_util.keras_model_to_graph_def(model_config)


class _DoublingLayer(tf.keras.layers.Layer):
def call(self, inputs):
Expand Down
Loading