Skip to content

Commit 6500d36

Browse files
yatbearAnuarTB
authored andcommitted
Migrate graph plugin to Keras 3 and remove dependencies on Keras 2 and tf-keras-nightly (tensorflow#6823)
This should fix the issue described in tensorflow#6686. Compared the graph content with the one parsed from Keras 3, there are some changes in the structure and naming, this PR modifies the graph parsing accordingly. Googlers, see b/325451531 and b/312739672. Tested with `tensorboard/plugins/graph:graphs_demo`: ![image](https://github.com/tensorflow/tensorboard/assets/15273931/2e13205a-5ea2-437c-9fa7-3198f0eed1c5) #keras3 #oncall
1 parent 9da01a4 commit 6500d36

File tree

5 files changed

+128
-118
lines changed

5 files changed

+128
-118
lines changed

Diff for: .github/workflows/ci.yml

-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ env:
3131
BUILDIFIER_SHA256SUM: 'e92a6793c7134c5431c58fbc34700664f101e5c9b1c1fcd93b97978e8b7f88db'
3232
BUILDOZER_SHA256SUM: '3d58a0b6972e4535718cdd6c12778170ea7382de7c75bc3728f5719437ffb84d'
3333
TENSORFLOW_VERSION: 'tf-nightly'
34-
TF_KERAS_VERSION: 'tf-keras-nightly' # Keras 2
3534

3635
jobs:
3736
build:
@@ -78,7 +77,6 @@ jobs:
7877
run: |
7978
python -m pip install -U pip
8079
pip install "${TENSORFLOW_VERSION}"
81-
pip install "${TF_KERAS_VERSION}"
8280
if: matrix.tf_version_id != 'notf'
8381
- name: 'Install Python dependencies'
8482
run: |

Diff for: DEVELOPMENT.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ TensorBoard at HEAD relies on the nightly installation of TensorFlow: this allow
1919
$ virtualenv -p python3 tf
2020
$ source tf/bin/activate
2121
(tf)$ pip install --upgrade pip
22-
(tf)$ pip install tf-nightly tf-keras-nightly -r tensorboard/pip_package/requirements.txt -r tensorboard/pip_package/requirements_dev.txt
22+
(tf)$ pip install tf-nightly -r tensorboard/pip_package/requirements.txt -r tensorboard/pip_package/requirements_dev.txt
2323
(tf)$ pip uninstall -y tb-nightly
2424
```
2525

Diff for: tensorboard/plugins/graph/graphs_plugin_v2_test.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,9 @@
2424
from tensorboard.compat.proto import graph_pb2
2525
from tensorboard.plugins.graph import graphs_plugin_test
2626

27-
# Stay on Keras 2 for now: https://github.com/keras-team/keras/issues/18467.
28-
version_fn = getattr(tf.keras, "version", None)
29-
if version_fn and version_fn().startswith("3."):
30-
import tf_keras as keras # Keras 2
31-
else:
32-
keras = tf.keras # Keras 2
27+
28+
# Graph plugin V2 Keras 3 is only supported in TensorFlow eager mode.
29+
tf.compat.v1.enable_eager_execution()
3330

3431

3532
class GraphsPluginV2Test(
@@ -41,13 +38,13 @@ def generate_run(
4138
x, y = np.ones((10, 10)), np.ones((10, 1))
4239
val_x, val_y = np.ones((4, 10)), np.ones((4, 1))
4340

44-
model = keras.Sequential(
41+
model = tf.keras.Sequential(
4542
[
46-
keras.layers.Dense(10, activation="relu"),
47-
keras.layers.Dense(1, activation="sigmoid"),
43+
tf.keras.layers.Dense(10, activation="relu"),
44+
tf.keras.layers.Dense(1, activation="sigmoid"),
4845
]
4946
)
50-
model.compile("rmsprop", "binary_crossentropy")
47+
model.compile(optimizer="rmsprop", loss="binary_crossentropy")
5148

5249
model.fit(
5350
x,
@@ -56,7 +53,7 @@ def generate_run(
5653
batch_size=2,
5754
epochs=1,
5855
callbacks=[
59-
keras.callbacks.TensorBoard(
56+
tf.keras.callbacks.TensorBoard(
6057
log_dir=os.path.join(logdir, run_name),
6158
write_graph=include_graph,
6259
)

Diff for: tensorboard/plugins/graph/keras_util.py

+45-29
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,24 @@ def _norm_to_list_of_layers(maybe_layers):
117117
)
118118

119119

120+
def _get_inbound_nodes(layer):
121+
"""Returns a list of [name, size, index] for all inbound nodes of the given layer."""
122+
inbound_nodes = []
123+
if layer.get("inbound_nodes") is not None:
124+
for maybe_inbound_node in layer.get("inbound_nodes", []):
125+
for inbound_node_args in maybe_inbound_node.get("args", []):
126+
# Sometimes this field is a list when there are multiple inbound nodes
127+
# for the given layer.
128+
if not isinstance(inbound_node_args, list):
129+
inbound_node_args = [inbound_node_args]
130+
for arg in inbound_node_args:
131+
history = arg.get("config", {}).get("keras_history", [])
132+
if len(history) < 3:
133+
continue
134+
inbound_nodes.append(history[:3])
135+
return inbound_nodes
136+
137+
120138
def _update_dicts(
121139
name_scope,
122140
model_layer,
@@ -149,7 +167,7 @@ def _update_dicts(
149167
node_name = _scoped_name(name_scope, layer_config.get("name"))
150168
input_layers = layer_config.get("input_layers")
151169
output_layers = layer_config.get("output_layers")
152-
inbound_nodes = model_layer.get("inbound_nodes")
170+
inbound_nodes = _get_inbound_nodes(model_layer)
153171

154172
is_functional_model = bool(input_layers and output_layers)
155173
# In case of [1] and the parent model is functional, current layer
@@ -164,7 +182,7 @@ def _update_dicts(
164182
elif is_parent_functional_model and not is_functional_model:
165183
# Sequential model can take only one input. Make sure inbound to the
166184
# model is linked to the first layer in the Sequential model.
167-
prev_node_name = _scoped_name(name_scope, inbound_nodes[0][0][0])
185+
prev_node_name = _scoped_name(name_scope, inbound_nodes[0][0])
168186
elif (
169187
not is_parent_functional_model
170188
and prev_node_name
@@ -244,33 +262,31 @@ def keras_model_to_graph_def(keras_layer):
244262
tf_dtype = dtypes.as_dtype(layer_config.get("dtype"))
245263
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
246264
if layer.get("inbound_nodes") is not None:
247-
for maybe_inbound_node in layer.get("inbound_nodes"):
248-
inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node)
249-
for [name, size, index, _] in inbound_nodes:
250-
inbound_name = _scoped_name(name_scope, name)
251-
# An input to a layer can be output from a model. In that case, the name
252-
# of inbound_nodes to a layer is a name of a model. Remap the name of the
253-
# model to output layer of the model. Also, since there can be multiple
254-
# outputs in a model, make sure we pick the right output_layer from the model.
255-
inbound_node_names = model_name_to_output.get(
256-
inbound_name, [inbound_name]
257-
)
258-
# There can be multiple inbound_nodes that reference the
259-
# same upstream layer. This causes issues when looking for
260-
# a particular index in that layer, since the indices
261-
# captured in `inbound_nodes` doesn't necessarily match the
262-
# number of entries in the `inbound_node_names` list. To
263-
# avoid IndexErrors, we just use the last element in the
264-
# `inbound_node_names` in this situation.
265-
# Note that this is a quick hack to avoid IndexErrors in
266-
# this situation, and might not be an appropriate solution
267-
# to this problem in general.
268-
input_name = (
269-
inbound_node_names[index]
270-
if index < len(inbound_node_names)
271-
else inbound_node_names[-1]
272-
)
273-
node_def.input.append(input_name)
265+
for name, size, index in _get_inbound_nodes(layer):
266+
inbound_name = _scoped_name(name_scope, name)
267+
# An input to a layer can be output from a model. In that case, the name
268+
# of inbound_nodes to a layer is a name of a model. Remap the name of the
269+
# model to output layer of the model. Also, since there can be multiple
270+
# outputs in a model, make sure we pick the right output_layer from the model.
271+
inbound_node_names = model_name_to_output.get(
272+
inbound_name, [inbound_name]
273+
)
274+
# There can be multiple inbound_nodes that reference the
275+
# same upstream layer. This causes issues when looking for
276+
# a particular index in that layer, since the indices
277+
# captured in `inbound_nodes` doesn't necessarily match the
278+
# number of entries in the `inbound_node_names` list. To
279+
# avoid IndexErrors, we just use the last element in the
280+
# `inbound_node_names` in this situation.
281+
# Note that this is a quick hack to avoid IndexErrors in
282+
# this situation, and might not be an appropriate solution
283+
# to this problem in general.
284+
input_name = (
285+
inbound_node_names[index]
286+
if index < len(inbound_node_names)
287+
else inbound_node_names[-1]
288+
)
289+
node_def.input.append(input_name)
274290
elif prev_node_name is not None:
275291
node_def.input.append(prev_node_name)
276292

0 commit comments

Comments
 (0)