Skip to content

Commit 4c929f6

Browse files
authored
Refactor AuthContext creation into standalone AuthContextMiddleware. (tensorflow#6131)
The TensorBoard application layer creates a middleware that injects AuthContext into the RequestContext. It would be useful to expose this middleware for other usage - there are at least a couple test files in the internal code base that could make use of it. This change refactors the middleware into its own file for reuse, naming it AuthContextMiddleware and giving it its own devoted test.
1 parent 9c00b84 commit 4c929f6

File tree

4 files changed

+142
-16
lines changed

4 files changed

+142
-16
lines changed

tensorboard/backend/BUILD

+26-2
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@ py_library(
5353
srcs_version = "PY3",
5454
visibility = ["//visibility:public"],
5555
deps = [
56+
":auth_context_middleware",
5657
":client_feature_flags",
5758
":empty_path_redirect",
5859
":experiment_id",
5960
":experimental_plugin",
6061
":http_util",
6162
":path_prefix",
6263
":security_validator",
63-
"//tensorboard:auth",
64-
"//tensorboard:context",
6564
"//tensorboard:errors",
6665
"//tensorboard:plugin_util",
6766
"//tensorboard/plugins/core:core_plugin",
@@ -88,6 +87,31 @@ py_test(
8887
],
8988
)
9089

90+
py_library(
91+
name = "auth_context_middleware",
92+
srcs = ["auth_context_middleware.py"],
93+
srcs_version = "PY3",
94+
deps = [
95+
"//tensorboard:auth",
96+
"//tensorboard:context",
97+
],
98+
)
99+
100+
py_test(
101+
name = "auth_context_middleware_test",
102+
size = "small",
103+
srcs = ["auth_context_middleware_test.py"],
104+
srcs_version = "PY3",
105+
tags = ["support_notf"],
106+
deps = [
107+
":auth_context_middleware",
108+
"//tensorboard:auth",
109+
"//tensorboard:context",
110+
"//tensorboard:test",
111+
"@org_pocoo_werkzeug",
112+
],
113+
)
114+
91115
py_library(
92116
name = "empty_path_redirect",
93117
srcs = ["empty_path_redirect.py"],

tensorboard/backend/application.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232

3333
from tensorboard import errors
3434
from tensorboard import plugin_util
35-
from tensorboard import auth
36-
from tensorboard import context
35+
from tensorboard.backend import auth_context_middleware
3736
from tensorboard.backend import client_feature_flags
3837
from tensorboard.backend import empty_path_redirect
3938
from tensorboard.backend import experiment_id
@@ -326,7 +325,9 @@ def _create_wsgi_app(self):
326325
app = self._route_request
327326
for middleware in self._extra_middlewares:
328327
app = middleware(app)
329-
app = _auth_context_middleware(app, self._auth_providers)
328+
app = auth_context_middleware.AuthContextMiddleware(
329+
app, self._auth_providers
330+
)
330331
app = client_feature_flags.ClientFeatureFlagsMiddleware(app)
331332
app = empty_path_redirect.EmptyPathRedirectMiddleware(app)
332333
app = experiment_id.ExperimentIdMiddleware(app)
@@ -582,17 +583,6 @@ def wrapper(environ, start_response):
582583
return wrapper
583584

584585

585-
def _auth_context_middleware(wsgi_app, auth_providers):
586-
def wrapper(environ, start_response):
587-
environ = dict(environ)
588-
auth_ctx = auth.AuthContext(auth_providers, environ)
589-
ctx = context.from_environ(environ).replace(auth=auth_ctx)
590-
context.set_in_environ(environ, ctx)
591-
return wsgi_app(environ, start_response)
592-
593-
return wrapper
594-
595-
596586
def _clean_path(path):
597587
"""Removes a trailing slash from a non-root path.
598588
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from tensorboard import auth
16+
from tensorboard import context
17+
18+
19+
class AuthContextMiddleware:
20+
"""WSGI middleware to inject an AuthContext into the RequestContext."""
21+
22+
def __init__(self, application, auth_providers):
23+
"""Initializes this middleware.
24+
25+
Args:
26+
application: A WSGI application to delegate to.
27+
auth_providers: The auth_providers to provide to the AuthContext.
28+
"""
29+
self._application = application
30+
self._auth_providers = auth_providers
31+
32+
def __call__(self, environ, start_response):
33+
"""Invoke this WSGI application."""
34+
environ = dict(environ)
35+
auth_ctx = auth.AuthContext(self._auth_providers, environ)
36+
ctx = context.from_environ(environ).replace(auth=auth_ctx)
37+
context.set_in_environ(environ, ctx)
38+
return self._application(environ, start_response)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for `tensorboard.backend.auth_context_middleware`."""
16+
from werkzeug import test as werkzeug_test
17+
from werkzeug import wrappers
18+
19+
from tensorboard import auth
20+
from tensorboard import context
21+
from tensorboard import test as tb_test
22+
from tensorboard.backend import auth_context_middleware
23+
24+
25+
class SimpleAuthProvider(auth.AuthProvider):
26+
"""Simple AuthProvider that returns the value it is initialized with."""
27+
28+
def __init__(self, credential):
29+
self._credential = credential
30+
31+
def authenticate(self, environ):
32+
return self._credential
33+
34+
35+
def _create_auth_provider_verifier_app(expected_auth_key):
36+
"""Generates a WSGI application for verifying AuthContextMiddleware.
37+
38+
It should be placed after AuthContextMiddleware in the WSGI handler chain.
39+
It will generate a credential using the AuthContext populated by
40+
AuthContextMiddleware.
41+
42+
Args:
43+
expected_auth_key: The auth key to be used for invoking the AuthContext.
44+
This key should correspond to the auth_providers used to configure
45+
the AuthContextMiddleware .
46+
"""
47+
48+
def _app(environ, start_response):
49+
ctx = context.from_environ(environ)
50+
credential = ctx.auth.get(expected_auth_key)
51+
start_response("200 OK", [("Content-Type", "text\plain")])
52+
return f"credential: {credential}"
53+
54+
return _app
55+
56+
57+
class AuthContextMiddlewareTest(tb_test.TestCase):
58+
"""Tests for `AuthContextMiddleware`"""
59+
60+
def test_populates_auth_context(self):
61+
app = auth_context_middleware.AuthContextMiddleware(
62+
_create_auth_provider_verifier_app("my_key"),
63+
{"my_key": SimpleAuthProvider("my_credential")},
64+
)
65+
66+
server = werkzeug_test.Client(app, wrappers.Response)
67+
response = server.get("")
68+
self.assertEqual(
69+
"credential: my_credential", response.get_data().decode()
70+
)
71+
72+
73+
if __name__ == "__main__":
74+
tb_test.main()

0 commit comments

Comments
 (0)