24
24
25
25
import collections
26
26
import random
27
- import six
28
- from werkzeug import wrappers
29
27
30
28
import numpy as np
29
+ import six
31
30
import tensorflow as tf
31
+ from werkzeug import wrappers
32
32
33
33
from tensorboard import plugin_util
34
34
from tensorboard .backend import http_util
@@ -47,6 +47,10 @@ class HistogramsPlugin(base_plugin.TBPlugin):
47
47
48
48
plugin_name = metadata .PLUGIN_NAME
49
49
50
+ # Use a round number + 1 since sampling includes both start and end steps,
51
+ # so N+1 samples corresponds to dividing the step sequence into N intervals.
52
+ SAMPLE_SIZE = 51
53
+
50
54
def __init__ (self , context ):
51
55
"""Instantiates HistogramsPlugin via TensorBoard core.
52
56
@@ -119,7 +123,7 @@ def index_impl(self):
119
123
120
124
return result
121
125
122
- def histograms_impl (self , tag , run , downsample_to = 50 ):
126
+ def histograms_impl (self , tag , run , downsample_to = None ):
123
127
"""Result of the form `(body, mime_type)`, or `ValueError`.
124
128
125
129
At most `downsample_to` events will be returned. If this value is
@@ -128,45 +132,61 @@ def histograms_impl(self, tag, run, downsample_to=50):
128
132
if self ._db_connection_provider :
129
133
# Serve data from the database.
130
134
db = self ._db_connection_provider ()
131
- # We select for steps greater than -1 because the writer inserts
132
- # placeholder rows en masse. The check for step filters out those rows.
133
- query = '''
134
- SELECT
135
- Tensors.computed_time AS computed_time,
136
- Tensors.step AS step,
137
- Tensors.data AS data,
138
- Tensors.dtype AS dtype,
139
- Tensors.shape AS shape
140
- FROM Tensors
141
- JOIN Tags
142
- ON Tensors.series = Tags.tag_id
143
- JOIN Runs
144
- ON Tags.run_id = Runs.run_id
145
- WHERE
146
- Runs.run_name = ?
147
- AND Tags.tag_name = ?
148
- AND Tags.plugin_name = ?
149
- AND Tensors.step > -1
150
- '''
151
- if downsample_to is not None :
152
- # Wrap the query in an outer one that samples.
153
- query = '''
154
- SELECT *
155
- FROM
156
- (%(query)s
157
- ORDER BY RANDOM()
158
- LIMIT %(downsample_to)d)
159
- ''' % {
160
- 'query' : query ,
161
- 'downsample_to' : downsample_to ,
162
- }
163
- query = '''
164
- %s
165
- ORDER BY step
166
- ''' % query
167
- cursor = db .execute (query , (run , tag , metadata .PLUGIN_NAME ))
168
- events = [(row [0 ], row [1 ], self ._get_values (row [2 ], row [3 ], row [4 ]))
169
- for row in cursor ]
135
+ cursor = db .cursor ()
136
+ # Prefetch the tag ID matching this run and tag.
137
+ cursor .execute (
138
+ '''
139
+ SELECT
140
+ tag_id
141
+ FROM Tags
142
+ JOIN Runs USING (run_id)
143
+ WHERE
144
+ Runs.run_name = :run
145
+ AND Tags.tag_name = :tag
146
+ AND Tags.plugin_name = :plugin
147
+ ''' ,
148
+ {'run' : run , 'tag' : tag , 'plugin' : metadata .PLUGIN_NAME })
149
+ row = cursor .fetchone ()
150
+ if not row :
151
+ raise ValueError ('No histogram tag %r for run %r' % (tag , run ))
152
+ (tag_id ,) = row
153
+ # Fetch tensor values, optionally with linear-spaced sampling by step.
154
+ # For steps ranging from s_min to s_max and sample size k, this query
155
+ # divides the range into k - 1 equal-sized intervals and returns the
156
+ # lowest step at or above each of the k interval boundaries (which always
157
+ # includes s_min and s_max, and may be fewer than k results if there are
158
+ # intervals where no steps are present). For contiguous steps the results
159
+ # can be formally expressed as the following:
160
+ # [s_min + math.ceil(i / k * (s_max - s_min)) for i in range(0, k + 1)]
161
+ cursor .execute (
162
+ '''
163
+ SELECT
164
+ MIN(step) AS step,
165
+ computed_time,
166
+ data,
167
+ dtype,
168
+ shape
169
+ FROM Tensors
170
+ INNER JOIN (
171
+ SELECT
172
+ MIN(step) AS min_step,
173
+ MAX(step) AS max_step
174
+ FROM Tensors
175
+ /* Filter out NULL so we can use TensorSeriesStepIndex. */
176
+ WHERE series = :tag_id AND step IS NOT NULL
177
+ )
178
+ /* Ensure we omit reserved rows, which have NULL step values. */
179
+ WHERE series = :tag_id AND step IS NOT NULL
180
+ /* Bucket rows into sample_size linearly spaced buckets, or do
181
+ no sampling if sample_size is NULL. */
182
+ GROUP BY
183
+ IFNULL(:sample_size - 1, max_step - min_step)
184
+ * (step - min_step) / (max_step - min_step)
185
+ ORDER BY step
186
+ ''' ,
187
+ {'tag_id' : tag_id , 'sample_size' : downsample_to })
188
+ events = [(computed_time , step , self ._get_values (data , dtype , shape ))
189
+ for step , computed_time , data , dtype , shape in cursor ]
170
190
else :
171
191
# Serve data from events files.
172
192
try :
@@ -204,7 +224,8 @@ def histograms_route(self, request):
204
224
tag = request .args .get ('tag' )
205
225
run = request .args .get ('run' )
206
226
try :
207
- (body , mime_type ) = self .histograms_impl (tag , run )
227
+ (body , mime_type ) = self .histograms_impl (
228
+ tag , run , downsample_to = self .SAMPLE_SIZE )
208
229
code = 200
209
230
except ValueError as e :
210
231
(body , mime_type ) = (str (e ), 'text/plain' )
0 commit comments