Skip to content

Commit bf164c7

Browse files
committed
Change histogram SQL query to use linspaced sampling
1 parent e547f78 commit bf164c7

File tree

2 files changed

+69
-44
lines changed

2 files changed

+69
-44
lines changed

tensorboard/plugins/distribution/distributions_plugin.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class DistributionsPlugin(base_plugin.TBPlugin):
4141

4242
plugin_name = 'distributions'
4343

44+
# Use a round number + 1 since sampling includes both start and end steps,
45+
# so N+1 samples corresponds to dividing the step sequence into N intervals.
46+
SAMPLE_SIZE = 501
47+
4448
def __init__(self, context):
4549
"""Instantiates DistributionsPlugin via TensorBoard core.
4650
@@ -67,7 +71,7 @@ def is_active(self):
6771
def distributions_impl(self, tag, run):
6872
"""Result of the form `(body, mime_type)`, or `ValueError`."""
6973
(histograms, mime_type) = self._histograms_plugin.histograms_impl(
70-
tag, run, downsample_to=None)
74+
tag, run, downsample_to=self.SAMPLE_SIZE)
7175
return ([self._compress(histogram) for histogram in histograms],
7276
mime_type)
7377

tensorboard/plugins/histogram/histograms_plugin.py

+64-43
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424

2525
import collections
2626
import random
27-
import six
28-
from werkzeug import wrappers
2927

3028
import numpy as np
29+
import six
3130
import tensorflow as tf
31+
from werkzeug import wrappers
3232

3333
from tensorboard import plugin_util
3434
from tensorboard.backend import http_util
@@ -47,6 +47,10 @@ class HistogramsPlugin(base_plugin.TBPlugin):
4747

4848
plugin_name = metadata.PLUGIN_NAME
4949

50+
# Use a round number + 1 since sampling includes both start and end steps,
51+
# so N+1 samples corresponds to dividing the step sequence into N intervals.
52+
SAMPLE_SIZE = 51
53+
5054
def __init__(self, context):
5155
"""Instantiates HistogramsPlugin via TensorBoard core.
5256
@@ -119,7 +123,7 @@ def index_impl(self):
119123

120124
return result
121125

122-
def histograms_impl(self, tag, run, downsample_to=50):
126+
def histograms_impl(self, tag, run, downsample_to=None):
123127
"""Result of the form `(body, mime_type)`, or `ValueError`.
124128
125129
At most `downsample_to` events will be returned. If this value is
@@ -128,45 +132,61 @@ def histograms_impl(self, tag, run, downsample_to=50):
128132
if self._db_connection_provider:
129133
# Serve data from the database.
130134
db = self._db_connection_provider()
131-
# We select for steps greater than -1 because the writer inserts
132-
# placeholder rows en masse. The check for step filters out those rows.
133-
query = '''
134-
SELECT
135-
Tensors.computed_time AS computed_time,
136-
Tensors.step AS step,
137-
Tensors.data AS data,
138-
Tensors.dtype AS dtype,
139-
Tensors.shape AS shape
140-
FROM Tensors
141-
JOIN Tags
142-
ON Tensors.series = Tags.tag_id
143-
JOIN Runs
144-
ON Tags.run_id = Runs.run_id
145-
WHERE
146-
Runs.run_name = ?
147-
AND Tags.tag_name = ?
148-
AND Tags.plugin_name = ?
149-
AND Tensors.step > -1
150-
'''
151-
if downsample_to is not None:
152-
# Wrap the query in an outer one that samples.
153-
query = '''
154-
SELECT *
155-
FROM
156-
(%(query)s
157-
ORDER BY RANDOM()
158-
LIMIT %(downsample_to)d)
159-
''' % {
160-
'query': query,
161-
'downsample_to': downsample_to,
162-
}
163-
query = '''
164-
%s
165-
ORDER BY step
166-
''' % query
167-
cursor = db.execute(query, (run, tag, metadata.PLUGIN_NAME))
168-
events = [(row[0], row[1], self._get_values(row[2], row[3], row[4]))
169-
for row in cursor]
135+
cursor = db.cursor()
136+
# Prefetch the tag ID matching this run and tag.
137+
cursor.execute(
138+
'''
139+
SELECT
140+
tag_id
141+
FROM Tags
142+
JOIN Runs USING (run_id)
143+
WHERE
144+
Runs.run_name = :run
145+
AND Tags.tag_name = :tag
146+
AND Tags.plugin_name = :plugin
147+
''',
148+
{'run': run, 'tag': tag, 'plugin': metadata.PLUGIN_NAME})
149+
row = cursor.fetchone()
150+
if not row:
151+
raise ValueError('No histogram tag %r for run %r' % (tag, run))
152+
(tag_id,) = row
153+
# Fetch tensor values, optionally with linear-spaced sampling by step.
154+
# For steps ranging from s_min to s_max and sample size k, this query
155+
# divides the range into k - 1 equal-sized intervals and returns the
156+
# lowest step at or above each of the k interval boundaries (which always
157+
# includes s_min and s_max, and may be fewer than k results if there are
158+
# intervals where no steps are present). For contiguous steps the results
159+
# can be formally expressed as the following:
160+
# [s_min + math.ceil(i / k * (s_max - s_min)) for i in range(0, k + 1)]
161+
cursor.execute(
162+
'''
163+
SELECT
164+
MIN(step) AS step,
165+
computed_time,
166+
data,
167+
dtype,
168+
shape
169+
FROM Tensors
170+
INNER JOIN (
171+
SELECT
172+
MIN(step) AS min_step,
173+
MAX(step) AS max_step
174+
FROM Tensors
175+
/* Filter out NULL so we can use TensorSeriesStepIndex. */
176+
WHERE series = :tag_id AND step IS NOT NULL
177+
)
178+
/* Ensure we omit reserved rows, which have NULL step values. */
179+
WHERE series = :tag_id AND step IS NOT NULL
180+
/* Bucket rows into sample_size linearly spaced buckets, or do
181+
no sampling if sample_size is NULL. */
182+
GROUP BY
183+
IFNULL(:sample_size - 1, max_step - min_step)
184+
* (step - min_step) / (max_step - min_step)
185+
ORDER BY step
186+
''',
187+
{'tag_id': tag_id, 'sample_size': downsample_to})
188+
events = [(computed_time, step, self._get_values(data, dtype, shape))
189+
for step, computed_time, data, dtype, shape in cursor]
170190
else:
171191
# Serve data from events files.
172192
try:
@@ -204,7 +224,8 @@ def histograms_route(self, request):
204224
tag = request.args.get('tag')
205225
run = request.args.get('run')
206226
try:
207-
(body, mime_type) = self.histograms_impl(tag, run)
227+
(body, mime_type) = self.histograms_impl(
228+
tag, run, downsample_to=self.SAMPLE_SIZE)
208229
code = 200
209230
except ValueError as e:
210231
(body, mime_type) = (str(e), 'text/plain')

0 commit comments

Comments
 (0)