19
19
from __future__ import print_function
20
20
21
21
import collections
22
+ import contextlib
22
23
import json
23
24
import os
24
25
import shutil
27
28
from werkzeug import test as werkzeug_test
28
29
from werkzeug import wrappers
29
30
31
+ from tensorboard import db
30
32
from tensorboard .backend import application
31
33
from tensorboard .backend .event_processing import plugin_event_multiplexer as event_multiplexer # pylint: disable=line-too-long
32
34
from tensorboard .plugins import base_plugin
@@ -37,30 +39,60 @@ class CorePluginTest(tf.test.TestCase):
37
39
_only_use_meta_graph = False # Server data contains only a GraphDef
38
40
39
41
def setUp (self ):
40
- self .logdir = self .get_temp_dir ()
41
- self .addCleanup (shutil .rmtree , self .logdir )
42
+ self .temp_dir = self .get_temp_dir ()
43
+ self .addCleanup (shutil .rmtree , self .temp_dir )
44
+
45
+ self .startLogdirBasedServer (self .temp_dir )
46
+ self .startDbBasedServer (self .temp_dir )
47
+
48
+ def startLogdirBasedServer (self , temp_dir ):
49
+ self .logdir = temp_dir
42
50
self ._generate_test_data (run_name = 'run1' )
43
51
self .multiplexer = event_multiplexer .EventMultiplexer (
44
52
size_guidance = application .DEFAULT_SIZE_GUIDANCE ,
45
53
purge_orphaned_data = True )
46
- self . _context = base_plugin .TBContext (
54
+ context = base_plugin .TBContext (
47
55
assets_zip_provider = get_test_assets_zip_provider (),
48
56
logdir = self .logdir ,
49
- multiplexer = self .multiplexer )
50
- self .plugin = core_plugin .CorePlugin (self ._context )
57
+ multiplexer = self .multiplexer ,
58
+ window_title = 'title foo' )
59
+ self .logdir_based_plugin = core_plugin .CorePlugin (context )
51
60
app = application .TensorBoardWSGIApp (
52
- self .logdir , [self .plugin ], self .multiplexer , 0 , path_prefix = '' )
53
- self .server = werkzeug_test .Client (app , wrappers .BaseResponse )
61
+ self .logdir ,
62
+ [self .logdir_based_plugin ],
63
+ self .multiplexer ,
64
+ 0 ,
65
+ path_prefix = '' )
66
+ self .logdir_based_server = werkzeug_test .Client (app , wrappers .BaseResponse )
67
+
68
+ def startDbBasedServer (self , temp_dir ):
69
+ self .db_uri = 'sqlite:' + os .path .join (temp_dir , 'db.sqlite' )
70
+ db_module , db_connection_provider = application .get_database_info (
71
+ self .db_uri )
72
+ if db_connection_provider is not None :
73
+ with contextlib .closing (db_connection_provider ()) as db_conn :
74
+ schema = db .Schema (db_conn )
75
+ schema .create_tables ()
76
+ schema .create_indexes ()
77
+ context = base_plugin .TBContext (
78
+ assets_zip_provider = get_test_assets_zip_provider (),
79
+ db_module = db_module ,
80
+ db_connection_provider = db_connection_provider ,
81
+ db_uri = self .db_uri ,
82
+ window_title = 'title foo' )
83
+ self .db_based_plugin = core_plugin .CorePlugin (context )
84
+ app = application .TensorBoardWSGI ([self .db_based_plugin ])
85
+ self .db_based_server = werkzeug_test .Client (app , wrappers .BaseResponse )
54
86
55
87
def testRoutesProvided (self ):
56
88
"""Tests that the plugin offers the correct routes."""
57
- routes = self .plugin .get_plugin_apps ()
89
+ routes = self .logdir_based_plugin .get_plugin_apps ()
58
90
self .assertIsInstance (routes ['/data/logdir' ], collections .Callable )
59
91
self .assertIsInstance (routes ['/data/runs' ], collections .Callable )
60
92
61
93
def testIndex_returnsActualHtml (self ):
62
94
"""Test the format of the /data/runs endpoint."""
63
- response = self .server .get ('/' )
95
+ response = self .logdir_based_server .get ('/' )
64
96
self .assertEqual (200 , response .status_code )
65
97
self .assertStartsWith (response .headers .get ('Content-Type' ), 'text/html' )
66
98
html = response .get_data ()
@@ -69,18 +101,39 @@ def testIndex_returnsActualHtml(self):
69
101
def testDataPaths_disableAllCaching (self ):
70
102
"""Test the format of the /data/runs endpoint."""
71
103
for path in ('/data/runs' , '/data/logdir' ):
72
- response = self .server .get (path )
104
+ response = self .logdir_based_server .get (path )
73
105
self .assertEqual (200 , response .status_code , msg = path )
74
106
self .assertEqual ('0' , response .headers .get ('Expires' ), msg = path )
75
107
108
+ def testEnvironmentForDbUri (self ):
109
+ """Test that the environment route correctly returns the database URI."""
110
+ parsed_object = self ._get_json (self .db_based_server , '/data/environment' )
111
+ self .assertEqual (parsed_object ['data_location' ], self .db_uri )
112
+
113
+ def testEnvironmentForLogdir (self ):
114
+ """Test that the environment route correctly returns the logdir."""
115
+ parsed_object = self ._get_json (
116
+ self .logdir_based_server , '/data/environment' )
117
+ self .assertEqual (parsed_object ['data_location' ], self .logdir )
118
+
119
+ def testEnvironmentForWindowTitle (self ):
120
+ """Test that the environment route correctly returns the window title."""
121
+ parsed_object_db = self ._get_json (
122
+ self .db_based_server , '/data/environment' )
123
+ parsed_object_logdir = self ._get_json (
124
+ self .logdir_based_server , '/data/environment' )
125
+ self .assertEqual (
126
+ parsed_object_db ['window_title' ], parsed_object_logdir ['window_title' ])
127
+ self .assertEqual (parsed_object_db ['window_title' ], 'title foo' )
128
+
76
129
def testLogdir (self ):
77
130
"""Test the format of the data/logdir endpoint."""
78
- parsed_object = self ._get_json ('/data/logdir' )
131
+ parsed_object = self ._get_json (self . logdir_based_server , '/data/logdir' )
79
132
self .assertEqual (parsed_object , {'logdir' : self .logdir })
80
133
81
134
def testRuns (self ):
82
135
"""Test the format of the /data/runs endpoint."""
83
- run_json = self ._get_json ('/data/runs' )
136
+ run_json = self ._get_json (self . logdir_based_server , '/data/runs' )
84
137
self .assertEqual (run_json , ['run1' ])
85
138
86
139
def testRunsAppendOnly (self ):
@@ -120,23 +173,23 @@ def add_run(run_name):
120
173
121
174
# Add one run: it should come last.
122
175
add_run ('avocado' )
123
- self .assertEqual (self ._get_json ('/data/runs' ),
176
+ self .assertEqual (self ._get_json (self . logdir_based_server , '/data/runs' ),
124
177
['run1' , 'avocado' ])
125
178
126
179
# Add another run: it should come last, too.
127
180
add_run ('zebra' )
128
- self .assertEqual (self ._get_json ('/data/runs' ),
181
+ self .assertEqual (self ._get_json (self . logdir_based_server , '/data/runs' ),
129
182
['run1' , 'avocado' , 'zebra' ])
130
183
131
184
# And maybe there's a run for which we somehow have no timestamp.
132
185
add_run ('mysterious' )
133
- self .assertEqual (self ._get_json ('/data/runs' ),
186
+ self .assertEqual (self ._get_json (self . logdir_based_server , '/data/runs' ),
134
187
['run1' , 'avocado' , 'zebra' , 'mysterious' ])
135
188
136
189
stubs .UnsetAll ()
137
190
138
- def _get_json (self , path ):
139
- response = self . server .get (path )
191
+ def _get_json (self , server , path ):
192
+ response = server .get (path )
140
193
self .assertEqual (200 , response .status_code )
141
194
return self ._get_json_payload (response )
142
195
0 commit comments